Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    dialect: DialectType = None,
  43    max_depth: t.Optional[int] = None,
  44):
  45    """
  46    Rewrite sqlglot AST to simplify expressions.
  47
  48    Example:
  49        >>> import sqlglot
  50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  51        >>> simplify(expression).sql()
  52        'TRUE'
  53
  54    Args:
  55        expression: expression to simplify
  56        constant_propagation: whether the constant propagation rule should be used
  57        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
  58    Returns:
  59        sqlglot.Expression: simplified expression
  60    """
  61
  62    dialect = Dialect.get_or_raise(dialect)
  63
  64    def _simplify(expression, root=True):
  65        if (
  66            max_depth
  67            and isinstance(expression, exp.Connector)
  68            and not isinstance(expression.parent, exp.Connector)
  69        ):
  70            depth = connector_depth(expression)
  71            if depth > max_depth:
  72                logger.info(
  73                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
  74                )
  75                return expression
  76
  77        if expression.meta.get(FINAL):
  78            return expression
  79
  80        # group by expressions cannot be simplified, for example
  81        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  82        # the projection must exactly match the group by key
  83        group = expression.args.get("group")
  84
  85        if group and hasattr(expression, "selects"):
  86            groups = set(group.expressions)
  87            group.meta[FINAL] = True
  88
  89            for e in expression.selects:
  90                for node in e.walk():
  91                    if node in groups:
  92                        e.meta[FINAL] = True
  93                        break
  94
  95            having = expression.args.get("having")
  96            if having:
  97                for node in having.walk():
  98                    if node in groups:
  99                        having.meta[FINAL] = True
 100                        break
 101
 102        # Pre-order transformations
 103        node = expression
 104        node = rewrite_between(node)
 105        node = uniq_sort(node, root)
 106        node = absorb_and_eliminate(node, root)
 107        node = simplify_concat(node)
 108        node = simplify_conditionals(node)
 109
 110        if constant_propagation:
 111            node = propagate_constants(node, root)
 112
 113        exp.replace_children(node, lambda e: _simplify(e, False))
 114
 115        # Post-order transformations
 116        node = simplify_not(node)
 117        node = flatten(node)
 118        node = simplify_connectors(node, root)
 119        node = remove_complements(node, root)
 120        node = simplify_coalesce(node)
 121        node.parent = expression.parent
 122        node = simplify_literals(node, root)
 123        node = simplify_equality(node)
 124        node = simplify_parens(node)
 125        node = simplify_datetrunc(node, dialect)
 126        node = sort_comparison(node)
 127        node = simplify_startswith(node)
 128
 129        if root:
 130            expression.replace(node)
 131        return node
 132
 133    expression = while_changing(expression, _simplify)
 134    remove_where_true(expression)
 135    return expression
 136
 137
 138def connector_depth(expression: exp.Expression) -> int:
 139    """
 140    Determine the maximum depth of a tree of Connectors.
 141
 142    For example:
 143        >>> from sqlglot import parse_one
 144        >>> connector_depth(parse_one("a AND b AND c AND d"))
 145        3
 146    """
 147    stack = deque([(expression, 0)])
 148    max_depth = 0
 149
 150    while stack:
 151        expression, depth = stack.pop()
 152
 153        if not isinstance(expression, exp.Connector):
 154            continue
 155
 156        depth += 1
 157        max_depth = max(depth, max_depth)
 158
 159        stack.append((expression.left, depth))
 160        stack.append((expression.right, depth))
 161
 162    return max_depth
 163
 164
 165def catch(*exceptions):
 166    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 167
 168    def decorator(func):
 169        def wrapped(expression, *args, **kwargs):
 170            try:
 171                return func(expression, *args, **kwargs)
 172            except exceptions:
 173                return expression
 174
 175        return wrapped
 176
 177    return decorator
 178
 179
 180def rewrite_between(expression: exp.Expression) -> exp.Expression:
 181    """Rewrite x between y and z to x >= y AND x <= z.
 182
 183    This is done because comparison simplification is only done on lt/lte/gt/gte.
 184    """
 185    if isinstance(expression, exp.Between):
 186        negate = isinstance(expression.parent, exp.Not)
 187
 188        expression = exp.and_(
 189            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 190            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 191            copy=False,
 192        )
 193
 194        if negate:
 195            expression = exp.paren(expression, copy=False)
 196
 197    return expression
 198
 199
 200COMPLEMENT_COMPARISONS = {
 201    exp.LT: exp.GTE,
 202    exp.GT: exp.LTE,
 203    exp.LTE: exp.GT,
 204    exp.GTE: exp.LT,
 205    exp.EQ: exp.NEQ,
 206    exp.NEQ: exp.EQ,
 207}
 208
 209COMPLEMENT_SUBQUERY_PREDICATES = {
 210    exp.All: exp.Any,
 211    exp.Any: exp.All,
 212}
 213
 214
 215def simplify_not(expression):
 216    """
 217    Demorgan's Law
 218    NOT (x OR y) -> NOT x AND NOT y
 219    NOT (x AND y) -> NOT x OR NOT y
 220    """
 221    if isinstance(expression, exp.Not):
 222        this = expression.this
 223        if is_null(this):
 224            return exp.null()
 225        if this.__class__ in COMPLEMENT_COMPARISONS:
 226            right = this.expression
 227            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
 228            if complement_subquery_predicate:
 229                right = complement_subquery_predicate(this=right.this)
 230
 231            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 232        if isinstance(this, exp.Paren):
 233            condition = this.unnest()
 234            if isinstance(condition, exp.And):
 235                return exp.paren(
 236                    exp.or_(
 237                        exp.not_(condition.left, copy=False),
 238                        exp.not_(condition.right, copy=False),
 239                        copy=False,
 240                    )
 241                )
 242            if isinstance(condition, exp.Or):
 243                return exp.paren(
 244                    exp.and_(
 245                        exp.not_(condition.left, copy=False),
 246                        exp.not_(condition.right, copy=False),
 247                        copy=False,
 248                    )
 249                )
 250            if is_null(condition):
 251                return exp.null()
 252        if always_true(this):
 253            return exp.false()
 254        if is_false(this):
 255            return exp.true()
 256        if isinstance(this, exp.Not):
 257            # double negation
 258            # NOT NOT x -> x
 259            return this.this
 260    return expression
 261
 262
 263def flatten(expression):
 264    """
 265    A AND (B AND C) -> A AND B AND C
 266    A OR (B OR C) -> A OR B OR C
 267    """
 268    if isinstance(expression, exp.Connector):
 269        for node in expression.args.values():
 270            child = node.unnest()
 271            if isinstance(child, expression.__class__):
 272                node.replace(child)
 273    return expression
 274
 275
 276def simplify_connectors(expression, root=True):
 277    def _simplify_connectors(expression, left, right):
 278        if isinstance(expression, exp.And):
 279            if is_false(left) or is_false(right):
 280                return exp.false()
 281            if is_zero(left) or is_zero(right):
 282                return exp.false()
 283            if is_null(left) or is_null(right):
 284                return exp.null()
 285            if always_true(left) and always_true(right):
 286                return exp.true()
 287            if always_true(left):
 288                return right
 289            if always_true(right):
 290                return left
 291            return _simplify_comparison(expression, left, right)
 292        elif isinstance(expression, exp.Or):
 293            if always_true(left) or always_true(right):
 294                return exp.true()
 295            if (
 296                (is_null(left) and is_null(right))
 297                or (is_null(left) and always_false(right))
 298                or (always_false(left) and is_null(right))
 299            ):
 300                return exp.null()
 301            if is_false(left):
 302                return right
 303            if is_false(right):
 304                return left
 305            return _simplify_comparison(expression, left, right, or_=True)
 306        elif isinstance(expression, exp.Xor):
 307            if left == right:
 308                return exp.false()
 309
 310    if isinstance(expression, exp.Connector):
 311        return _flat_simplify(expression, _simplify_connectors, root)
 312    return expression
 313
 314
 315LT_LTE = (exp.LT, exp.LTE)
 316GT_GTE = (exp.GT, exp.GTE)
 317
 318COMPARISONS = (
 319    *LT_LTE,
 320    *GT_GTE,
 321    exp.EQ,
 322    exp.NEQ,
 323    exp.Is,
 324)
 325
 326INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 327    exp.LT: exp.GT,
 328    exp.GT: exp.LT,
 329    exp.LTE: exp.GTE,
 330    exp.GTE: exp.LTE,
 331}
 332
 333NONDETERMINISTIC = (exp.Rand, exp.Randn)
 334AND_OR = (exp.And, exp.Or)
 335
 336
 337def _simplify_comparison(expression, left, right, or_=False):
 338    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 339        ll, lr = left.args.values()
 340        rl, rr = right.args.values()
 341
 342        largs = {ll, lr}
 343        rargs = {rl, rr}
 344
 345        matching = largs & rargs
 346        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 347
 348        if matching and columns:
 349            try:
 350                l = first(largs - columns)
 351                r = first(rargs - columns)
 352            except StopIteration:
 353                return expression
 354
 355            if l.is_number and r.is_number:
 356                l = l.to_py()
 357                r = r.to_py()
 358            elif l.is_string and r.is_string:
 359                l = l.name
 360                r = r.name
 361            else:
 362                l = extract_date(l)
 363                if not l:
 364                    return None
 365                r = extract_date(r)
 366                if not r:
 367                    return None
 368                # python won't compare date and datetime, but many engines will upcast
 369                l, r = cast_as_datetime(l), cast_as_datetime(r)
 370
 371            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 372                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 373                    return left if (av > bv if or_ else av <= bv) else right
 374                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 375                    return left if (av < bv if or_ else av >= bv) else right
 376
 377                # we can't ever shortcut to true because the column could be null
 378                if not or_:
 379                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 380                        if av <= bv:
 381                            return exp.false()
 382                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 383                        if av >= bv:
 384                            return exp.false()
 385                    elif isinstance(a, exp.EQ):
 386                        if isinstance(b, exp.LT):
 387                            return exp.false() if av >= bv else a
 388                        if isinstance(b, exp.LTE):
 389                            return exp.false() if av > bv else a
 390                        if isinstance(b, exp.GT):
 391                            return exp.false() if av <= bv else a
 392                        if isinstance(b, exp.GTE):
 393                            return exp.false() if av < bv else a
 394                        if isinstance(b, exp.NEQ):
 395                            return exp.false() if av == bv else a
 396    return None
 397
 398
 399def remove_complements(expression, root=True):
 400    """
 401    Removing complements.
 402
 403    A AND NOT A -> FALSE
 404    A OR NOT A -> TRUE
 405    """
 406    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 407        ops = set(expression.flatten())
 408        for op in ops:
 409            if isinstance(op, exp.Not) and op.this in ops:
 410                return exp.false() if isinstance(expression, exp.And) else exp.true()
 411
 412    return expression
 413
 414
 415def uniq_sort(expression, root=True):
 416    """
 417    Uniq and sort a connector.
 418
 419    C AND A AND B AND B -> A AND B AND C
 420    """
 421    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 422        flattened = tuple(expression.flatten())
 423
 424        if isinstance(expression, exp.Xor):
 425            result_func = exp.xor
 426            # Do not deduplicate XOR as A XOR A != A if A == True
 427            deduped = None
 428            arr = tuple((gen(e), e) for e in flattened)
 429        else:
 430            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 431            deduped = {gen(e): e for e in flattened}
 432            arr = tuple(deduped.items())
 433
 434        # check if the operands are already sorted, if not sort them
 435        # A AND C AND B -> A AND B AND C
 436        for i, (sql, e) in enumerate(arr[1:]):
 437            if sql < arr[i][0]:
 438                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 439                break
 440        else:
 441            # we didn't have to sort but maybe we need to dedup
 442            if deduped and len(deduped) < len(flattened):
 443                expression = result_func(*deduped.values(), copy=False)
 444
 445    return expression
 446
 447
 448def absorb_and_eliminate(expression, root=True):
 449    """
 450    absorption:
 451        A AND (A OR B) -> A
 452        A OR (A AND B) -> A
 453        A AND (NOT A OR B) -> A AND B
 454        A OR (NOT A AND B) -> A OR B
 455    elimination:
 456        (A AND B) OR (A AND NOT B) -> A
 457        (A OR B) AND (A OR NOT B) -> A
 458    """
 459    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 460        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 461
 462        ops = tuple(expression.flatten())
 463
 464        # Initialize lookup tables:
 465        # Set of all operands, used to find complements for absorption.
 466        op_set = set()
 467        # Sub-operands, used to find subsets for absorption.
 468        subops = defaultdict(list)
 469        # Pairs of complements, used for elimination.
 470        pairs = defaultdict(list)
 471
 472        # Populate the lookup tables
 473        for op in ops:
 474            op_set.add(op)
 475
 476            if not isinstance(op, kind):
 477                # In cases like: A OR (A AND B)
 478                # Subop will be: ^
 479                subops[op].append({op})
 480                continue
 481
 482            # In cases like: (A AND B) OR (A AND B AND C)
 483            # Subops will be: ^     ^
 484            subset = set(op.flatten())
 485            for i in subset:
 486                subops[i].append(subset)
 487
 488            a, b = op.unnest_operands()
 489            if isinstance(a, exp.Not):
 490                pairs[frozenset((a.this, b))].append((op, b))
 491            if isinstance(b, exp.Not):
 492                pairs[frozenset((a, b.this))].append((op, a))
 493
 494        for op in ops:
 495            if not isinstance(op, kind):
 496                continue
 497
 498            a, b = op.unnest_operands()
 499
 500            # Absorb
 501            if isinstance(a, exp.Not) and a.this in op_set:
 502                a.replace(exp.true() if kind == exp.And else exp.false())
 503                continue
 504            if isinstance(b, exp.Not) and b.this in op_set:
 505                b.replace(exp.true() if kind == exp.And else exp.false())
 506                continue
 507            superset = set(op.flatten())
 508            if any(any(subset < superset for subset in subops[i]) for i in superset):
 509                op.replace(exp.false() if kind == exp.And else exp.true())
 510                continue
 511
 512            # Eliminate
 513            for other, complement in pairs[frozenset((a, b))]:
 514                op.replace(complement)
 515                other.replace(complement)
 516
 517    return expression
 518
 519
 520def propagate_constants(expression, root=True):
 521    """
 522    Propagate constants for conjunctions in DNF:
 523
 524    SELECT * FROM t WHERE a = b AND b = 5 becomes
 525    SELECT * FROM t WHERE a = 5 AND b = 5
 526
 527    Reference: https://www.sqlite.org/optoverview.html
 528    """
 529
 530    if (
 531        isinstance(expression, exp.And)
 532        and (root or not expression.same_parent)
 533        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 534    ):
 535        constant_mapping = {}
 536        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 537            if isinstance(expr, exp.EQ):
 538                l, r = expr.left, expr.right
 539
 540                # TODO: create a helper that can be used to detect nested literal expressions such
 541                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 542                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 543                    constant_mapping[l] = (id(l), r)
 544
 545        if constant_mapping:
 546            for column in find_all_in_scope(expression, exp.Column):
 547                parent = column.parent
 548                column_id, constant = constant_mapping.get(column) or (None, None)
 549                if (
 550                    column_id is not None
 551                    and id(column) != column_id
 552                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 553                ):
 554                    column.replace(constant.copy())
 555
 556    return expression
 557
 558
 559INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 560    exp.DateAdd: exp.Sub,
 561    exp.DateSub: exp.Add,
 562    exp.DatetimeAdd: exp.Sub,
 563    exp.DatetimeSub: exp.Add,
 564}
 565
 566INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 567    **INVERSE_DATE_OPS,
 568    exp.Add: exp.Sub,
 569    exp.Sub: exp.Add,
 570}
 571
 572
 573def _is_number(expression: exp.Expression) -> bool:
 574    return expression.is_number
 575
 576
 577def _is_interval(expression: exp.Expression) -> bool:
 578    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 579
 580
 581@catch(ModuleNotFoundError, UnsupportedUnit)
 582def simplify_equality(expression: exp.Expression) -> exp.Expression:
 583    """
 584    Use the subtraction and addition properties of equality to simplify expressions:
 585
 586        x + 1 = 3 becomes x = 2
 587
 588    There are two binary operations in the above expression: + and =
 589    Here's how we reference all the operands in the code below:
 590
 591          l     r
 592        x + 1 = 3
 593        a   b
 594    """
 595    if isinstance(expression, COMPARISONS):
 596        l, r = expression.left, expression.right
 597
 598        if l.__class__ not in INVERSE_OPS:
 599            return expression
 600
 601        if r.is_number:
 602            a_predicate = _is_number
 603            b_predicate = _is_number
 604        elif _is_date_literal(r):
 605            a_predicate = _is_date_literal
 606            b_predicate = _is_interval
 607        else:
 608            return expression
 609
 610        if l.__class__ in INVERSE_DATE_OPS:
 611            l = t.cast(exp.IntervalOp, l)
 612            a = l.this
 613            b = l.interval()
 614        else:
 615            l = t.cast(exp.Binary, l)
 616            a, b = l.left, l.right
 617
 618        if not a_predicate(a) and b_predicate(b):
 619            pass
 620        elif not a_predicate(b) and b_predicate(a):
 621            a, b = b, a
 622        else:
 623            return expression
 624
 625        return expression.__class__(
 626            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 627        )
 628    return expression
 629
 630
 631def simplify_literals(expression, root=True):
 632    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 633        return _flat_simplify(expression, _simplify_binary, root)
 634
 635    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 636        return expression.this.this
 637
 638    if type(expression) in INVERSE_DATE_OPS:
 639        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 640
 641    return expression
 642
 643
 644NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 645
 646
 647def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 648    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 649        this = _simplify_integer_cast(expr.this)
 650    else:
 651        this = expr.this
 652
 653    if isinstance(expr, exp.Cast) and this.is_int:
 654        num = this.to_py()
 655
 656        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 657        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 658        # engine-dependent
 659        if (
 660            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 661        ) or (
 662            UTINYINT_MIN <= num <= UTINYINT_MAX
 663            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 664        ):
 665            return this
 666
 667    return expr
 668
 669
 670def _simplify_binary(expression, a, b):
 671    if isinstance(expression, COMPARISONS):
 672        a = _simplify_integer_cast(a)
 673        b = _simplify_integer_cast(b)
 674
 675    if isinstance(expression, exp.Is):
 676        if isinstance(b, exp.Not):
 677            c = b.this
 678            not_ = True
 679        else:
 680            c = b
 681            not_ = False
 682
 683        if is_null(c):
 684            if isinstance(a, exp.Literal):
 685                return exp.true() if not_ else exp.false()
 686            if is_null(a):
 687                return exp.false() if not_ else exp.true()
 688    elif isinstance(expression, NULL_OK):
 689        return None
 690    elif is_null(a) or is_null(b):
 691        return exp.null()
 692
 693    if a.is_number and b.is_number:
 694        num_a = a.to_py()
 695        num_b = b.to_py()
 696
 697        if isinstance(expression, exp.Add):
 698            return exp.Literal.number(num_a + num_b)
 699        if isinstance(expression, exp.Mul):
 700            return exp.Literal.number(num_a * num_b)
 701
 702        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 703        if isinstance(expression, exp.Sub):
 704            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 705        if isinstance(expression, exp.Div):
 706            # engines have differing int div behavior so intdiv is not safe
 707            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 708                return None
 709            return exp.Literal.number(num_a / num_b)
 710
 711        boolean = eval_boolean(expression, num_a, num_b)
 712
 713        if boolean:
 714            return boolean
 715    elif a.is_string and b.is_string:
 716        boolean = eval_boolean(expression, a.this, b.this)
 717
 718        if boolean:
 719            return boolean
 720    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 721        date, b = extract_date(a), extract_interval(b)
 722        if date and b:
 723            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 724                return date_literal(date + b, extract_type(a))
 725            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 726                return date_literal(date - b, extract_type(a))
 727    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 728        a, date = extract_interval(a), extract_date(b)
 729        # you cannot subtract a date from an interval
 730        if a and b and isinstance(expression, exp.Add):
 731            return date_literal(a + date, extract_type(b))
 732    elif _is_date_literal(a) and _is_date_literal(b):
 733        if isinstance(expression, exp.Predicate):
 734            a, b = extract_date(a), extract_date(b)
 735            boolean = eval_boolean(expression, a, b)
 736            if boolean:
 737                return boolean
 738
 739    return None
 740
 741
 742def simplify_parens(expression):
 743    if not isinstance(expression, exp.Paren):
 744        return expression
 745
 746    this = expression.this
 747    parent = expression.parent
 748    parent_is_predicate = isinstance(parent, exp.Predicate)
 749
 750    if (
 751        not isinstance(this, exp.Select)
 752        and not isinstance(parent, exp.SubqueryPredicate)
 753        and (
 754            not isinstance(parent, (exp.Condition, exp.Binary))
 755            or isinstance(parent, exp.Paren)
 756            or (
 757                not isinstance(this, exp.Binary)
 758                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 759            )
 760            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 761            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 762            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 763            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 764        )
 765    ):
 766        return this
 767    return expression
 768
 769
 770def _is_nonnull_constant(expression: exp.Expression) -> bool:
 771    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 772
 773
 774def _is_constant(expression: exp.Expression) -> bool:
 775    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 776
 777
 778def simplify_coalesce(expression):
 779    # COALESCE(x) -> x
 780    if (
 781        isinstance(expression, exp.Coalesce)
 782        and (not expression.expressions or _is_nonnull_constant(expression.this))
 783        # COALESCE is also used as a Spark partitioning hint
 784        and not isinstance(expression.parent, exp.Hint)
 785    ):
 786        return expression.this
 787
 788    if not isinstance(expression, COMPARISONS):
 789        return expression
 790
 791    if isinstance(expression.left, exp.Coalesce):
 792        coalesce = expression.left
 793        other = expression.right
 794    elif isinstance(expression.right, exp.Coalesce):
 795        coalesce = expression.right
 796        other = expression.left
 797    else:
 798        return expression
 799
 800    # This transformation is valid for non-constants,
 801    # but it really only does anything if they are both constants.
 802    if not _is_constant(other):
 803        return expression
 804
 805    # Find the first constant arg
 806    for arg_index, arg in enumerate(coalesce.expressions):
 807        if _is_constant(arg):
 808            break
 809    else:
 810        return expression
 811
 812    coalesce.set("expressions", coalesce.expressions[:arg_index])
 813
 814    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 815    # since we already remove COALESCE at the top of this function.
 816    coalesce = coalesce if coalesce.expressions else coalesce.this
 817
 818    # This expression is more complex than when we started, but it will get simplified further
 819    return exp.paren(
 820        exp.or_(
 821            exp.and_(
 822                coalesce.is_(exp.null()).not_(copy=False),
 823                expression.copy(),
 824                copy=False,
 825            ),
 826            exp.and_(
 827                coalesce.is_(exp.null()),
 828                type(expression)(this=arg.copy(), expression=other.copy()),
 829                copy=False,
 830            ),
 831            copy=False,
 832        )
 833    )
 834
 835
 836CONCATS = (exp.Concat, exp.DPipe)
 837
 838
 839def simplify_concat(expression):
 840    """Reduces all groups that contain string literals by concatenating them."""
 841    if not isinstance(expression, CONCATS) or (
 842        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 843        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 844    ):
 845        return expression
 846
 847    if isinstance(expression, exp.ConcatWs):
 848        sep_expr, *expressions = expression.expressions
 849        sep = sep_expr.name
 850        concat_type = exp.ConcatWs
 851        args = {}
 852    else:
 853        expressions = expression.expressions
 854        sep = ""
 855        concat_type = exp.Concat
 856        args = {
 857            "safe": expression.args.get("safe"),
 858            "coalesce": expression.args.get("coalesce"),
 859        }
 860
 861    new_args = []
 862    for is_string_group, group in itertools.groupby(
 863        expressions or expression.flatten(), lambda e: e.is_string
 864    ):
 865        if is_string_group:
 866            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 867        else:
 868            new_args.extend(group)
 869
 870    if len(new_args) == 1 and new_args[0].is_string:
 871        return new_args[0]
 872
 873    if concat_type is exp.ConcatWs:
 874        new_args = [sep_expr] + new_args
 875    elif isinstance(expression, exp.DPipe):
 876        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 877
 878    return concat_type(expressions=new_args, **args)
 879
 880
 881def simplify_conditionals(expression):
 882    """Simplifies expressions like IF, CASE if their condition is statically known."""
 883    if isinstance(expression, exp.Case):
 884        this = expression.this
 885        for case in expression.args["ifs"]:
 886            cond = case.this
 887            if this:
 888                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 889                cond = cond.replace(this.pop().eq(cond))
 890
 891            if always_true(cond):
 892                return case.args["true"]
 893
 894            if always_false(cond):
 895                case.pop()
 896                if not expression.args["ifs"]:
 897                    return expression.args.get("default") or exp.null()
 898    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 899        if always_true(expression.this):
 900            return expression.args["true"]
 901        if always_false(expression.this):
 902            return expression.args.get("false") or exp.null()
 903
 904    return expression
 905
 906
 907def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 908    """
 909    Reduces a prefix check to either TRUE or FALSE if both the string and the
 910    prefix are statically known.
 911
 912    Example:
 913        >>> from sqlglot import parse_one
 914        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 915        'TRUE'
 916    """
 917    if (
 918        isinstance(expression, exp.StartsWith)
 919        and expression.this.is_string
 920        and expression.expression.is_string
 921    ):
 922        return exp.convert(expression.name.startswith(expression.expression.name))
 923
 924    return expression
 925
 926
 927DateRange = t.Tuple[datetime.date, datetime.date]
 928
 929
 930def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 931    """
 932    Get the date range for a DATE_TRUNC equality comparison:
 933
 934    Example:
 935        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 936    Returns:
 937        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 938    """
 939    floor = date_floor(date, unit, dialect)
 940
 941    if date != floor:
 942        # This will always be False, except for NULL values.
 943        return None
 944
 945    return floor, floor + interval(unit)
 946
 947
 948def _datetrunc_eq_expression(
 949    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 950) -> exp.Expression:
 951    """Get the logical expression for a date range"""
 952    return exp.and_(
 953        left >= date_literal(drange[0], target_type),
 954        left < date_literal(drange[1], target_type),
 955        copy=False,
 956    )
 957
 958
 959def _datetrunc_eq(
 960    left: exp.Expression,
 961    date: datetime.date,
 962    unit: str,
 963    dialect: Dialect,
 964    target_type: t.Optional[exp.DataType],
 965) -> t.Optional[exp.Expression]:
 966    drange = _datetrunc_range(date, unit, dialect)
 967    if not drange:
 968        return None
 969
 970    return _datetrunc_eq_expression(left, drange, target_type)
 971
 972
 973def _datetrunc_neq(
 974    left: exp.Expression,
 975    date: datetime.date,
 976    unit: str,
 977    dialect: Dialect,
 978    target_type: t.Optional[exp.DataType],
 979) -> t.Optional[exp.Expression]:
 980    drange = _datetrunc_range(date, unit, dialect)
 981    if not drange:
 982        return None
 983
 984    return exp.and_(
 985        left < date_literal(drange[0], target_type),
 986        left >= date_literal(drange[1], target_type),
 987        copy=False,
 988    )
 989
 990
 991DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 992    exp.LT: lambda l, dt, u, d, t: l
 993    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 994    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 995    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 996    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 997    exp.EQ: _datetrunc_eq,
 998    exp.NEQ: _datetrunc_neq,
 999}
1000DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
1001DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
1002
1003
1004def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
1005    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
1006
1007
1008@catch(ModuleNotFoundError, UnsupportedUnit)
1009def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1010    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1011    comparison = expression.__class__
1012
1013    if isinstance(expression, DATETRUNCS):
1014        this = expression.this
1015        trunc_type = extract_type(this)
1016        date = extract_date(this)
1017        if date and expression.unit:
1018            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1019    elif comparison not in DATETRUNC_COMPARISONS:
1020        return expression
1021
1022    if isinstance(expression, exp.Binary):
1023        l, r = expression.left, expression.right
1024
1025        if not _is_datetrunc_predicate(l, r):
1026            return expression
1027
1028        l = t.cast(exp.DateTrunc, l)
1029        trunc_arg = l.this
1030        unit = l.unit.name.lower()
1031        date = extract_date(r)
1032
1033        if not date:
1034            return expression
1035
1036        return (
1037            DATETRUNC_BINARY_COMPARISONS[comparison](
1038                trunc_arg, date, unit, dialect, extract_type(r)
1039            )
1040            or expression
1041        )
1042
1043    if isinstance(expression, exp.In):
1044        l = expression.this
1045        rs = expression.expressions
1046
1047        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1048            l = t.cast(exp.DateTrunc, l)
1049            unit = l.unit.name.lower()
1050
1051            ranges = []
1052            for r in rs:
1053                date = extract_date(r)
1054                if not date:
1055                    return expression
1056                drange = _datetrunc_range(date, unit, dialect)
1057                if drange:
1058                    ranges.append(drange)
1059
1060            if not ranges:
1061                return expression
1062
1063            ranges = merge_ranges(ranges)
1064            target_type = extract_type(*rs)
1065
1066            return exp.or_(
1067                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1068            )
1069
1070    return expression
1071
1072
1073def sort_comparison(expression: exp.Expression) -> exp.Expression:
1074    if expression.__class__ in COMPLEMENT_COMPARISONS:
1075        l, r = expression.this, expression.expression
1076        l_column = isinstance(l, exp.Column)
1077        r_column = isinstance(r, exp.Column)
1078        l_const = _is_constant(l)
1079        r_const = _is_constant(r)
1080
1081        if (
1082            (l_column and not r_column)
1083            or (r_const and not l_const)
1084            or isinstance(r, exp.SubqueryPredicate)
1085        ):
1086            return expression
1087        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1088            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1089                this=r, expression=l
1090            )
1091    return expression
1092
1093
1094# CROSS joins result in an empty table if the right table is empty.
1095# So we can only simplify certain types of joins to CROSS.
1096# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1097JOINS = {
1098    ("", ""),
1099    ("", "INNER"),
1100    ("RIGHT", ""),
1101    ("RIGHT", "OUTER"),
1102}
1103
1104
1105def remove_where_true(expression):
1106    for where in expression.find_all(exp.Where):
1107        if always_true(where.this):
1108            where.pop()
1109    for join in expression.find_all(exp.Join):
1110        if (
1111            always_true(join.args.get("on"))
1112            and not join.args.get("using")
1113            and not join.args.get("method")
1114            and (join.side, join.kind) in JOINS
1115        ):
1116            join.args["on"].pop()
1117            join.set("side", None)
1118            join.set("kind", "CROSS")
1119
1120
1121def always_true(expression):
1122    return (isinstance(expression, exp.Boolean) and expression.this) or (
1123        isinstance(expression, exp.Literal) and not is_zero(expression)
1124    )
1125
1126
1127def always_false(expression):
1128    return is_false(expression) or is_null(expression) or is_zero(expression)
1129
1130
1131def is_zero(expression):
1132    return isinstance(expression, exp.Literal) and expression.to_py() == 0
1133
1134
1135def is_complement(a, b):
1136    return isinstance(b, exp.Not) and b.this == a
1137
1138
1139def is_false(a: exp.Expression) -> bool:
1140    return type(a) is exp.Boolean and not a.this
1141
1142
1143def is_null(a: exp.Expression) -> bool:
1144    return type(a) is exp.Null
1145
1146
1147def eval_boolean(expression, a, b):
1148    if isinstance(expression, (exp.EQ, exp.Is)):
1149        return boolean_literal(a == b)
1150    if isinstance(expression, exp.NEQ):
1151        return boolean_literal(a != b)
1152    if isinstance(expression, exp.GT):
1153        return boolean_literal(a > b)
1154    if isinstance(expression, exp.GTE):
1155        return boolean_literal(a >= b)
1156    if isinstance(expression, exp.LT):
1157        return boolean_literal(a < b)
1158    if isinstance(expression, exp.LTE):
1159        return boolean_literal(a <= b)
1160    return None
1161
1162
1163def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1164    if isinstance(value, datetime.datetime):
1165        return value.date()
1166    if isinstance(value, datetime.date):
1167        return value
1168    try:
1169        return datetime.datetime.fromisoformat(value).date()
1170    except ValueError:
1171        return None
1172
1173
1174def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1175    if isinstance(value, datetime.datetime):
1176        return value
1177    if isinstance(value, datetime.date):
1178        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1179    try:
1180        return datetime.datetime.fromisoformat(value)
1181    except ValueError:
1182        return None
1183
1184
1185def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1186    if not value:
1187        return None
1188    if to.is_type(exp.DataType.Type.DATE):
1189        return cast_as_date(value)
1190    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1191        return cast_as_datetime(value)
1192    return None
1193
1194
1195def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1196    if isinstance(cast, exp.Cast):
1197        to = cast.to
1198    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1199        to = exp.DataType.build(exp.DataType.Type.DATE)
1200    else:
1201        return None
1202
1203    if isinstance(cast.this, exp.Literal):
1204        value: t.Any = cast.this.name
1205    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1206        value = extract_date(cast.this)
1207    else:
1208        return None
1209    return cast_value(value, to)
1210
1211
1212def _is_date_literal(expression: exp.Expression) -> bool:
1213    return extract_date(expression) is not None
1214
1215
1216def extract_interval(expression):
1217    try:
1218        n = int(expression.this.to_py())
1219        unit = expression.text("unit").lower()
1220        return interval(unit, n)
1221    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1222        return None
1223
1224
1225def extract_type(*expressions):
1226    target_type = None
1227    for expression in expressions:
1228        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1229        if target_type:
1230            break
1231
1232    return target_type
1233
1234
1235def date_literal(date, target_type=None):
1236    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1237        target_type = (
1238            exp.DataType.Type.DATETIME
1239            if isinstance(date, datetime.datetime)
1240            else exp.DataType.Type.DATE
1241        )
1242
1243    return exp.cast(exp.Literal.string(date), target_type)
1244
1245
1246def interval(unit: str, n: int = 1):
1247    from dateutil.relativedelta import relativedelta
1248
1249    if unit == "year":
1250        return relativedelta(years=1 * n)
1251    if unit == "quarter":
1252        return relativedelta(months=3 * n)
1253    if unit == "month":
1254        return relativedelta(months=1 * n)
1255    if unit == "week":
1256        return relativedelta(weeks=1 * n)
1257    if unit == "day":
1258        return relativedelta(days=1 * n)
1259    if unit == "hour":
1260        return relativedelta(hours=1 * n)
1261    if unit == "minute":
1262        return relativedelta(minutes=1 * n)
1263    if unit == "second":
1264        return relativedelta(seconds=1 * n)
1265
1266    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1267
1268
1269def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1270    if unit == "year":
1271        return d.replace(month=1, day=1)
1272    if unit == "quarter":
1273        if d.month <= 3:
1274            return d.replace(month=1, day=1)
1275        elif d.month <= 6:
1276            return d.replace(month=4, day=1)
1277        elif d.month <= 9:
1278            return d.replace(month=7, day=1)
1279        else:
1280            return d.replace(month=10, day=1)
1281    if unit == "month":
1282        return d.replace(month=d.month, day=1)
1283    if unit == "week":
1284        # Assuming week starts on Monday (0) and ends on Sunday (6)
1285        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1286    if unit == "day":
1287        return d
1288
1289    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1290
1291
1292def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1293    floor = date_floor(d, unit, dialect)
1294
1295    if floor == d:
1296        return d
1297
1298    return floor + interval(unit)
1299
1300
1301def boolean_literal(condition):
1302    return exp.true() if condition else exp.false()
1303
1304
1305def _flat_simplify(expression, simplifier, root=True):
1306    if root or not expression.same_parent:
1307        operands = []
1308        queue = deque(expression.flatten(unnest=False))
1309        size = len(queue)
1310
1311        while queue:
1312            a = queue.popleft()
1313
1314            for b in queue:
1315                result = simplifier(expression, a, b)
1316
1317                if result and result is not expression:
1318                    queue.remove(b)
1319                    queue.appendleft(result)
1320                    break
1321            else:
1322                operands.append(a)
1323
1324        if len(operands) < size:
1325            return functools.reduce(
1326                lambda a, b: expression.__class__(this=a, expression=b), operands
1327            )
1328    return expression
1329
1330
1331def gen(expression: t.Any) -> str:
1332    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1333
1334    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1335    generator is expensive so we have a bare minimum sql generator here.
1336    """
1337    return Gen().gen(expression)
1338
1339
1340class Gen:
1341    def __init__(self):
1342        self.stack = []
1343        self.sqls = []
1344
1345    def gen(self, expression: exp.Expression) -> str:
1346        self.stack = [expression]
1347        self.sqls.clear()
1348
1349        while self.stack:
1350            node = self.stack.pop()
1351
1352            if isinstance(node, exp.Expression):
1353                exp_handler_name = f"{node.key}_sql"
1354
1355                if hasattr(self, exp_handler_name):
1356                    getattr(self, exp_handler_name)(node)
1357                elif isinstance(node, exp.Func):
1358                    self._function(node)
1359                else:
1360                    key = node.key.upper()
1361                    self.stack.append(f"{key} " if self._args(node) else key)
1362            elif type(node) is list:
1363                for n in reversed(node):
1364                    if n is not None:
1365                        self.stack.extend((n, ","))
1366                if node:
1367                    self.stack.pop()
1368            else:
1369                if node is not None:
1370                    self.sqls.append(str(node))
1371
1372        return "".join(self.sqls)
1373
1374    def add_sql(self, e: exp.Add) -> None:
1375        self._binary(e, " + ")
1376
1377    def alias_sql(self, e: exp.Alias) -> None:
1378        self.stack.extend(
1379            (
1380                e.args.get("alias"),
1381                " AS ",
1382                e.args.get("this"),
1383            )
1384        )
1385
1386    def and_sql(self, e: exp.And) -> None:
1387        self._binary(e, " AND ")
1388
1389    def anonymous_sql(self, e: exp.Anonymous) -> None:
1390        this = e.this
1391        if isinstance(this, str):
1392            name = this.upper()
1393        elif isinstance(this, exp.Identifier):
1394            name = this.this
1395            name = f'"{name}"' if this.quoted else name.upper()
1396        else:
1397            raise ValueError(
1398                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1399            )
1400
1401        self.stack.extend(
1402            (
1403                ")",
1404                e.expressions,
1405                "(",
1406                name,
1407            )
1408        )
1409
1410    def between_sql(self, e: exp.Between) -> None:
1411        self.stack.extend(
1412            (
1413                e.args.get("high"),
1414                " AND ",
1415                e.args.get("low"),
1416                " BETWEEN ",
1417                e.this,
1418            )
1419        )
1420
1421    def boolean_sql(self, e: exp.Boolean) -> None:
1422        self.stack.append("TRUE" if e.this else "FALSE")
1423
1424    def bracket_sql(self, e: exp.Bracket) -> None:
1425        self.stack.extend(
1426            (
1427                "]",
1428                e.expressions,
1429                "[",
1430                e.this,
1431            )
1432        )
1433
1434    def column_sql(self, e: exp.Column) -> None:
1435        for p in reversed(e.parts):
1436            self.stack.extend((p, "."))
1437        self.stack.pop()
1438
1439    def datatype_sql(self, e: exp.DataType) -> None:
1440        self._args(e, 1)
1441        self.stack.append(f"{e.this.name} ")
1442
1443    def div_sql(self, e: exp.Div) -> None:
1444        self._binary(e, " / ")
1445
1446    def dot_sql(self, e: exp.Dot) -> None:
1447        self._binary(e, ".")
1448
1449    def eq_sql(self, e: exp.EQ) -> None:
1450        self._binary(e, " = ")
1451
1452    def from_sql(self, e: exp.From) -> None:
1453        self.stack.extend((e.this, "FROM "))
1454
1455    def gt_sql(self, e: exp.GT) -> None:
1456        self._binary(e, " > ")
1457
1458    def gte_sql(self, e: exp.GTE) -> None:
1459        self._binary(e, " >= ")
1460
1461    def identifier_sql(self, e: exp.Identifier) -> None:
1462        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1463
1464    def ilike_sql(self, e: exp.ILike) -> None:
1465        self._binary(e, " ILIKE ")
1466
1467    def in_sql(self, e: exp.In) -> None:
1468        self.stack.append(")")
1469        self._args(e, 1)
1470        self.stack.extend(
1471            (
1472                "(",
1473                " IN ",
1474                e.this,
1475            )
1476        )
1477
1478    def intdiv_sql(self, e: exp.IntDiv) -> None:
1479        self._binary(e, " DIV ")
1480
1481    def is_sql(self, e: exp.Is) -> None:
1482        self._binary(e, " IS ")
1483
1484    def like_sql(self, e: exp.Like) -> None:
1485        self._binary(e, " Like ")
1486
1487    def literal_sql(self, e: exp.Literal) -> None:
1488        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1489
1490    def lt_sql(self, e: exp.LT) -> None:
1491        self._binary(e, " < ")
1492
1493    def lte_sql(self, e: exp.LTE) -> None:
1494        self._binary(e, " <= ")
1495
1496    def mod_sql(self, e: exp.Mod) -> None:
1497        self._binary(e, " % ")
1498
1499    def mul_sql(self, e: exp.Mul) -> None:
1500        self._binary(e, " * ")
1501
1502    def neg_sql(self, e: exp.Neg) -> None:
1503        self._unary(e, "-")
1504
1505    def neq_sql(self, e: exp.NEQ) -> None:
1506        self._binary(e, " <> ")
1507
1508    def not_sql(self, e: exp.Not) -> None:
1509        self._unary(e, "NOT ")
1510
1511    def null_sql(self, e: exp.Null) -> None:
1512        self.stack.append("NULL")
1513
1514    def or_sql(self, e: exp.Or) -> None:
1515        self._binary(e, " OR ")
1516
1517    def paren_sql(self, e: exp.Paren) -> None:
1518        self.stack.extend(
1519            (
1520                ")",
1521                e.this,
1522                "(",
1523            )
1524        )
1525
1526    def sub_sql(self, e: exp.Sub) -> None:
1527        self._binary(e, " - ")
1528
1529    def subquery_sql(self, e: exp.Subquery) -> None:
1530        self._args(e, 2)
1531        alias = e.args.get("alias")
1532        if alias:
1533            self.stack.append(alias)
1534        self.stack.extend((")", e.this, "("))
1535
1536    def table_sql(self, e: exp.Table) -> None:
1537        self._args(e, 4)
1538        alias = e.args.get("alias")
1539        if alias:
1540            self.stack.append(alias)
1541        for p in reversed(e.parts):
1542            self.stack.extend((p, "."))
1543        self.stack.pop()
1544
1545    def tablealias_sql(self, e: exp.TableAlias) -> None:
1546        columns = e.columns
1547
1548        if columns:
1549            self.stack.extend((")", columns, "("))
1550
1551        self.stack.extend((e.this, " AS "))
1552
1553    def var_sql(self, e: exp.Var) -> None:
1554        self.stack.append(e.this)
1555
1556    def _binary(self, e: exp.Binary, op: str) -> None:
1557        self.stack.extend((e.expression, op, e.this))
1558
1559    def _unary(self, e: exp.Unary, op: str) -> None:
1560        self.stack.extend((e.this, op))
1561
1562    def _function(self, e: exp.Func) -> None:
1563        self.stack.extend(
1564            (
1565                ")",
1566                list(e.args.values()),
1567                "(",
1568                e.sql_name(),
1569            )
1570        )
1571
1572    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1573        kvs = []
1574        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1575
1576        for k in arg_types or arg_types:
1577            v = node.args.get(k)
1578
1579            if v is not None:
1580                kvs.append([f":{k}", v])
1581        if kvs:
1582            self.stack.append(kvs)
1583            return True
1584        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, max_depth: Optional[int] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    dialect: DialectType = None,
 44    max_depth: t.Optional[int] = None,
 45):
 46    """
 47    Rewrite sqlglot AST to simplify expressions.
 48
 49    Example:
 50        >>> import sqlglot
 51        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 52        >>> simplify(expression).sql()
 53        'TRUE'
 54
 55    Args:
 56        expression: expression to simplify
 57        constant_propagation: whether the constant propagation rule should be used
 58        max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped
 59    Returns:
 60        sqlglot.Expression: simplified expression
 61    """
 62
 63    dialect = Dialect.get_or_raise(dialect)
 64
 65    def _simplify(expression, root=True):
 66        if (
 67            max_depth
 68            and isinstance(expression, exp.Connector)
 69            and not isinstance(expression.parent, exp.Connector)
 70        ):
 71            depth = connector_depth(expression)
 72            if depth > max_depth:
 73                logger.info(
 74                    f"Skipping simplification because connector depth {depth} exceeds max {max_depth}"
 75                )
 76                return expression
 77
 78        if expression.meta.get(FINAL):
 79            return expression
 80
 81        # group by expressions cannot be simplified, for example
 82        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 83        # the projection must exactly match the group by key
 84        group = expression.args.get("group")
 85
 86        if group and hasattr(expression, "selects"):
 87            groups = set(group.expressions)
 88            group.meta[FINAL] = True
 89
 90            for e in expression.selects:
 91                for node in e.walk():
 92                    if node in groups:
 93                        e.meta[FINAL] = True
 94                        break
 95
 96            having = expression.args.get("having")
 97            if having:
 98                for node in having.walk():
 99                    if node in groups:
100                        having.meta[FINAL] = True
101                        break
102
103        # Pre-order transformations
104        node = expression
105        node = rewrite_between(node)
106        node = uniq_sort(node, root)
107        node = absorb_and_eliminate(node, root)
108        node = simplify_concat(node)
109        node = simplify_conditionals(node)
110
111        if constant_propagation:
112            node = propagate_constants(node, root)
113
114        exp.replace_children(node, lambda e: _simplify(e, False))
115
116        # Post-order transformations
117        node = simplify_not(node)
118        node = flatten(node)
119        node = simplify_connectors(node, root)
120        node = remove_complements(node, root)
121        node = simplify_coalesce(node)
122        node.parent = expression.parent
123        node = simplify_literals(node, root)
124        node = simplify_equality(node)
125        node = simplify_parens(node)
126        node = simplify_datetrunc(node, dialect)
127        node = sort_comparison(node)
128        node = simplify_startswith(node)
129
130        if root:
131            expression.replace(node)
132        return node
133
134    expression = while_changing(expression, _simplify)
135    remove_where_true(expression)
136    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • max_depth: Chains of Connectors (AND, OR, etc) exceeding max_depth will be skipped
Returns:

sqlglot.Expression: simplified expression

def connector_depth(expression: sqlglot.expressions.Expression) -> int:
139def connector_depth(expression: exp.Expression) -> int:
140    """
141    Determine the maximum depth of a tree of Connectors.
142
143    For example:
144        >>> from sqlglot import parse_one
145        >>> connector_depth(parse_one("a AND b AND c AND d"))
146        3
147    """
148    stack = deque([(expression, 0)])
149    max_depth = 0
150
151    while stack:
152        expression, depth = stack.pop()
153
154        if not isinstance(expression, exp.Connector):
155            continue
156
157        depth += 1
158        max_depth = max(depth, max_depth)
159
160        stack.append((expression.left, depth))
161        stack.append((expression.right, depth))
162
163    return max_depth

Determine the maximum depth of a tree of Connectors.

For example:
>>> from sqlglot import parse_one
>>> connector_depth(parse_one("a AND b AND c AND d"))
3
def catch(*exceptions):
166def catch(*exceptions):
167    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
168
169    def decorator(func):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
175
176        return wrapped
177
178    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
181def rewrite_between(expression: exp.Expression) -> exp.Expression:
182    """Rewrite x between y and z to x >= y AND x <= z.
183
184    This is done because comparison simplification is only done on lt/lte/gt/gte.
185    """
186    if isinstance(expression, exp.Between):
187        negate = isinstance(expression.parent, exp.Not)
188
189        expression = exp.and_(
190            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
191            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
192            copy=False,
193        )
194
195        if negate:
196            expression = exp.paren(expression, copy=False)
197
198    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

COMPLEMENT_SUBQUERY_PREDICATES = {<class 'sqlglot.expressions.All'>: <class 'sqlglot.expressions.Any'>, <class 'sqlglot.expressions.Any'>: <class 'sqlglot.expressions.All'>}
def simplify_not(expression):
216def simplify_not(expression):
217    """
218    Demorgan's Law
219    NOT (x OR y) -> NOT x AND NOT y
220    NOT (x AND y) -> NOT x OR NOT y
221    """
222    if isinstance(expression, exp.Not):
223        this = expression.this
224        if is_null(this):
225            return exp.null()
226        if this.__class__ in COMPLEMENT_COMPARISONS:
227            right = this.expression
228            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
229            if complement_subquery_predicate:
230                right = complement_subquery_predicate(this=right.this)
231
232            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
233        if isinstance(this, exp.Paren):
234            condition = this.unnest()
235            if isinstance(condition, exp.And):
236                return exp.paren(
237                    exp.or_(
238                        exp.not_(condition.left, copy=False),
239                        exp.not_(condition.right, copy=False),
240                        copy=False,
241                    )
242                )
243            if isinstance(condition, exp.Or):
244                return exp.paren(
245                    exp.and_(
246                        exp.not_(condition.left, copy=False),
247                        exp.not_(condition.right, copy=False),
248                        copy=False,
249                    )
250                )
251            if is_null(condition):
252                return exp.null()
253        if always_true(this):
254            return exp.false()
255        if is_false(this):
256            return exp.true()
257        if isinstance(this, exp.Not):
258            # double negation
259            # NOT NOT x -> x
260            return this.this
261    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
264def flatten(expression):
265    """
266    A AND (B AND C) -> A AND B AND C
267    A OR (B OR C) -> A OR B OR C
268    """
269    if isinstance(expression, exp.Connector):
270        for node in expression.args.values():
271            child = node.unnest()
272            if isinstance(child, expression.__class__):
273                node.replace(child)
274    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
277def simplify_connectors(expression, root=True):
278    def _simplify_connectors(expression, left, right):
279        if isinstance(expression, exp.And):
280            if is_false(left) or is_false(right):
281                return exp.false()
282            if is_zero(left) or is_zero(right):
283                return exp.false()
284            if is_null(left) or is_null(right):
285                return exp.null()
286            if always_true(left) and always_true(right):
287                return exp.true()
288            if always_true(left):
289                return right
290            if always_true(right):
291                return left
292            return _simplify_comparison(expression, left, right)
293        elif isinstance(expression, exp.Or):
294            if always_true(left) or always_true(right):
295                return exp.true()
296            if (
297                (is_null(left) and is_null(right))
298                or (is_null(left) and always_false(right))
299                or (always_false(left) and is_null(right))
300            ):
301                return exp.null()
302            if is_false(left):
303                return right
304            if is_false(right):
305                return left
306            return _simplify_comparison(expression, left, right, or_=True)
307        elif isinstance(expression, exp.Xor):
308            if left == right:
309                return exp.false()
310
311    if isinstance(expression, exp.Connector):
312        return _flat_simplify(expression, _simplify_connectors, root)
313    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
400def remove_complements(expression, root=True):
401    """
402    Removing complements.
403
404    A AND NOT A -> FALSE
405    A OR NOT A -> TRUE
406    """
407    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
408        ops = set(expression.flatten())
409        for op in ops:
410            if isinstance(op, exp.Not) and op.this in ops:
411                return exp.false() if isinstance(expression, exp.And) else exp.true()
412
413    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
416def uniq_sort(expression, root=True):
417    """
418    Uniq and sort a connector.
419
420    C AND A AND B AND B -> A AND B AND C
421    """
422    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
423        flattened = tuple(expression.flatten())
424
425        if isinstance(expression, exp.Xor):
426            result_func = exp.xor
427            # Do not deduplicate XOR as A XOR A != A if A == True
428            deduped = None
429            arr = tuple((gen(e), e) for e in flattened)
430        else:
431            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
432            deduped = {gen(e): e for e in flattened}
433            arr = tuple(deduped.items())
434
435        # check if the operands are already sorted, if not sort them
436        # A AND C AND B -> A AND B AND C
437        for i, (sql, e) in enumerate(arr[1:]):
438            if sql < arr[i][0]:
439                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
440                break
441        else:
442            # we didn't have to sort but maybe we need to dedup
443            if deduped and len(deduped) < len(flattened):
444                expression = result_func(*deduped.values(), copy=False)
445
446    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
449def absorb_and_eliminate(expression, root=True):
450    """
451    absorption:
452        A AND (A OR B) -> A
453        A OR (A AND B) -> A
454        A AND (NOT A OR B) -> A AND B
455        A OR (NOT A AND B) -> A OR B
456    elimination:
457        (A AND B) OR (A AND NOT B) -> A
458        (A OR B) AND (A OR NOT B) -> A
459    """
460    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
461        kind = exp.Or if isinstance(expression, exp.And) else exp.And
462
463        ops = tuple(expression.flatten())
464
465        # Initialize lookup tables:
466        # Set of all operands, used to find complements for absorption.
467        op_set = set()
468        # Sub-operands, used to find subsets for absorption.
469        subops = defaultdict(list)
470        # Pairs of complements, used for elimination.
471        pairs = defaultdict(list)
472
473        # Populate the lookup tables
474        for op in ops:
475            op_set.add(op)
476
477            if not isinstance(op, kind):
478                # In cases like: A OR (A AND B)
479                # Subop will be: ^
480                subops[op].append({op})
481                continue
482
483            # In cases like: (A AND B) OR (A AND B AND C)
484            # Subops will be: ^     ^
485            subset = set(op.flatten())
486            for i in subset:
487                subops[i].append(subset)
488
489            a, b = op.unnest_operands()
490            if isinstance(a, exp.Not):
491                pairs[frozenset((a.this, b))].append((op, b))
492            if isinstance(b, exp.Not):
493                pairs[frozenset((a, b.this))].append((op, a))
494
495        for op in ops:
496            if not isinstance(op, kind):
497                continue
498
499            a, b = op.unnest_operands()
500
501            # Absorb
502            if isinstance(a, exp.Not) and a.this in op_set:
503                a.replace(exp.true() if kind == exp.And else exp.false())
504                continue
505            if isinstance(b, exp.Not) and b.this in op_set:
506                b.replace(exp.true() if kind == exp.And else exp.false())
507                continue
508            superset = set(op.flatten())
509            if any(any(subset < superset for subset in subops[i]) for i in superset):
510                op.replace(exp.false() if kind == exp.And else exp.true())
511                continue
512
513            # Eliminate
514            for other, complement in pairs[frozenset((a, b))]:
515                op.replace(complement)
516                other.replace(complement)
517
518    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
521def propagate_constants(expression, root=True):
522    """
523    Propagate constants for conjunctions in DNF:
524
525    SELECT * FROM t WHERE a = b AND b = 5 becomes
526    SELECT * FROM t WHERE a = 5 AND b = 5
527
528    Reference: https://www.sqlite.org/optoverview.html
529    """
530
531    if (
532        isinstance(expression, exp.And)
533        and (root or not expression.same_parent)
534        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
535    ):
536        constant_mapping = {}
537        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
538            if isinstance(expr, exp.EQ):
539                l, r = expr.left, expr.right
540
541                # TODO: create a helper that can be used to detect nested literal expressions such
542                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
543                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
544                    constant_mapping[l] = (id(l), r)
545
546        if constant_mapping:
547            for column in find_all_in_scope(expression, exp.Column):
548                parent = column.parent
549                column_id, constant = constant_mapping.get(column) or (None, None)
550                if (
551                    column_id is not None
552                    and id(column) != column_id
553                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
554                ):
555                    column.replace(constant.copy())
556
557    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
632def simplify_literals(expression, root=True):
633    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
634        return _flat_simplify(expression, _simplify_binary, root)
635
636    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
637        return expression.this.this
638
639    if type(expression) in INVERSE_DATE_OPS:
640        return _simplify_binary(expression, expression.this, expression.interval()) or expression
641
642    return expression
def simplify_parens(expression):
743def simplify_parens(expression):
744    if not isinstance(expression, exp.Paren):
745        return expression
746
747    this = expression.this
748    parent = expression.parent
749    parent_is_predicate = isinstance(parent, exp.Predicate)
750
751    if (
752        not isinstance(this, exp.Select)
753        and not isinstance(parent, exp.SubqueryPredicate)
754        and (
755            not isinstance(parent, (exp.Condition, exp.Binary))
756            or isinstance(parent, exp.Paren)
757            or (
758                not isinstance(this, exp.Binary)
759                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
760            )
761            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
762            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
763            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
764            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
765        )
766    ):
767        return this
768    return expression
def simplify_coalesce(expression):
779def simplify_coalesce(expression):
780    # COALESCE(x) -> x
781    if (
782        isinstance(expression, exp.Coalesce)
783        and (not expression.expressions or _is_nonnull_constant(expression.this))
784        # COALESCE is also used as a Spark partitioning hint
785        and not isinstance(expression.parent, exp.Hint)
786    ):
787        return expression.this
788
789    if not isinstance(expression, COMPARISONS):
790        return expression
791
792    if isinstance(expression.left, exp.Coalesce):
793        coalesce = expression.left
794        other = expression.right
795    elif isinstance(expression.right, exp.Coalesce):
796        coalesce = expression.right
797        other = expression.left
798    else:
799        return expression
800
801    # This transformation is valid for non-constants,
802    # but it really only does anything if they are both constants.
803    if not _is_constant(other):
804        return expression
805
806    # Find the first constant arg
807    for arg_index, arg in enumerate(coalesce.expressions):
808        if _is_constant(arg):
809            break
810    else:
811        return expression
812
813    coalesce.set("expressions", coalesce.expressions[:arg_index])
814
815    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
816    # since we already remove COALESCE at the top of this function.
817    coalesce = coalesce if coalesce.expressions else coalesce.this
818
819    # This expression is more complex than when we started, but it will get simplified further
820    return exp.paren(
821        exp.or_(
822            exp.and_(
823                coalesce.is_(exp.null()).not_(copy=False),
824                expression.copy(),
825                copy=False,
826            ),
827            exp.and_(
828                coalesce.is_(exp.null()),
829                type(expression)(this=arg.copy(), expression=other.copy()),
830                copy=False,
831            ),
832            copy=False,
833        )
834    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
840def simplify_concat(expression):
841    """Reduces all groups that contain string literals by concatenating them."""
842    if not isinstance(expression, CONCATS) or (
843        # We can't reduce a CONCAT_WS call if we don't statically know the separator
844        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
845    ):
846        return expression
847
848    if isinstance(expression, exp.ConcatWs):
849        sep_expr, *expressions = expression.expressions
850        sep = sep_expr.name
851        concat_type = exp.ConcatWs
852        args = {}
853    else:
854        expressions = expression.expressions
855        sep = ""
856        concat_type = exp.Concat
857        args = {
858            "safe": expression.args.get("safe"),
859            "coalesce": expression.args.get("coalesce"),
860        }
861
862    new_args = []
863    for is_string_group, group in itertools.groupby(
864        expressions or expression.flatten(), lambda e: e.is_string
865    ):
866        if is_string_group:
867            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
868        else:
869            new_args.extend(group)
870
871    if len(new_args) == 1 and new_args[0].is_string:
872        return new_args[0]
873
874    if concat_type is exp.ConcatWs:
875        new_args = [sep_expr] + new_args
876    elif isinstance(expression, exp.DPipe):
877        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
878
879    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
882def simplify_conditionals(expression):
883    """Simplifies expressions like IF, CASE if their condition is statically known."""
884    if isinstance(expression, exp.Case):
885        this = expression.this
886        for case in expression.args["ifs"]:
887            cond = case.this
888            if this:
889                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
890                cond = cond.replace(this.pop().eq(cond))
891
892            if always_true(cond):
893                return case.args["true"]
894
895            if always_false(cond):
896                case.pop()
897                if not expression.args["ifs"]:
898                    return expression.args.get("default") or exp.null()
899    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
900        if always_true(expression.this):
901            return expression.args["true"]
902        if always_false(expression.this):
903            return expression.args.get("false") or exp.null()
904
905    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
908def simplify_startswith(expression: exp.Expression) -> exp.Expression:
909    """
910    Reduces a prefix check to either TRUE or FALSE if both the string and the
911    prefix are statically known.
912
913    Example:
914        >>> from sqlglot import parse_one
915        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
916        'TRUE'
917    """
918    if (
919        isinstance(expression, exp.StartsWith)
920        and expression.this.is_string
921        and expression.expression.is_string
922    ):
923        return exp.convert(expression.name.startswith(expression.expression.name))
924
925    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.GT'>}
def simplify_datetrunc(expression, *args, **kwargs):
170        def wrapped(expression, *args, **kwargs):
171            try:
172                return func(expression, *args, **kwargs)
173            except exceptions:
174                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1074def sort_comparison(expression: exp.Expression) -> exp.Expression:
1075    if expression.__class__ in COMPLEMENT_COMPARISONS:
1076        l, r = expression.this, expression.expression
1077        l_column = isinstance(l, exp.Column)
1078        r_column = isinstance(r, exp.Column)
1079        l_const = _is_constant(l)
1080        r_const = _is_constant(r)
1081
1082        if (
1083            (l_column and not r_column)
1084            or (r_const and not l_const)
1085            or isinstance(r, exp.SubqueryPredicate)
1086        ):
1087            return expression
1088        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1089            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1090                this=r, expression=l
1091            )
1092    return expression
JOINS = {('RIGHT', ''), ('', ''), ('', 'INNER'), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
1106def remove_where_true(expression):
1107    for where in expression.find_all(exp.Where):
1108        if always_true(where.this):
1109            where.pop()
1110    for join in expression.find_all(exp.Join):
1111        if (
1112            always_true(join.args.get("on"))
1113            and not join.args.get("using")
1114            and not join.args.get("method")
1115            and (join.side, join.kind) in JOINS
1116        ):
1117            join.args["on"].pop()
1118            join.set("side", None)
1119            join.set("kind", "CROSS")
def always_true(expression):
1122def always_true(expression):
1123    return (isinstance(expression, exp.Boolean) and expression.this) or (
1124        isinstance(expression, exp.Literal) and not is_zero(expression)
1125    )
def always_false(expression):
1128def always_false(expression):
1129    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
1132def is_zero(expression):
1133    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
1136def is_complement(a, b):
1137    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1140def is_false(a: exp.Expression) -> bool:
1141    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1144def is_null(a: exp.Expression) -> bool:
1145    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1148def eval_boolean(expression, a, b):
1149    if isinstance(expression, (exp.EQ, exp.Is)):
1150        return boolean_literal(a == b)
1151    if isinstance(expression, exp.NEQ):
1152        return boolean_literal(a != b)
1153    if isinstance(expression, exp.GT):
1154        return boolean_literal(a > b)
1155    if isinstance(expression, exp.GTE):
1156        return boolean_literal(a >= b)
1157    if isinstance(expression, exp.LT):
1158        return boolean_literal(a < b)
1159    if isinstance(expression, exp.LTE):
1160        return boolean_literal(a <= b)
1161    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1164def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1165    if isinstance(value, datetime.datetime):
1166        return value.date()
1167    if isinstance(value, datetime.date):
1168        return value
1169    try:
1170        return datetime.datetime.fromisoformat(value).date()
1171    except ValueError:
1172        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1175def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1176    if isinstance(value, datetime.datetime):
1177        return value
1178    if isinstance(value, datetime.date):
1179        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1180    try:
1181        return datetime.datetime.fromisoformat(value)
1182    except ValueError:
1183        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1186def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1187    if not value:
1188        return None
1189    if to.is_type(exp.DataType.Type.DATE):
1190        return cast_as_date(value)
1191    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1192        return cast_as_datetime(value)
1193    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1196def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1197    if isinstance(cast, exp.Cast):
1198        to = cast.to
1199    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1200        to = exp.DataType.build(exp.DataType.Type.DATE)
1201    else:
1202        return None
1203
1204    if isinstance(cast.this, exp.Literal):
1205        value: t.Any = cast.this.name
1206    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1207        value = extract_date(cast.this)
1208    else:
1209        return None
1210    return cast_value(value, to)
def extract_interval(expression):
1217def extract_interval(expression):
1218    try:
1219        n = int(expression.this.to_py())
1220        unit = expression.text("unit").lower()
1221        return interval(unit, n)
1222    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1223        return None
def extract_type(*expressions):
1226def extract_type(*expressions):
1227    target_type = None
1228    for expression in expressions:
1229        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1230        if target_type:
1231            break
1232
1233    return target_type
def date_literal(date, target_type=None):
1236def date_literal(date, target_type=None):
1237    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1238        target_type = (
1239            exp.DataType.Type.DATETIME
1240            if isinstance(date, datetime.datetime)
1241            else exp.DataType.Type.DATE
1242        )
1243
1244    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1247def interval(unit: str, n: int = 1):
1248    from dateutil.relativedelta import relativedelta
1249
1250    if unit == "year":
1251        return relativedelta(years=1 * n)
1252    if unit == "quarter":
1253        return relativedelta(months=3 * n)
1254    if unit == "month":
1255        return relativedelta(months=1 * n)
1256    if unit == "week":
1257        return relativedelta(weeks=1 * n)
1258    if unit == "day":
1259        return relativedelta(days=1 * n)
1260    if unit == "hour":
1261        return relativedelta(hours=1 * n)
1262    if unit == "minute":
1263        return relativedelta(minutes=1 * n)
1264    if unit == "second":
1265        return relativedelta(seconds=1 * n)
1266
1267    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1270def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1271    if unit == "year":
1272        return d.replace(month=1, day=1)
1273    if unit == "quarter":
1274        if d.month <= 3:
1275            return d.replace(month=1, day=1)
1276        elif d.month <= 6:
1277            return d.replace(month=4, day=1)
1278        elif d.month <= 9:
1279            return d.replace(month=7, day=1)
1280        else:
1281            return d.replace(month=10, day=1)
1282    if unit == "month":
1283        return d.replace(month=d.month, day=1)
1284    if unit == "week":
1285        # Assuming week starts on Monday (0) and ends on Sunday (6)
1286        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1287    if unit == "day":
1288        return d
1289
1290    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1293def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1294    floor = date_floor(d, unit, dialect)
1295
1296    if floor == d:
1297        return d
1298
1299    return floor + interval(unit)
def boolean_literal(condition):
1302def boolean_literal(condition):
1303    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1332def gen(expression: t.Any) -> str:
1333    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1334
1335    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1336    generator is expensive so we have a bare minimum sql generator here.
1337    """
1338    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1341class Gen:
1342    def __init__(self):
1343        self.stack = []
1344        self.sqls = []
1345
1346    def gen(self, expression: exp.Expression) -> str:
1347        self.stack = [expression]
1348        self.sqls.clear()
1349
1350        while self.stack:
1351            node = self.stack.pop()
1352
1353            if isinstance(node, exp.Expression):
1354                exp_handler_name = f"{node.key}_sql"
1355
1356                if hasattr(self, exp_handler_name):
1357                    getattr(self, exp_handler_name)(node)
1358                elif isinstance(node, exp.Func):
1359                    self._function(node)
1360                else:
1361                    key = node.key.upper()
1362                    self.stack.append(f"{key} " if self._args(node) else key)
1363            elif type(node) is list:
1364                for n in reversed(node):
1365                    if n is not None:
1366                        self.stack.extend((n, ","))
1367                if node:
1368                    self.stack.pop()
1369            else:
1370                if node is not None:
1371                    self.sqls.append(str(node))
1372
1373        return "".join(self.sqls)
1374
1375    def add_sql(self, e: exp.Add) -> None:
1376        self._binary(e, " + ")
1377
1378    def alias_sql(self, e: exp.Alias) -> None:
1379        self.stack.extend(
1380            (
1381                e.args.get("alias"),
1382                " AS ",
1383                e.args.get("this"),
1384            )
1385        )
1386
1387    def and_sql(self, e: exp.And) -> None:
1388        self._binary(e, " AND ")
1389
1390    def anonymous_sql(self, e: exp.Anonymous) -> None:
1391        this = e.this
1392        if isinstance(this, str):
1393            name = this.upper()
1394        elif isinstance(this, exp.Identifier):
1395            name = this.this
1396            name = f'"{name}"' if this.quoted else name.upper()
1397        else:
1398            raise ValueError(
1399                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1400            )
1401
1402        self.stack.extend(
1403            (
1404                ")",
1405                e.expressions,
1406                "(",
1407                name,
1408            )
1409        )
1410
1411    def between_sql(self, e: exp.Between) -> None:
1412        self.stack.extend(
1413            (
1414                e.args.get("high"),
1415                " AND ",
1416                e.args.get("low"),
1417                " BETWEEN ",
1418                e.this,
1419            )
1420        )
1421
1422    def boolean_sql(self, e: exp.Boolean) -> None:
1423        self.stack.append("TRUE" if e.this else "FALSE")
1424
1425    def bracket_sql(self, e: exp.Bracket) -> None:
1426        self.stack.extend(
1427            (
1428                "]",
1429                e.expressions,
1430                "[",
1431                e.this,
1432            )
1433        )
1434
1435    def column_sql(self, e: exp.Column) -> None:
1436        for p in reversed(e.parts):
1437            self.stack.extend((p, "."))
1438        self.stack.pop()
1439
1440    def datatype_sql(self, e: exp.DataType) -> None:
1441        self._args(e, 1)
1442        self.stack.append(f"{e.this.name} ")
1443
1444    def div_sql(self, e: exp.Div) -> None:
1445        self._binary(e, " / ")
1446
1447    def dot_sql(self, e: exp.Dot) -> None:
1448        self._binary(e, ".")
1449
1450    def eq_sql(self, e: exp.EQ) -> None:
1451        self._binary(e, " = ")
1452
1453    def from_sql(self, e: exp.From) -> None:
1454        self.stack.extend((e.this, "FROM "))
1455
1456    def gt_sql(self, e: exp.GT) -> None:
1457        self._binary(e, " > ")
1458
1459    def gte_sql(self, e: exp.GTE) -> None:
1460        self._binary(e, " >= ")
1461
1462    def identifier_sql(self, e: exp.Identifier) -> None:
1463        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1464
1465    def ilike_sql(self, e: exp.ILike) -> None:
1466        self._binary(e, " ILIKE ")
1467
1468    def in_sql(self, e: exp.In) -> None:
1469        self.stack.append(")")
1470        self._args(e, 1)
1471        self.stack.extend(
1472            (
1473                "(",
1474                " IN ",
1475                e.this,
1476            )
1477        )
1478
1479    def intdiv_sql(self, e: exp.IntDiv) -> None:
1480        self._binary(e, " DIV ")
1481
1482    def is_sql(self, e: exp.Is) -> None:
1483        self._binary(e, " IS ")
1484
1485    def like_sql(self, e: exp.Like) -> None:
1486        self._binary(e, " Like ")
1487
1488    def literal_sql(self, e: exp.Literal) -> None:
1489        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1490
1491    def lt_sql(self, e: exp.LT) -> None:
1492        self._binary(e, " < ")
1493
1494    def lte_sql(self, e: exp.LTE) -> None:
1495        self._binary(e, " <= ")
1496
1497    def mod_sql(self, e: exp.Mod) -> None:
1498        self._binary(e, " % ")
1499
1500    def mul_sql(self, e: exp.Mul) -> None:
1501        self._binary(e, " * ")
1502
1503    def neg_sql(self, e: exp.Neg) -> None:
1504        self._unary(e, "-")
1505
1506    def neq_sql(self, e: exp.NEQ) -> None:
1507        self._binary(e, " <> ")
1508
1509    def not_sql(self, e: exp.Not) -> None:
1510        self._unary(e, "NOT ")
1511
1512    def null_sql(self, e: exp.Null) -> None:
1513        self.stack.append("NULL")
1514
1515    def or_sql(self, e: exp.Or) -> None:
1516        self._binary(e, " OR ")
1517
1518    def paren_sql(self, e: exp.Paren) -> None:
1519        self.stack.extend(
1520            (
1521                ")",
1522                e.this,
1523                "(",
1524            )
1525        )
1526
1527    def sub_sql(self, e: exp.Sub) -> None:
1528        self._binary(e, " - ")
1529
1530    def subquery_sql(self, e: exp.Subquery) -> None:
1531        self._args(e, 2)
1532        alias = e.args.get("alias")
1533        if alias:
1534            self.stack.append(alias)
1535        self.stack.extend((")", e.this, "("))
1536
1537    def table_sql(self, e: exp.Table) -> None:
1538        self._args(e, 4)
1539        alias = e.args.get("alias")
1540        if alias:
1541            self.stack.append(alias)
1542        for p in reversed(e.parts):
1543            self.stack.extend((p, "."))
1544        self.stack.pop()
1545
1546    def tablealias_sql(self, e: exp.TableAlias) -> None:
1547        columns = e.columns
1548
1549        if columns:
1550            self.stack.extend((")", columns, "("))
1551
1552        self.stack.extend((e.this, " AS "))
1553
1554    def var_sql(self, e: exp.Var) -> None:
1555        self.stack.append(e.this)
1556
1557    def _binary(self, e: exp.Binary, op: str) -> None:
1558        self.stack.extend((e.expression, op, e.this))
1559
1560    def _unary(self, e: exp.Unary, op: str) -> None:
1561        self.stack.extend((e.this, op))
1562
1563    def _function(self, e: exp.Func) -> None:
1564        self.stack.extend(
1565            (
1566                ")",
1567                list(e.args.values()),
1568                "(",
1569                e.sql_name(),
1570            )
1571        )
1572
1573    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1574        kvs = []
1575        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1576
1577        for k in arg_types or arg_types:
1578            v = node.args.get(k)
1579
1580            if v is not None:
1581                kvs.append([f":{k}", v])
1582        if kvs:
1583            self.stack.append(kvs)
1584            return True
1585        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1346    def gen(self, expression: exp.Expression) -> str:
1347        self.stack = [expression]
1348        self.sqls.clear()
1349
1350        while self.stack:
1351            node = self.stack.pop()
1352
1353            if isinstance(node, exp.Expression):
1354                exp_handler_name = f"{node.key}_sql"
1355
1356                if hasattr(self, exp_handler_name):
1357                    getattr(self, exp_handler_name)(node)
1358                elif isinstance(node, exp.Func):
1359                    self._function(node)
1360                else:
1361                    key = node.key.upper()
1362                    self.stack.append(f"{key} " if self._args(node) else key)
1363            elif type(node) is list:
1364                for n in reversed(node):
1365                    if n is not None:
1366                        self.stack.extend((n, ","))
1367                if node:
1368                    self.stack.pop()
1369            else:
1370                if node is not None:
1371                    self.sqls.append(str(node))
1372
1373        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1375    def add_sql(self, e: exp.Add) -> None:
1376        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1378    def alias_sql(self, e: exp.Alias) -> None:
1379        self.stack.extend(
1380            (
1381                e.args.get("alias"),
1382                " AS ",
1383                e.args.get("this"),
1384            )
1385        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1387    def and_sql(self, e: exp.And) -> None:
1388        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1390    def anonymous_sql(self, e: exp.Anonymous) -> None:
1391        this = e.this
1392        if isinstance(this, str):
1393            name = this.upper()
1394        elif isinstance(this, exp.Identifier):
1395            name = this.this
1396            name = f'"{name}"' if this.quoted else name.upper()
1397        else:
1398            raise ValueError(
1399                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1400            )
1401
1402        self.stack.extend(
1403            (
1404                ")",
1405                e.expressions,
1406                "(",
1407                name,
1408            )
1409        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1411    def between_sql(self, e: exp.Between) -> None:
1412        self.stack.extend(
1413            (
1414                e.args.get("high"),
1415                " AND ",
1416                e.args.get("low"),
1417                " BETWEEN ",
1418                e.this,
1419            )
1420        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1422    def boolean_sql(self, e: exp.Boolean) -> None:
1423        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1425    def bracket_sql(self, e: exp.Bracket) -> None:
1426        self.stack.extend(
1427            (
1428                "]",
1429                e.expressions,
1430                "[",
1431                e.this,
1432            )
1433        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1435    def column_sql(self, e: exp.Column) -> None:
1436        for p in reversed(e.parts):
1437            self.stack.extend((p, "."))
1438        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1440    def datatype_sql(self, e: exp.DataType) -> None:
1441        self._args(e, 1)
1442        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1444    def div_sql(self, e: exp.Div) -> None:
1445        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1447    def dot_sql(self, e: exp.Dot) -> None:
1448        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1450    def eq_sql(self, e: exp.EQ) -> None:
1451        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1453    def from_sql(self, e: exp.From) -> None:
1454        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1456    def gt_sql(self, e: exp.GT) -> None:
1457        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1459    def gte_sql(self, e: exp.GTE) -> None:
1460        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1462    def identifier_sql(self, e: exp.Identifier) -> None:
1463        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1465    def ilike_sql(self, e: exp.ILike) -> None:
1466        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1468    def in_sql(self, e: exp.In) -> None:
1469        self.stack.append(")")
1470        self._args(e, 1)
1471        self.stack.extend(
1472            (
1473                "(",
1474                " IN ",
1475                e.this,
1476            )
1477        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1479    def intdiv_sql(self, e: exp.IntDiv) -> None:
1480        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1482    def is_sql(self, e: exp.Is) -> None:
1483        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1485    def like_sql(self, e: exp.Like) -> None:
1486        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1488    def literal_sql(self, e: exp.Literal) -> None:
1489        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1491    def lt_sql(self, e: exp.LT) -> None:
1492        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1494    def lte_sql(self, e: exp.LTE) -> None:
1495        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1497    def mod_sql(self, e: exp.Mod) -> None:
1498        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1500    def mul_sql(self, e: exp.Mul) -> None:
1501        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1503    def neg_sql(self, e: exp.Neg) -> None:
1504        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1506    def neq_sql(self, e: exp.NEQ) -> None:
1507        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1509    def not_sql(self, e: exp.Not) -> None:
1510        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1512    def null_sql(self, e: exp.Null) -> None:
1513        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1515    def or_sql(self, e: exp.Or) -> None:
1516        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1518    def paren_sql(self, e: exp.Paren) -> None:
1519        self.stack.extend(
1520            (
1521                ")",
1522                e.this,
1523                "(",
1524            )
1525        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1527    def sub_sql(self, e: exp.Sub) -> None:
1528        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1530    def subquery_sql(self, e: exp.Subquery) -> None:
1531        self._args(e, 2)
1532        alias = e.args.get("alias")
1533        if alias:
1534            self.stack.append(alias)
1535        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1537    def table_sql(self, e: exp.Table) -> None:
1538        self._args(e, 4)
1539        alias = e.args.get("alias")
1540        if alias:
1541            self.stack.append(alias)
1542        for p in reversed(e.parts):
1543            self.stack.extend((p, "."))
1544        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1546    def tablealias_sql(self, e: exp.TableAlias) -> None:
1547        columns = e.columns
1548
1549        if columns:
1550            self.stack.extend((")", columns, "("))
1551
1552        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1554    def var_sql(self, e: exp.Var) -> None:
1555        self.stack.append(e.this)