Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23logger = logging.getLogger("sqlglot")
  24
  25# Final means that an expression should not be simplified
  26FINAL = "final"
  27
  28# Value ranges for byte-sized signed/unsigned integers
  29TINYINT_MIN = -128
  30TINYINT_MAX = 127
  31UTINYINT_MIN = 0
  32UTINYINT_MAX = 255
  33
  34
  35class UnsupportedUnit(Exception):
  36    pass
  37
  38
  39def simplify(
  40    expression: exp.Expression,
  41    constant_propagation: bool = False,
  42    dialect: DialectType = None,
  43):
  44    """
  45    Rewrite sqlglot AST to simplify expressions.
  46
  47    Example:
  48        >>> import sqlglot
  49        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  50        >>> simplify(expression).sql()
  51        'TRUE'
  52
  53    Args:
  54        expression: expression to simplify
  55        constant_propagation: whether the constant propagation rule should be used
  56    Returns:
  57        sqlglot.Expression: simplified expression
  58    """
  59
  60    dialect = Dialect.get_or_raise(dialect)
  61
  62    def _simplify(expression):
  63        pre_transformation_stack = [expression]
  64        post_transformation_stack = []
  65
  66        while pre_transformation_stack:
  67            node = pre_transformation_stack.pop()
  68
  69            if node.meta.get(FINAL):
  70                continue
  71
  72            # group by expressions cannot be simplified, for example
  73            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  74            # the projection must exactly match the group by key
  75            group = node.args.get("group")
  76
  77            if group and hasattr(node, "selects"):
  78                groups = set(group.expressions)
  79                group.meta[FINAL] = True
  80
  81                for s in node.selects:
  82                    for n in s.walk():
  83                        if n in groups:
  84                            s.meta[FINAL] = True
  85                            break
  86
  87                having = node.args.get("having")
  88                if having:
  89                    for n in having.walk():
  90                        if n in groups:
  91                            having.meta[FINAL] = True
  92                            break
  93
  94            parent = node.parent
  95            root = node is expression
  96
  97            new_node = rewrite_between(node)
  98            new_node = uniq_sort(new_node, root)
  99            new_node = absorb_and_eliminate(new_node, root)
 100            new_node = simplify_concat(new_node)
 101            new_node = simplify_conditionals(new_node)
 102
 103            if constant_propagation:
 104                new_node = propagate_constants(new_node, root)
 105
 106            if new_node is not node:
 107                node.replace(new_node)
 108
 109            pre_transformation_stack.extend(
 110                n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 111            )
 112            post_transformation_stack.append((new_node, parent))
 113
 114        while post_transformation_stack:
 115            node, parent = post_transformation_stack.pop()
 116            root = node is expression
 117
 118            # Resets parent, arg_key, index pointers– this is needed because some of the
 119            # previous transformations mutate the AST, leading to an inconsistent state
 120            for k, v in tuple(node.args.items()):
 121                node.set(k, v)
 122
 123            # Post-order transformations
 124            new_node = simplify_not(node)
 125            new_node = flatten(new_node)
 126            new_node = simplify_connectors(new_node, root)
 127            new_node = remove_complements(new_node, root)
 128            new_node = simplify_coalesce(new_node, dialect)
 129
 130            new_node.parent = parent
 131
 132            new_node = simplify_literals(new_node, root)
 133            new_node = simplify_equality(new_node)
 134            new_node = simplify_parens(new_node)
 135            new_node = simplify_datetrunc(new_node, dialect)
 136            new_node = sort_comparison(new_node)
 137            new_node = simplify_startswith(new_node)
 138
 139            if new_node is not node:
 140                node.replace(new_node)
 141
 142        return new_node
 143
 144    expression = while_changing(expression, _simplify)
 145    remove_where_true(expression)
 146    return expression
 147
 148
 149def catch(*exceptions):
 150    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 151
 152    def decorator(func):
 153        def wrapped(expression, *args, **kwargs):
 154            try:
 155                return func(expression, *args, **kwargs)
 156            except exceptions:
 157                return expression
 158
 159        return wrapped
 160
 161    return decorator
 162
 163
 164def rewrite_between(expression: exp.Expression) -> exp.Expression:
 165    """Rewrite x between y and z to x >= y AND x <= z.
 166
 167    This is done because comparison simplification is only done on lt/lte/gt/gte.
 168    """
 169    if isinstance(expression, exp.Between):
 170        negate = isinstance(expression.parent, exp.Not)
 171
 172        expression = exp.and_(
 173            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 174            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 175            copy=False,
 176        )
 177
 178        if negate:
 179            expression = exp.paren(expression, copy=False)
 180
 181    return expression
 182
 183
 184COMPLEMENT_COMPARISONS = {
 185    exp.LT: exp.GTE,
 186    exp.GT: exp.LTE,
 187    exp.LTE: exp.GT,
 188    exp.GTE: exp.LT,
 189    exp.EQ: exp.NEQ,
 190    exp.NEQ: exp.EQ,
 191}
 192
 193COMPLEMENT_SUBQUERY_PREDICATES = {
 194    exp.All: exp.Any,
 195    exp.Any: exp.All,
 196}
 197
 198
 199def simplify_not(expression):
 200    """
 201    Demorgan's Law
 202    NOT (x OR y) -> NOT x AND NOT y
 203    NOT (x AND y) -> NOT x OR NOT y
 204    """
 205    if isinstance(expression, exp.Not):
 206        this = expression.this
 207        if is_null(this):
 208            return exp.null()
 209        if this.__class__ in COMPLEMENT_COMPARISONS:
 210            right = this.expression
 211            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
 212            if complement_subquery_predicate:
 213                right = complement_subquery_predicate(this=right.this)
 214
 215            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 216        if isinstance(this, exp.Paren):
 217            condition = this.unnest()
 218            if isinstance(condition, exp.And):
 219                return exp.paren(
 220                    exp.or_(
 221                        exp.not_(condition.left, copy=False),
 222                        exp.not_(condition.right, copy=False),
 223                        copy=False,
 224                    )
 225                )
 226            if isinstance(condition, exp.Or):
 227                return exp.paren(
 228                    exp.and_(
 229                        exp.not_(condition.left, copy=False),
 230                        exp.not_(condition.right, copy=False),
 231                        copy=False,
 232                    )
 233                )
 234            if is_null(condition):
 235                return exp.null()
 236        if always_true(this):
 237            return exp.false()
 238        if is_false(this):
 239            return exp.true()
 240        if isinstance(this, exp.Not):
 241            # double negation
 242            # NOT NOT x -> x
 243            return this.this
 244    return expression
 245
 246
 247def flatten(expression):
 248    """
 249    A AND (B AND C) -> A AND B AND C
 250    A OR (B OR C) -> A OR B OR C
 251    """
 252    if isinstance(expression, exp.Connector):
 253        for node in expression.args.values():
 254            child = node.unnest()
 255            if isinstance(child, expression.__class__):
 256                node.replace(child)
 257    return expression
 258
 259
 260def simplify_connectors(expression, root=True):
 261    def _simplify_connectors(expression, left, right):
 262        if isinstance(expression, exp.And):
 263            if is_false(left) or is_false(right):
 264                return exp.false()
 265            if is_zero(left) or is_zero(right):
 266                return exp.false()
 267            if is_null(left) or is_null(right):
 268                return exp.null()
 269            if always_true(left) and always_true(right):
 270                return exp.true()
 271            if always_true(left):
 272                return right
 273            if always_true(right):
 274                return left
 275            return _simplify_comparison(expression, left, right)
 276        elif isinstance(expression, exp.Or):
 277            if always_true(left) or always_true(right):
 278                return exp.true()
 279            if (
 280                (is_null(left) and is_null(right))
 281                or (is_null(left) and always_false(right))
 282                or (always_false(left) and is_null(right))
 283            ):
 284                return exp.null()
 285            if is_false(left):
 286                return right
 287            if is_false(right):
 288                return left
 289            return _simplify_comparison(expression, left, right, or_=True)
 290        elif isinstance(expression, exp.Xor):
 291            if left == right:
 292                return exp.false()
 293
 294    if isinstance(expression, exp.Connector):
 295        return _flat_simplify(expression, _simplify_connectors, root)
 296    return expression
 297
 298
 299LT_LTE = (exp.LT, exp.LTE)
 300GT_GTE = (exp.GT, exp.GTE)
 301
 302COMPARISONS = (
 303    *LT_LTE,
 304    *GT_GTE,
 305    exp.EQ,
 306    exp.NEQ,
 307    exp.Is,
 308)
 309
 310INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 311    exp.LT: exp.GT,
 312    exp.GT: exp.LT,
 313    exp.LTE: exp.GTE,
 314    exp.GTE: exp.LTE,
 315}
 316
 317NONDETERMINISTIC = (exp.Rand, exp.Randn)
 318AND_OR = (exp.And, exp.Or)
 319
 320
 321def _simplify_comparison(expression, left, right, or_=False):
 322    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 323        ll, lr = left.args.values()
 324        rl, rr = right.args.values()
 325
 326        largs = {ll, lr}
 327        rargs = {rl, rr}
 328
 329        matching = largs & rargs
 330        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 331
 332        if matching and columns:
 333            try:
 334                l = first(largs - columns)
 335                r = first(rargs - columns)
 336            except StopIteration:
 337                return expression
 338
 339            if l.is_number and r.is_number:
 340                l = l.to_py()
 341                r = r.to_py()
 342            elif l.is_string and r.is_string:
 343                l = l.name
 344                r = r.name
 345            else:
 346                l = extract_date(l)
 347                if not l:
 348                    return None
 349                r = extract_date(r)
 350                if not r:
 351                    return None
 352                # python won't compare date and datetime, but many engines will upcast
 353                l, r = cast_as_datetime(l), cast_as_datetime(r)
 354
 355            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 356                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 357                    return left if (av > bv if or_ else av <= bv) else right
 358                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 359                    return left if (av < bv if or_ else av >= bv) else right
 360
 361                # we can't ever shortcut to true because the column could be null
 362                if not or_:
 363                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 364                        if av <= bv:
 365                            return exp.false()
 366                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 367                        if av >= bv:
 368                            return exp.false()
 369                    elif isinstance(a, exp.EQ):
 370                        if isinstance(b, exp.LT):
 371                            return exp.false() if av >= bv else a
 372                        if isinstance(b, exp.LTE):
 373                            return exp.false() if av > bv else a
 374                        if isinstance(b, exp.GT):
 375                            return exp.false() if av <= bv else a
 376                        if isinstance(b, exp.GTE):
 377                            return exp.false() if av < bv else a
 378                        if isinstance(b, exp.NEQ):
 379                            return exp.false() if av == bv else a
 380    return None
 381
 382
 383def remove_complements(expression, root=True):
 384    """
 385    Removing complements.
 386
 387    A AND NOT A -> FALSE
 388    A OR NOT A -> TRUE
 389    """
 390    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 391        ops = set(expression.flatten())
 392        for op in ops:
 393            if isinstance(op, exp.Not) and op.this in ops:
 394                return exp.false() if isinstance(expression, exp.And) else exp.true()
 395
 396    return expression
 397
 398
 399def uniq_sort(expression, root=True):
 400    """
 401    Uniq and sort a connector.
 402
 403    C AND A AND B AND B -> A AND B AND C
 404    """
 405    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 406        flattened = tuple(expression.flatten())
 407
 408        if isinstance(expression, exp.Xor):
 409            result_func = exp.xor
 410            # Do not deduplicate XOR as A XOR A != A if A == True
 411            deduped = None
 412            arr = tuple((gen(e), e) for e in flattened)
 413        else:
 414            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 415            deduped = {gen(e): e for e in flattened}
 416            arr = tuple(deduped.items())
 417
 418        # check if the operands are already sorted, if not sort them
 419        # A AND C AND B -> A AND B AND C
 420        for i, (sql, e) in enumerate(arr[1:]):
 421            if sql < arr[i][0]:
 422                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 423                break
 424        else:
 425            # we didn't have to sort but maybe we need to dedup
 426            if deduped and len(deduped) < len(flattened):
 427                expression = result_func(*deduped.values(), copy=False)
 428
 429    return expression
 430
 431
 432def absorb_and_eliminate(expression, root=True):
 433    """
 434    absorption:
 435        A AND (A OR B) -> A
 436        A OR (A AND B) -> A
 437        A AND (NOT A OR B) -> A AND B
 438        A OR (NOT A AND B) -> A OR B
 439    elimination:
 440        (A AND B) OR (A AND NOT B) -> A
 441        (A OR B) AND (A OR NOT B) -> A
 442    """
 443    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
 444        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 445
 446        ops = tuple(expression.flatten())
 447
 448        # Initialize lookup tables:
 449        # Set of all operands, used to find complements for absorption.
 450        op_set = set()
 451        # Sub-operands, used to find subsets for absorption.
 452        subops = defaultdict(list)
 453        # Pairs of complements, used for elimination.
 454        pairs = defaultdict(list)
 455
 456        # Populate the lookup tables
 457        for op in ops:
 458            op_set.add(op)
 459
 460            if not isinstance(op, kind):
 461                # In cases like: A OR (A AND B)
 462                # Subop will be: ^
 463                subops[op].append({op})
 464                continue
 465
 466            # In cases like: (A AND B) OR (A AND B AND C)
 467            # Subops will be: ^     ^
 468            subset = set(op.flatten())
 469            for i in subset:
 470                subops[i].append(subset)
 471
 472            a, b = op.unnest_operands()
 473            if isinstance(a, exp.Not):
 474                pairs[frozenset((a.this, b))].append((op, b))
 475            if isinstance(b, exp.Not):
 476                pairs[frozenset((a, b.this))].append((op, a))
 477
 478        for op in ops:
 479            if not isinstance(op, kind):
 480                continue
 481
 482            a, b = op.unnest_operands()
 483
 484            # Absorb
 485            if isinstance(a, exp.Not) and a.this in op_set:
 486                a.replace(exp.true() if kind == exp.And else exp.false())
 487                continue
 488            if isinstance(b, exp.Not) and b.this in op_set:
 489                b.replace(exp.true() if kind == exp.And else exp.false())
 490                continue
 491            superset = set(op.flatten())
 492            if any(any(subset < superset for subset in subops[i]) for i in superset):
 493                op.replace(exp.false() if kind == exp.And else exp.true())
 494                continue
 495
 496            # Eliminate
 497            for other, complement in pairs[frozenset((a, b))]:
 498                op.replace(complement)
 499                other.replace(complement)
 500
 501    return expression
 502
 503
 504def propagate_constants(expression, root=True):
 505    """
 506    Propagate constants for conjunctions in DNF:
 507
 508    SELECT * FROM t WHERE a = b AND b = 5 becomes
 509    SELECT * FROM t WHERE a = 5 AND b = 5
 510
 511    Reference: https://www.sqlite.org/optoverview.html
 512    """
 513
 514    if (
 515        isinstance(expression, exp.And)
 516        and (root or not expression.same_parent)
 517        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 518    ):
 519        constant_mapping = {}
 520        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 521            if isinstance(expr, exp.EQ):
 522                l, r = expr.left, expr.right
 523
 524                # TODO: create a helper that can be used to detect nested literal expressions such
 525                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 526                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 527                    constant_mapping[l] = (id(l), r)
 528
 529        if constant_mapping:
 530            for column in find_all_in_scope(expression, exp.Column):
 531                parent = column.parent
 532                column_id, constant = constant_mapping.get(column) or (None, None)
 533                if (
 534                    column_id is not None
 535                    and id(column) != column_id
 536                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 537                ):
 538                    column.replace(constant.copy())
 539
 540    return expression
 541
 542
 543INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 544    exp.DateAdd: exp.Sub,
 545    exp.DateSub: exp.Add,
 546    exp.DatetimeAdd: exp.Sub,
 547    exp.DatetimeSub: exp.Add,
 548}
 549
 550INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 551    **INVERSE_DATE_OPS,
 552    exp.Add: exp.Sub,
 553    exp.Sub: exp.Add,
 554}
 555
 556
 557def _is_number(expression: exp.Expression) -> bool:
 558    return expression.is_number
 559
 560
 561def _is_interval(expression: exp.Expression) -> bool:
 562    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 563
 564
 565@catch(ModuleNotFoundError, UnsupportedUnit)
 566def simplify_equality(expression: exp.Expression) -> exp.Expression:
 567    """
 568    Use the subtraction and addition properties of equality to simplify expressions:
 569
 570        x + 1 = 3 becomes x = 2
 571
 572    There are two binary operations in the above expression: + and =
 573    Here's how we reference all the operands in the code below:
 574
 575          l     r
 576        x + 1 = 3
 577        a   b
 578    """
 579    if isinstance(expression, COMPARISONS):
 580        l, r = expression.left, expression.right
 581
 582        if l.__class__ not in INVERSE_OPS:
 583            return expression
 584
 585        if r.is_number:
 586            a_predicate = _is_number
 587            b_predicate = _is_number
 588        elif _is_date_literal(r):
 589            a_predicate = _is_date_literal
 590            b_predicate = _is_interval
 591        else:
 592            return expression
 593
 594        if l.__class__ in INVERSE_DATE_OPS:
 595            l = t.cast(exp.IntervalOp, l)
 596            a = l.this
 597            b = l.interval()
 598        else:
 599            l = t.cast(exp.Binary, l)
 600            a, b = l.left, l.right
 601
 602        if not a_predicate(a) and b_predicate(b):
 603            pass
 604        elif not a_predicate(b) and b_predicate(a):
 605            a, b = b, a
 606        else:
 607            return expression
 608
 609        return expression.__class__(
 610            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 611        )
 612    return expression
 613
 614
 615def simplify_literals(expression, root=True):
 616    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 617        return _flat_simplify(expression, _simplify_binary, root)
 618
 619    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
 620        return expression.this.this
 621
 622    if type(expression) in INVERSE_DATE_OPS:
 623        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 624
 625    return expression
 626
 627
 628NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 629
 630
 631def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 632    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 633        this = _simplify_integer_cast(expr.this)
 634    else:
 635        this = expr.this
 636
 637    if isinstance(expr, exp.Cast) and this.is_int:
 638        num = this.to_py()
 639
 640        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 641        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 642        # engine-dependent
 643        if (
 644            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 645        ) or (
 646            UTINYINT_MIN <= num <= UTINYINT_MAX
 647            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 648        ):
 649            return this
 650
 651    return expr
 652
 653
 654def _simplify_binary(expression, a, b):
 655    if isinstance(expression, COMPARISONS):
 656        a = _simplify_integer_cast(a)
 657        b = _simplify_integer_cast(b)
 658
 659    if isinstance(expression, exp.Is):
 660        if isinstance(b, exp.Not):
 661            c = b.this
 662            not_ = True
 663        else:
 664            c = b
 665            not_ = False
 666
 667        if is_null(c):
 668            if isinstance(a, exp.Literal):
 669                return exp.true() if not_ else exp.false()
 670            if is_null(a):
 671                return exp.false() if not_ else exp.true()
 672    elif isinstance(expression, NULL_OK):
 673        return None
 674    elif is_null(a) or is_null(b):
 675        return exp.null()
 676
 677    if a.is_number and b.is_number:
 678        num_a = a.to_py()
 679        num_b = b.to_py()
 680
 681        if isinstance(expression, exp.Add):
 682            return exp.Literal.number(num_a + num_b)
 683        if isinstance(expression, exp.Mul):
 684            return exp.Literal.number(num_a * num_b)
 685
 686        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 687        if isinstance(expression, exp.Sub):
 688            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 689        if isinstance(expression, exp.Div):
 690            # engines have differing int div behavior so intdiv is not safe
 691            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 692                return None
 693            return exp.Literal.number(num_a / num_b)
 694
 695        boolean = eval_boolean(expression, num_a, num_b)
 696
 697        if boolean:
 698            return boolean
 699    elif a.is_string and b.is_string:
 700        boolean = eval_boolean(expression, a.this, b.this)
 701
 702        if boolean:
 703            return boolean
 704    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 705        date, b = extract_date(a), extract_interval(b)
 706        if date and b:
 707            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 708                return date_literal(date + b, extract_type(a))
 709            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 710                return date_literal(date - b, extract_type(a))
 711    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 712        a, date = extract_interval(a), extract_date(b)
 713        # you cannot subtract a date from an interval
 714        if a and b and isinstance(expression, exp.Add):
 715            return date_literal(a + date, extract_type(b))
 716    elif _is_date_literal(a) and _is_date_literal(b):
 717        if isinstance(expression, exp.Predicate):
 718            a, b = extract_date(a), extract_date(b)
 719            boolean = eval_boolean(expression, a, b)
 720            if boolean:
 721                return boolean
 722
 723    return None
 724
 725
 726def simplify_parens(expression):
 727    if not isinstance(expression, exp.Paren):
 728        return expression
 729
 730    this = expression.this
 731    parent = expression.parent
 732    parent_is_predicate = isinstance(parent, exp.Predicate)
 733
 734    if (
 735        not isinstance(this, exp.Select)
 736        and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
 737        and (
 738            not isinstance(parent, (exp.Condition, exp.Binary))
 739            or isinstance(parent, exp.Paren)
 740            or (
 741                not isinstance(this, exp.Binary)
 742                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 743            )
 744            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 745            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 746            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 747            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 748        )
 749    ):
 750        return this
 751    return expression
 752
 753
 754def _is_nonnull_constant(expression: exp.Expression) -> bool:
 755    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 756
 757
 758def _is_constant(expression: exp.Expression) -> bool:
 759    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 760
 761
 762def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
 763    # COALESCE(x) -> x
 764    if (
 765        isinstance(expression, exp.Coalesce)
 766        and (not expression.expressions or _is_nonnull_constant(expression.this))
 767        # COALESCE is also used as a Spark partitioning hint
 768        and not isinstance(expression.parent, exp.Hint)
 769    ):
 770        return expression.this
 771
 772    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
 773    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
 774    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
 775    if dialect == "redshift":
 776        return expression
 777
 778    if not isinstance(expression, COMPARISONS):
 779        return expression
 780
 781    if isinstance(expression.left, exp.Coalesce):
 782        coalesce = expression.left
 783        other = expression.right
 784    elif isinstance(expression.right, exp.Coalesce):
 785        coalesce = expression.right
 786        other = expression.left
 787    else:
 788        return expression
 789
 790    # This transformation is valid for non-constants,
 791    # but it really only does anything if they are both constants.
 792    if not _is_constant(other):
 793        return expression
 794
 795    # Find the first constant arg
 796    for arg_index, arg in enumerate(coalesce.expressions):
 797        if _is_constant(arg):
 798            break
 799    else:
 800        return expression
 801
 802    coalesce.set("expressions", coalesce.expressions[:arg_index])
 803
 804    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 805    # since we already remove COALESCE at the top of this function.
 806    coalesce = coalesce if coalesce.expressions else coalesce.this
 807
 808    # This expression is more complex than when we started, but it will get simplified further
 809    return exp.paren(
 810        exp.or_(
 811            exp.and_(
 812                coalesce.is_(exp.null()).not_(copy=False),
 813                expression.copy(),
 814                copy=False,
 815            ),
 816            exp.and_(
 817                coalesce.is_(exp.null()),
 818                type(expression)(this=arg.copy(), expression=other.copy()),
 819                copy=False,
 820            ),
 821            copy=False,
 822        )
 823    )
 824
 825
 826CONCATS = (exp.Concat, exp.DPipe)
 827
 828
 829def simplify_concat(expression):
 830    """Reduces all groups that contain string literals by concatenating them."""
 831    if not isinstance(expression, CONCATS) or (
 832        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 833        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 834    ):
 835        return expression
 836
 837    if isinstance(expression, exp.ConcatWs):
 838        sep_expr, *expressions = expression.expressions
 839        sep = sep_expr.name
 840        concat_type = exp.ConcatWs
 841        args = {}
 842    else:
 843        expressions = expression.expressions
 844        sep = ""
 845        concat_type = exp.Concat
 846        args = {
 847            "safe": expression.args.get("safe"),
 848            "coalesce": expression.args.get("coalesce"),
 849        }
 850
 851    new_args = []
 852    for is_string_group, group in itertools.groupby(
 853        expressions or expression.flatten(), lambda e: e.is_string
 854    ):
 855        if is_string_group:
 856            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 857        else:
 858            new_args.extend(group)
 859
 860    if len(new_args) == 1 and new_args[0].is_string:
 861        return new_args[0]
 862
 863    if concat_type is exp.ConcatWs:
 864        new_args = [sep_expr] + new_args
 865    elif isinstance(expression, exp.DPipe):
 866        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 867
 868    return concat_type(expressions=new_args, **args)
 869
 870
 871def simplify_conditionals(expression):
 872    """Simplifies expressions like IF, CASE if their condition is statically known."""
 873    if isinstance(expression, exp.Case):
 874        this = expression.this
 875        for case in expression.args["ifs"]:
 876            cond = case.this
 877            if this:
 878                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 879                cond = cond.replace(this.pop().eq(cond))
 880
 881            if always_true(cond):
 882                return case.args["true"]
 883
 884            if always_false(cond):
 885                case.pop()
 886                if not expression.args["ifs"]:
 887                    return expression.args.get("default") or exp.null()
 888    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 889        if always_true(expression.this):
 890            return expression.args["true"]
 891        if always_false(expression.this):
 892            return expression.args.get("false") or exp.null()
 893
 894    return expression
 895
 896
 897def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 898    """
 899    Reduces a prefix check to either TRUE or FALSE if both the string and the
 900    prefix are statically known.
 901
 902    Example:
 903        >>> from sqlglot import parse_one
 904        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 905        'TRUE'
 906    """
 907    if (
 908        isinstance(expression, exp.StartsWith)
 909        and expression.this.is_string
 910        and expression.expression.is_string
 911    ):
 912        return exp.convert(expression.name.startswith(expression.expression.name))
 913
 914    return expression
 915
 916
 917DateRange = t.Tuple[datetime.date, datetime.date]
 918
 919
 920def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 921    """
 922    Get the date range for a DATE_TRUNC equality comparison:
 923
 924    Example:
 925        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 926    Returns:
 927        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 928    """
 929    floor = date_floor(date, unit, dialect)
 930
 931    if date != floor:
 932        # This will always be False, except for NULL values.
 933        return None
 934
 935    return floor, floor + interval(unit)
 936
 937
 938def _datetrunc_eq_expression(
 939    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 940) -> exp.Expression:
 941    """Get the logical expression for a date range"""
 942    return exp.and_(
 943        left >= date_literal(drange[0], target_type),
 944        left < date_literal(drange[1], target_type),
 945        copy=False,
 946    )
 947
 948
 949def _datetrunc_eq(
 950    left: exp.Expression,
 951    date: datetime.date,
 952    unit: str,
 953    dialect: Dialect,
 954    target_type: t.Optional[exp.DataType],
 955) -> t.Optional[exp.Expression]:
 956    drange = _datetrunc_range(date, unit, dialect)
 957    if not drange:
 958        return None
 959
 960    return _datetrunc_eq_expression(left, drange, target_type)
 961
 962
 963def _datetrunc_neq(
 964    left: exp.Expression,
 965    date: datetime.date,
 966    unit: str,
 967    dialect: Dialect,
 968    target_type: t.Optional[exp.DataType],
 969) -> t.Optional[exp.Expression]:
 970    drange = _datetrunc_range(date, unit, dialect)
 971    if not drange:
 972        return None
 973
 974    return exp.and_(
 975        left < date_literal(drange[0], target_type),
 976        left >= date_literal(drange[1], target_type),
 977        copy=False,
 978    )
 979
 980
 981DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 982    exp.LT: lambda l, dt, u, d, t: l
 983    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 984    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 985    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 986    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 987    exp.EQ: _datetrunc_eq,
 988    exp.NEQ: _datetrunc_neq,
 989}
 990DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 991DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 992
 993
 994def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 995    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 996
 997
 998@catch(ModuleNotFoundError, UnsupportedUnit)
 999def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
1000    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1001    comparison = expression.__class__
1002
1003    if isinstance(expression, DATETRUNCS):
1004        this = expression.this
1005        trunc_type = extract_type(this)
1006        date = extract_date(this)
1007        if date and expression.unit:
1008            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
1009    elif comparison not in DATETRUNC_COMPARISONS:
1010        return expression
1011
1012    if isinstance(expression, exp.Binary):
1013        l, r = expression.left, expression.right
1014
1015        if not _is_datetrunc_predicate(l, r):
1016            return expression
1017
1018        l = t.cast(exp.DateTrunc, l)
1019        trunc_arg = l.this
1020        unit = l.unit.name.lower()
1021        date = extract_date(r)
1022
1023        if not date:
1024            return expression
1025
1026        return (
1027            DATETRUNC_BINARY_COMPARISONS[comparison](
1028                trunc_arg, date, unit, dialect, extract_type(r)
1029            )
1030            or expression
1031        )
1032
1033    if isinstance(expression, exp.In):
1034        l = expression.this
1035        rs = expression.expressions
1036
1037        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
1038            l = t.cast(exp.DateTrunc, l)
1039            unit = l.unit.name.lower()
1040
1041            ranges = []
1042            for r in rs:
1043                date = extract_date(r)
1044                if not date:
1045                    return expression
1046                drange = _datetrunc_range(date, unit, dialect)
1047                if drange:
1048                    ranges.append(drange)
1049
1050            if not ranges:
1051                return expression
1052
1053            ranges = merge_ranges(ranges)
1054            target_type = extract_type(*rs)
1055
1056            return exp.or_(
1057                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
1058            )
1059
1060    return expression
1061
1062
1063def sort_comparison(expression: exp.Expression) -> exp.Expression:
1064    if expression.__class__ in COMPLEMENT_COMPARISONS:
1065        l, r = expression.this, expression.expression
1066        l_column = isinstance(l, exp.Column)
1067        r_column = isinstance(r, exp.Column)
1068        l_const = _is_constant(l)
1069        r_const = _is_constant(r)
1070
1071        if (
1072            (l_column and not r_column)
1073            or (r_const and not l_const)
1074            or isinstance(r, exp.SubqueryPredicate)
1075        ):
1076            return expression
1077        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1078            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1079                this=r, expression=l
1080            )
1081    return expression
1082
1083
1084# CROSS joins result in an empty table if the right table is empty.
1085# So we can only simplify certain types of joins to CROSS.
1086# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1087JOINS = {
1088    ("", ""),
1089    ("", "INNER"),
1090    ("RIGHT", ""),
1091    ("RIGHT", "OUTER"),
1092}
1093
1094
1095def remove_where_true(expression):
1096    for where in expression.find_all(exp.Where):
1097        if always_true(where.this):
1098            where.pop()
1099    for join in expression.find_all(exp.Join):
1100        if (
1101            always_true(join.args.get("on"))
1102            and not join.args.get("using")
1103            and not join.args.get("method")
1104            and (join.side, join.kind) in JOINS
1105        ):
1106            join.args["on"].pop()
1107            join.set("side", None)
1108            join.set("kind", "CROSS")
1109
1110
1111def always_true(expression):
1112    return (isinstance(expression, exp.Boolean) and expression.this) or (
1113        isinstance(expression, exp.Literal) and not is_zero(expression)
1114    )
1115
1116
1117def always_false(expression):
1118    return is_false(expression) or is_null(expression) or is_zero(expression)
1119
1120
1121def is_zero(expression):
1122    return isinstance(expression, exp.Literal) and expression.to_py() == 0
1123
1124
1125def is_complement(a, b):
1126    return isinstance(b, exp.Not) and b.this == a
1127
1128
1129def is_false(a: exp.Expression) -> bool:
1130    return type(a) is exp.Boolean and not a.this
1131
1132
1133def is_null(a: exp.Expression) -> bool:
1134    return type(a) is exp.Null
1135
1136
1137def eval_boolean(expression, a, b):
1138    if isinstance(expression, (exp.EQ, exp.Is)):
1139        return boolean_literal(a == b)
1140    if isinstance(expression, exp.NEQ):
1141        return boolean_literal(a != b)
1142    if isinstance(expression, exp.GT):
1143        return boolean_literal(a > b)
1144    if isinstance(expression, exp.GTE):
1145        return boolean_literal(a >= b)
1146    if isinstance(expression, exp.LT):
1147        return boolean_literal(a < b)
1148    if isinstance(expression, exp.LTE):
1149        return boolean_literal(a <= b)
1150    return None
1151
1152
1153def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1154    if isinstance(value, datetime.datetime):
1155        return value.date()
1156    if isinstance(value, datetime.date):
1157        return value
1158    try:
1159        return datetime.datetime.fromisoformat(value).date()
1160    except ValueError:
1161        return None
1162
1163
1164def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1165    if isinstance(value, datetime.datetime):
1166        return value
1167    if isinstance(value, datetime.date):
1168        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1169    try:
1170        return datetime.datetime.fromisoformat(value)
1171    except ValueError:
1172        return None
1173
1174
1175def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1176    if not value:
1177        return None
1178    if to.is_type(exp.DataType.Type.DATE):
1179        return cast_as_date(value)
1180    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1181        return cast_as_datetime(value)
1182    return None
1183
1184
1185def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1186    if isinstance(cast, exp.Cast):
1187        to = cast.to
1188    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1189        to = exp.DataType.build(exp.DataType.Type.DATE)
1190    else:
1191        return None
1192
1193    if isinstance(cast.this, exp.Literal):
1194        value: t.Any = cast.this.name
1195    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1196        value = extract_date(cast.this)
1197    else:
1198        return None
1199    return cast_value(value, to)
1200
1201
1202def _is_date_literal(expression: exp.Expression) -> bool:
1203    return extract_date(expression) is not None
1204
1205
1206def extract_interval(expression):
1207    try:
1208        n = int(expression.this.to_py())
1209        unit = expression.text("unit").lower()
1210        return interval(unit, n)
1211    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1212        return None
1213
1214
1215def extract_type(*expressions):
1216    target_type = None
1217    for expression in expressions:
1218        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1219        if target_type:
1220            break
1221
1222    return target_type
1223
1224
1225def date_literal(date, target_type=None):
1226    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1227        target_type = (
1228            exp.DataType.Type.DATETIME
1229            if isinstance(date, datetime.datetime)
1230            else exp.DataType.Type.DATE
1231        )
1232
1233    return exp.cast(exp.Literal.string(date), target_type)
1234
1235
1236def interval(unit: str, n: int = 1):
1237    from dateutil.relativedelta import relativedelta
1238
1239    if unit == "year":
1240        return relativedelta(years=1 * n)
1241    if unit == "quarter":
1242        return relativedelta(months=3 * n)
1243    if unit == "month":
1244        return relativedelta(months=1 * n)
1245    if unit == "week":
1246        return relativedelta(weeks=1 * n)
1247    if unit == "day":
1248        return relativedelta(days=1 * n)
1249    if unit == "hour":
1250        return relativedelta(hours=1 * n)
1251    if unit == "minute":
1252        return relativedelta(minutes=1 * n)
1253    if unit == "second":
1254        return relativedelta(seconds=1 * n)
1255
1256    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1257
1258
1259def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1260    if unit == "year":
1261        return d.replace(month=1, day=1)
1262    if unit == "quarter":
1263        if d.month <= 3:
1264            return d.replace(month=1, day=1)
1265        elif d.month <= 6:
1266            return d.replace(month=4, day=1)
1267        elif d.month <= 9:
1268            return d.replace(month=7, day=1)
1269        else:
1270            return d.replace(month=10, day=1)
1271    if unit == "month":
1272        return d.replace(month=d.month, day=1)
1273    if unit == "week":
1274        # Assuming week starts on Monday (0) and ends on Sunday (6)
1275        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1276    if unit == "day":
1277        return d
1278
1279    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1280
1281
1282def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1283    floor = date_floor(d, unit, dialect)
1284
1285    if floor == d:
1286        return d
1287
1288    return floor + interval(unit)
1289
1290
1291def boolean_literal(condition):
1292    return exp.true() if condition else exp.false()
1293
1294
1295def _flat_simplify(expression, simplifier, root=True):
1296    if root or not expression.same_parent:
1297        operands = []
1298        queue = deque(expression.flatten(unnest=False))
1299        size = len(queue)
1300
1301        while queue:
1302            a = queue.popleft()
1303
1304            for b in queue:
1305                result = simplifier(expression, a, b)
1306
1307                if result and result is not expression:
1308                    queue.remove(b)
1309                    queue.appendleft(result)
1310                    break
1311            else:
1312                operands.append(a)
1313
1314        if len(operands) < size:
1315            return functools.reduce(
1316                lambda a, b: expression.__class__(this=a, expression=b), operands
1317            )
1318    return expression
1319
1320
1321def gen(expression: t.Any, comments: bool = False) -> str:
1322    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1323
1324    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1325    generator is expensive so we have a bare minimum sql generator here.
1326
1327    Args:
1328        expression: the expression to convert into a SQL string.
1329        comments: whether to include the expression's comments.
1330    """
1331    return Gen().gen(expression, comments=comments)
1332
1333
1334class Gen:
1335    def __init__(self):
1336        self.stack = []
1337        self.sqls = []
1338
1339    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1340        self.stack = [expression]
1341        self.sqls.clear()
1342
1343        while self.stack:
1344            node = self.stack.pop()
1345
1346            if isinstance(node, exp.Expression):
1347                if comments and node.comments:
1348                    self.stack.append(f" /*{','.join(node.comments)}*/")
1349
1350                exp_handler_name = f"{node.key}_sql"
1351
1352                if hasattr(self, exp_handler_name):
1353                    getattr(self, exp_handler_name)(node)
1354                elif isinstance(node, exp.Func):
1355                    self._function(node)
1356                else:
1357                    key = node.key.upper()
1358                    self.stack.append(f"{key} " if self._args(node) else key)
1359            elif type(node) is list:
1360                for n in reversed(node):
1361                    if n is not None:
1362                        self.stack.extend((n, ","))
1363                if node:
1364                    self.stack.pop()
1365            else:
1366                if node is not None:
1367                    self.sqls.append(str(node))
1368
1369        return "".join(self.sqls)
1370
1371    def add_sql(self, e: exp.Add) -> None:
1372        self._binary(e, " + ")
1373
1374    def alias_sql(self, e: exp.Alias) -> None:
1375        self.stack.extend(
1376            (
1377                e.args.get("alias"),
1378                " AS ",
1379                e.args.get("this"),
1380            )
1381        )
1382
1383    def and_sql(self, e: exp.And) -> None:
1384        self._binary(e, " AND ")
1385
1386    def anonymous_sql(self, e: exp.Anonymous) -> None:
1387        this = e.this
1388        if isinstance(this, str):
1389            name = this.upper()
1390        elif isinstance(this, exp.Identifier):
1391            name = this.this
1392            name = f'"{name}"' if this.quoted else name.upper()
1393        else:
1394            raise ValueError(
1395                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1396            )
1397
1398        self.stack.extend(
1399            (
1400                ")",
1401                e.expressions,
1402                "(",
1403                name,
1404            )
1405        )
1406
1407    def between_sql(self, e: exp.Between) -> None:
1408        self.stack.extend(
1409            (
1410                e.args.get("high"),
1411                " AND ",
1412                e.args.get("low"),
1413                " BETWEEN ",
1414                e.this,
1415            )
1416        )
1417
1418    def boolean_sql(self, e: exp.Boolean) -> None:
1419        self.stack.append("TRUE" if e.this else "FALSE")
1420
1421    def bracket_sql(self, e: exp.Bracket) -> None:
1422        self.stack.extend(
1423            (
1424                "]",
1425                e.expressions,
1426                "[",
1427                e.this,
1428            )
1429        )
1430
1431    def column_sql(self, e: exp.Column) -> None:
1432        for p in reversed(e.parts):
1433            self.stack.extend((p, "."))
1434        self.stack.pop()
1435
1436    def datatype_sql(self, e: exp.DataType) -> None:
1437        self._args(e, 1)
1438        self.stack.append(f"{e.this.name} ")
1439
1440    def div_sql(self, e: exp.Div) -> None:
1441        self._binary(e, " / ")
1442
1443    def dot_sql(self, e: exp.Dot) -> None:
1444        self._binary(e, ".")
1445
1446    def eq_sql(self, e: exp.EQ) -> None:
1447        self._binary(e, " = ")
1448
1449    def from_sql(self, e: exp.From) -> None:
1450        self.stack.extend((e.this, "FROM "))
1451
1452    def gt_sql(self, e: exp.GT) -> None:
1453        self._binary(e, " > ")
1454
1455    def gte_sql(self, e: exp.GTE) -> None:
1456        self._binary(e, " >= ")
1457
1458    def identifier_sql(self, e: exp.Identifier) -> None:
1459        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1460
1461    def ilike_sql(self, e: exp.ILike) -> None:
1462        self._binary(e, " ILIKE ")
1463
1464    def in_sql(self, e: exp.In) -> None:
1465        self.stack.append(")")
1466        self._args(e, 1)
1467        self.stack.extend(
1468            (
1469                "(",
1470                " IN ",
1471                e.this,
1472            )
1473        )
1474
1475    def intdiv_sql(self, e: exp.IntDiv) -> None:
1476        self._binary(e, " DIV ")
1477
1478    def is_sql(self, e: exp.Is) -> None:
1479        self._binary(e, " IS ")
1480
1481    def like_sql(self, e: exp.Like) -> None:
1482        self._binary(e, " Like ")
1483
1484    def literal_sql(self, e: exp.Literal) -> None:
1485        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1486
1487    def lt_sql(self, e: exp.LT) -> None:
1488        self._binary(e, " < ")
1489
1490    def lte_sql(self, e: exp.LTE) -> None:
1491        self._binary(e, " <= ")
1492
1493    def mod_sql(self, e: exp.Mod) -> None:
1494        self._binary(e, " % ")
1495
1496    def mul_sql(self, e: exp.Mul) -> None:
1497        self._binary(e, " * ")
1498
1499    def neg_sql(self, e: exp.Neg) -> None:
1500        self._unary(e, "-")
1501
1502    def neq_sql(self, e: exp.NEQ) -> None:
1503        self._binary(e, " <> ")
1504
1505    def not_sql(self, e: exp.Not) -> None:
1506        self._unary(e, "NOT ")
1507
1508    def null_sql(self, e: exp.Null) -> None:
1509        self.stack.append("NULL")
1510
1511    def or_sql(self, e: exp.Or) -> None:
1512        self._binary(e, " OR ")
1513
1514    def paren_sql(self, e: exp.Paren) -> None:
1515        self.stack.extend(
1516            (
1517                ")",
1518                e.this,
1519                "(",
1520            )
1521        )
1522
1523    def sub_sql(self, e: exp.Sub) -> None:
1524        self._binary(e, " - ")
1525
1526    def subquery_sql(self, e: exp.Subquery) -> None:
1527        self._args(e, 2)
1528        alias = e.args.get("alias")
1529        if alias:
1530            self.stack.append(alias)
1531        self.stack.extend((")", e.this, "("))
1532
1533    def table_sql(self, e: exp.Table) -> None:
1534        self._args(e, 4)
1535        alias = e.args.get("alias")
1536        if alias:
1537            self.stack.append(alias)
1538        for p in reversed(e.parts):
1539            self.stack.extend((p, "."))
1540        self.stack.pop()
1541
1542    def tablealias_sql(self, e: exp.TableAlias) -> None:
1543        columns = e.columns
1544
1545        if columns:
1546            self.stack.extend((")", columns, "("))
1547
1548        self.stack.extend((e.this, " AS "))
1549
1550    def var_sql(self, e: exp.Var) -> None:
1551        self.stack.append(e.this)
1552
1553    def _binary(self, e: exp.Binary, op: str) -> None:
1554        self.stack.extend((e.expression, op, e.this))
1555
1556    def _unary(self, e: exp.Unary, op: str) -> None:
1557        self.stack.extend((e.this, op))
1558
1559    def _function(self, e: exp.Func) -> None:
1560        self.stack.extend(
1561            (
1562                ")",
1563                list(e.args.values()),
1564                "(",
1565                e.sql_name(),
1566            )
1567        )
1568
1569    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1570        kvs = []
1571        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1572
1573        for k in arg_types or arg_types:
1574            v = node.args.get(k)
1575
1576            if v is not None:
1577                kvs.append([f":{k}", v])
1578        if kvs:
1579            self.stack.append(kvs)
1580            return True
1581        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
36class UnsupportedUnit(Exception):
37    pass

Common base class for all non-exit exceptions.

def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None):
 40def simplify(
 41    expression: exp.Expression,
 42    constant_propagation: bool = False,
 43    dialect: DialectType = None,
 44):
 45    """
 46    Rewrite sqlglot AST to simplify expressions.
 47
 48    Example:
 49        >>> import sqlglot
 50        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 51        >>> simplify(expression).sql()
 52        'TRUE'
 53
 54    Args:
 55        expression: expression to simplify
 56        constant_propagation: whether the constant propagation rule should be used
 57    Returns:
 58        sqlglot.Expression: simplified expression
 59    """
 60
 61    dialect = Dialect.get_or_raise(dialect)
 62
 63    def _simplify(expression):
 64        pre_transformation_stack = [expression]
 65        post_transformation_stack = []
 66
 67        while pre_transformation_stack:
 68            node = pre_transformation_stack.pop()
 69
 70            if node.meta.get(FINAL):
 71                continue
 72
 73            # group by expressions cannot be simplified, for example
 74            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 75            # the projection must exactly match the group by key
 76            group = node.args.get("group")
 77
 78            if group and hasattr(node, "selects"):
 79                groups = set(group.expressions)
 80                group.meta[FINAL] = True
 81
 82                for s in node.selects:
 83                    for n in s.walk():
 84                        if n in groups:
 85                            s.meta[FINAL] = True
 86                            break
 87
 88                having = node.args.get("having")
 89                if having:
 90                    for n in having.walk():
 91                        if n in groups:
 92                            having.meta[FINAL] = True
 93                            break
 94
 95            parent = node.parent
 96            root = node is expression
 97
 98            new_node = rewrite_between(node)
 99            new_node = uniq_sort(new_node, root)
100            new_node = absorb_and_eliminate(new_node, root)
101            new_node = simplify_concat(new_node)
102            new_node = simplify_conditionals(new_node)
103
104            if constant_propagation:
105                new_node = propagate_constants(new_node, root)
106
107            if new_node is not node:
108                node.replace(new_node)
109
110            pre_transformation_stack.extend(
111                n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
112            )
113            post_transformation_stack.append((new_node, parent))
114
115        while post_transformation_stack:
116            node, parent = post_transformation_stack.pop()
117            root = node is expression
118
119            # Resets parent, arg_key, index pointers– this is needed because some of the
120            # previous transformations mutate the AST, leading to an inconsistent state
121            for k, v in tuple(node.args.items()):
122                node.set(k, v)
123
124            # Post-order transformations
125            new_node = simplify_not(node)
126            new_node = flatten(new_node)
127            new_node = simplify_connectors(new_node, root)
128            new_node = remove_complements(new_node, root)
129            new_node = simplify_coalesce(new_node, dialect)
130
131            new_node.parent = parent
132
133            new_node = simplify_literals(new_node, root)
134            new_node = simplify_equality(new_node)
135            new_node = simplify_parens(new_node)
136            new_node = simplify_datetrunc(new_node, dialect)
137            new_node = sort_comparison(new_node)
138            new_node = simplify_startswith(new_node)
139
140            if new_node is not node:
141                node.replace(new_node)
142
143        return new_node
144
145    expression = while_changing(expression, _simplify)
146    remove_where_true(expression)
147    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
150def catch(*exceptions):
151    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
152
153    def decorator(func):
154        def wrapped(expression, *args, **kwargs):
155            try:
156                return func(expression, *args, **kwargs)
157            except exceptions:
158                return expression
159
160        return wrapped
161
162    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
165def rewrite_between(expression: exp.Expression) -> exp.Expression:
166    """Rewrite x between y and z to x >= y AND x <= z.
167
168    This is done because comparison simplification is only done on lt/lte/gt/gte.
169    """
170    if isinstance(expression, exp.Between):
171        negate = isinstance(expression.parent, exp.Not)
172
173        expression = exp.and_(
174            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
175            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
176            copy=False,
177        )
178
179        if negate:
180            expression = exp.paren(expression, copy=False)
181
182    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

COMPLEMENT_SUBQUERY_PREDICATES = {<class 'sqlglot.expressions.All'>: <class 'sqlglot.expressions.Any'>, <class 'sqlglot.expressions.Any'>: <class 'sqlglot.expressions.All'>}
def simplify_not(expression):
200def simplify_not(expression):
201    """
202    Demorgan's Law
203    NOT (x OR y) -> NOT x AND NOT y
204    NOT (x AND y) -> NOT x OR NOT y
205    """
206    if isinstance(expression, exp.Not):
207        this = expression.this
208        if is_null(this):
209            return exp.null()
210        if this.__class__ in COMPLEMENT_COMPARISONS:
211            right = this.expression
212            complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__)
213            if complement_subquery_predicate:
214                right = complement_subquery_predicate(this=right.this)
215
216            return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
217        if isinstance(this, exp.Paren):
218            condition = this.unnest()
219            if isinstance(condition, exp.And):
220                return exp.paren(
221                    exp.or_(
222                        exp.not_(condition.left, copy=False),
223                        exp.not_(condition.right, copy=False),
224                        copy=False,
225                    )
226                )
227            if isinstance(condition, exp.Or):
228                return exp.paren(
229                    exp.and_(
230                        exp.not_(condition.left, copy=False),
231                        exp.not_(condition.right, copy=False),
232                        copy=False,
233                    )
234                )
235            if is_null(condition):
236                return exp.null()
237        if always_true(this):
238            return exp.false()
239        if is_false(this):
240            return exp.true()
241        if isinstance(this, exp.Not):
242            # double negation
243            # NOT NOT x -> x
244            return this.this
245    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
248def flatten(expression):
249    """
250    A AND (B AND C) -> A AND B AND C
251    A OR (B OR C) -> A OR B OR C
252    """
253    if isinstance(expression, exp.Connector):
254        for node in expression.args.values():
255            child = node.unnest()
256            if isinstance(child, expression.__class__):
257                node.replace(child)
258    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
261def simplify_connectors(expression, root=True):
262    def _simplify_connectors(expression, left, right):
263        if isinstance(expression, exp.And):
264            if is_false(left) or is_false(right):
265                return exp.false()
266            if is_zero(left) or is_zero(right):
267                return exp.false()
268            if is_null(left) or is_null(right):
269                return exp.null()
270            if always_true(left) and always_true(right):
271                return exp.true()
272            if always_true(left):
273                return right
274            if always_true(right):
275                return left
276            return _simplify_comparison(expression, left, right)
277        elif isinstance(expression, exp.Or):
278            if always_true(left) or always_true(right):
279                return exp.true()
280            if (
281                (is_null(left) and is_null(right))
282                or (is_null(left) and always_false(right))
283                or (always_false(left) and is_null(right))
284            ):
285                return exp.null()
286            if is_false(left):
287                return right
288            if is_false(right):
289                return left
290            return _simplify_comparison(expression, left, right, or_=True)
291        elif isinstance(expression, exp.Xor):
292            if left == right:
293                return exp.false()
294
295    if isinstance(expression, exp.Connector):
296        return _flat_simplify(expression, _simplify_connectors, root)
297    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
def remove_complements(expression, root=True):
384def remove_complements(expression, root=True):
385    """
386    Removing complements.
387
388    A AND NOT A -> FALSE
389    A OR NOT A -> TRUE
390    """
391    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
392        ops = set(expression.flatten())
393        for op in ops:
394            if isinstance(op, exp.Not) and op.this in ops:
395                return exp.false() if isinstance(expression, exp.And) else exp.true()
396
397    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
400def uniq_sort(expression, root=True):
401    """
402    Uniq and sort a connector.
403
404    C AND A AND B AND B -> A AND B AND C
405    """
406    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
407        flattened = tuple(expression.flatten())
408
409        if isinstance(expression, exp.Xor):
410            result_func = exp.xor
411            # Do not deduplicate XOR as A XOR A != A if A == True
412            deduped = None
413            arr = tuple((gen(e), e) for e in flattened)
414        else:
415            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
416            deduped = {gen(e): e for e in flattened}
417            arr = tuple(deduped.items())
418
419        # check if the operands are already sorted, if not sort them
420        # A AND C AND B -> A AND B AND C
421        for i, (sql, e) in enumerate(arr[1:]):
422            if sql < arr[i][0]:
423                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
424                break
425        else:
426            # we didn't have to sort but maybe we need to dedup
427            if deduped and len(deduped) < len(flattened):
428                expression = result_func(*deduped.values(), copy=False)
429
430    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
433def absorb_and_eliminate(expression, root=True):
434    """
435    absorption:
436        A AND (A OR B) -> A
437        A OR (A AND B) -> A
438        A AND (NOT A OR B) -> A AND B
439        A OR (NOT A AND B) -> A OR B
440    elimination:
441        (A AND B) OR (A AND NOT B) -> A
442        (A OR B) AND (A OR NOT B) -> A
443    """
444    if isinstance(expression, AND_OR) and (root or not expression.same_parent):
445        kind = exp.Or if isinstance(expression, exp.And) else exp.And
446
447        ops = tuple(expression.flatten())
448
449        # Initialize lookup tables:
450        # Set of all operands, used to find complements for absorption.
451        op_set = set()
452        # Sub-operands, used to find subsets for absorption.
453        subops = defaultdict(list)
454        # Pairs of complements, used for elimination.
455        pairs = defaultdict(list)
456
457        # Populate the lookup tables
458        for op in ops:
459            op_set.add(op)
460
461            if not isinstance(op, kind):
462                # In cases like: A OR (A AND B)
463                # Subop will be: ^
464                subops[op].append({op})
465                continue
466
467            # In cases like: (A AND B) OR (A AND B AND C)
468            # Subops will be: ^     ^
469            subset = set(op.flatten())
470            for i in subset:
471                subops[i].append(subset)
472
473            a, b = op.unnest_operands()
474            if isinstance(a, exp.Not):
475                pairs[frozenset((a.this, b))].append((op, b))
476            if isinstance(b, exp.Not):
477                pairs[frozenset((a, b.this))].append((op, a))
478
479        for op in ops:
480            if not isinstance(op, kind):
481                continue
482
483            a, b = op.unnest_operands()
484
485            # Absorb
486            if isinstance(a, exp.Not) and a.this in op_set:
487                a.replace(exp.true() if kind == exp.And else exp.false())
488                continue
489            if isinstance(b, exp.Not) and b.this in op_set:
490                b.replace(exp.true() if kind == exp.And else exp.false())
491                continue
492            superset = set(op.flatten())
493            if any(any(subset < superset for subset in subops[i]) for i in superset):
494                op.replace(exp.false() if kind == exp.And else exp.true())
495                continue
496
497            # Eliminate
498            for other, complement in pairs[frozenset((a, b))]:
499                op.replace(complement)
500                other.replace(complement)
501
502    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
505def propagate_constants(expression, root=True):
506    """
507    Propagate constants for conjunctions in DNF:
508
509    SELECT * FROM t WHERE a = b AND b = 5 becomes
510    SELECT * FROM t WHERE a = 5 AND b = 5
511
512    Reference: https://www.sqlite.org/optoverview.html
513    """
514
515    if (
516        isinstance(expression, exp.And)
517        and (root or not expression.same_parent)
518        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
519    ):
520        constant_mapping = {}
521        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
522            if isinstance(expr, exp.EQ):
523                l, r = expr.left, expr.right
524
525                # TODO: create a helper that can be used to detect nested literal expressions such
526                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
527                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
528                    constant_mapping[l] = (id(l), r)
529
530        if constant_mapping:
531            for column in find_all_in_scope(expression, exp.Column):
532                parent = column.parent
533                column_id, constant = constant_mapping.get(column) or (None, None)
534                if (
535                    column_id is not None
536                    and id(column) != column_id
537                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
538                ):
539                    column.replace(constant.copy())
540
541    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
154        def wrapped(expression, *args, **kwargs):
155            try:
156                return func(expression, *args, **kwargs)
157            except exceptions:
158                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
616def simplify_literals(expression, root=True):
617    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
618        return _flat_simplify(expression, _simplify_binary, root)
619
620    if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
621        return expression.this.this
622
623    if type(expression) in INVERSE_DATE_OPS:
624        return _simplify_binary(expression, expression.this, expression.interval()) or expression
625
626    return expression
def simplify_parens(expression):
727def simplify_parens(expression):
728    if not isinstance(expression, exp.Paren):
729        return expression
730
731    this = expression.this
732    parent = expression.parent
733    parent_is_predicate = isinstance(parent, exp.Predicate)
734
735    if (
736        not isinstance(this, exp.Select)
737        and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket))
738        and (
739            not isinstance(parent, (exp.Condition, exp.Binary))
740            or isinstance(parent, exp.Paren)
741            or (
742                not isinstance(this, exp.Binary)
743                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
744            )
745            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
746            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
747            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
748            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
749        )
750    ):
751        return this
752    return expression
def simplify_coalesce( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
763def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
764    # COALESCE(x) -> x
765    if (
766        isinstance(expression, exp.Coalesce)
767        and (not expression.expressions or _is_nonnull_constant(expression.this))
768        # COALESCE is also used as a Spark partitioning hint
769        and not isinstance(expression.parent, exp.Hint)
770    ):
771        return expression.this
772
773    # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift,
774    # because they are not always equivalent. For example,  if `x` is `NULL` and it comes
775    # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE`
776    if dialect == "redshift":
777        return expression
778
779    if not isinstance(expression, COMPARISONS):
780        return expression
781
782    if isinstance(expression.left, exp.Coalesce):
783        coalesce = expression.left
784        other = expression.right
785    elif isinstance(expression.right, exp.Coalesce):
786        coalesce = expression.right
787        other = expression.left
788    else:
789        return expression
790
791    # This transformation is valid for non-constants,
792    # but it really only does anything if they are both constants.
793    if not _is_constant(other):
794        return expression
795
796    # Find the first constant arg
797    for arg_index, arg in enumerate(coalesce.expressions):
798        if _is_constant(arg):
799            break
800    else:
801        return expression
802
803    coalesce.set("expressions", coalesce.expressions[:arg_index])
804
805    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
806    # since we already remove COALESCE at the top of this function.
807    coalesce = coalesce if coalesce.expressions else coalesce.this
808
809    # This expression is more complex than when we started, but it will get simplified further
810    return exp.paren(
811        exp.or_(
812            exp.and_(
813                coalesce.is_(exp.null()).not_(copy=False),
814                expression.copy(),
815                copy=False,
816            ),
817            exp.and_(
818                coalesce.is_(exp.null()),
819                type(expression)(this=arg.copy(), expression=other.copy()),
820                copy=False,
821            ),
822            copy=False,
823        )
824    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
830def simplify_concat(expression):
831    """Reduces all groups that contain string literals by concatenating them."""
832    if not isinstance(expression, CONCATS) or (
833        # We can't reduce a CONCAT_WS call if we don't statically know the separator
834        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
835    ):
836        return expression
837
838    if isinstance(expression, exp.ConcatWs):
839        sep_expr, *expressions = expression.expressions
840        sep = sep_expr.name
841        concat_type = exp.ConcatWs
842        args = {}
843    else:
844        expressions = expression.expressions
845        sep = ""
846        concat_type = exp.Concat
847        args = {
848            "safe": expression.args.get("safe"),
849            "coalesce": expression.args.get("coalesce"),
850        }
851
852    new_args = []
853    for is_string_group, group in itertools.groupby(
854        expressions or expression.flatten(), lambda e: e.is_string
855    ):
856        if is_string_group:
857            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
858        else:
859            new_args.extend(group)
860
861    if len(new_args) == 1 and new_args[0].is_string:
862        return new_args[0]
863
864    if concat_type is exp.ConcatWs:
865        new_args = [sep_expr] + new_args
866    elif isinstance(expression, exp.DPipe):
867        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
868
869    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
872def simplify_conditionals(expression):
873    """Simplifies expressions like IF, CASE if their condition is statically known."""
874    if isinstance(expression, exp.Case):
875        this = expression.this
876        for case in expression.args["ifs"]:
877            cond = case.this
878            if this:
879                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
880                cond = cond.replace(this.pop().eq(cond))
881
882            if always_true(cond):
883                return case.args["true"]
884
885            if always_false(cond):
886                case.pop()
887                if not expression.args["ifs"]:
888                    return expression.args.get("default") or exp.null()
889    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
890        if always_true(expression.this):
891            return expression.args["true"]
892        if always_false(expression.this):
893            return expression.args.get("false") or exp.null()
894
895    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
898def simplify_startswith(expression: exp.Expression) -> exp.Expression:
899    """
900    Reduces a prefix check to either TRUE or FALSE if both the string and the
901    prefix are statically known.
902
903    Example:
904        >>> from sqlglot import parse_one
905        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
906        'TRUE'
907    """
908    if (
909        isinstance(expression, exp.StartsWith)
910        and expression.this.is_string
911        and expression.expression.is_string
912    ):
913        return exp.convert(expression.name.startswith(expression.expression.name))
914
915    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.LT'>}
def simplify_datetrunc(expression, *args, **kwargs):
154        def wrapped(expression, *args, **kwargs):
155            try:
156                return func(expression, *args, **kwargs)
157            except exceptions:
158                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1064def sort_comparison(expression: exp.Expression) -> exp.Expression:
1065    if expression.__class__ in COMPLEMENT_COMPARISONS:
1066        l, r = expression.this, expression.expression
1067        l_column = isinstance(l, exp.Column)
1068        r_column = isinstance(r, exp.Column)
1069        l_const = _is_constant(l)
1070        r_const = _is_constant(r)
1071
1072        if (
1073            (l_column and not r_column)
1074            or (r_const and not l_const)
1075            or isinstance(r, exp.SubqueryPredicate)
1076        ):
1077            return expression
1078        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1079            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1080                this=r, expression=l
1081            )
1082    return expression
JOINS = {('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', '')}
def remove_where_true(expression):
1096def remove_where_true(expression):
1097    for where in expression.find_all(exp.Where):
1098        if always_true(where.this):
1099            where.pop()
1100    for join in expression.find_all(exp.Join):
1101        if (
1102            always_true(join.args.get("on"))
1103            and not join.args.get("using")
1104            and not join.args.get("method")
1105            and (join.side, join.kind) in JOINS
1106        ):
1107            join.args["on"].pop()
1108            join.set("side", None)
1109            join.set("kind", "CROSS")
def always_true(expression):
1112def always_true(expression):
1113    return (isinstance(expression, exp.Boolean) and expression.this) or (
1114        isinstance(expression, exp.Literal) and not is_zero(expression)
1115    )
def always_false(expression):
1118def always_false(expression):
1119    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
1122def is_zero(expression):
1123    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
1126def is_complement(a, b):
1127    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1130def is_false(a: exp.Expression) -> bool:
1131    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1134def is_null(a: exp.Expression) -> bool:
1135    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1138def eval_boolean(expression, a, b):
1139    if isinstance(expression, (exp.EQ, exp.Is)):
1140        return boolean_literal(a == b)
1141    if isinstance(expression, exp.NEQ):
1142        return boolean_literal(a != b)
1143    if isinstance(expression, exp.GT):
1144        return boolean_literal(a > b)
1145    if isinstance(expression, exp.GTE):
1146        return boolean_literal(a >= b)
1147    if isinstance(expression, exp.LT):
1148        return boolean_literal(a < b)
1149    if isinstance(expression, exp.LTE):
1150        return boolean_literal(a <= b)
1151    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1154def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1155    if isinstance(value, datetime.datetime):
1156        return value.date()
1157    if isinstance(value, datetime.date):
1158        return value
1159    try:
1160        return datetime.datetime.fromisoformat(value).date()
1161    except ValueError:
1162        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1165def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1166    if isinstance(value, datetime.datetime):
1167        return value
1168    if isinstance(value, datetime.date):
1169        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1170    try:
1171        return datetime.datetime.fromisoformat(value)
1172    except ValueError:
1173        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1176def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1177    if not value:
1178        return None
1179    if to.is_type(exp.DataType.Type.DATE):
1180        return cast_as_date(value)
1181    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1182        return cast_as_datetime(value)
1183    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1186def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1187    if isinstance(cast, exp.Cast):
1188        to = cast.to
1189    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1190        to = exp.DataType.build(exp.DataType.Type.DATE)
1191    else:
1192        return None
1193
1194    if isinstance(cast.this, exp.Literal):
1195        value: t.Any = cast.this.name
1196    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1197        value = extract_date(cast.this)
1198    else:
1199        return None
1200    return cast_value(value, to)
def extract_interval(expression):
1207def extract_interval(expression):
1208    try:
1209        n = int(expression.this.to_py())
1210        unit = expression.text("unit").lower()
1211        return interval(unit, n)
1212    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1213        return None
def extract_type(*expressions):
1216def extract_type(*expressions):
1217    target_type = None
1218    for expression in expressions:
1219        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1220        if target_type:
1221            break
1222
1223    return target_type
def date_literal(date, target_type=None):
1226def date_literal(date, target_type=None):
1227    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1228        target_type = (
1229            exp.DataType.Type.DATETIME
1230            if isinstance(date, datetime.datetime)
1231            else exp.DataType.Type.DATE
1232        )
1233
1234    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1237def interval(unit: str, n: int = 1):
1238    from dateutil.relativedelta import relativedelta
1239
1240    if unit == "year":
1241        return relativedelta(years=1 * n)
1242    if unit == "quarter":
1243        return relativedelta(months=3 * n)
1244    if unit == "month":
1245        return relativedelta(months=1 * n)
1246    if unit == "week":
1247        return relativedelta(weeks=1 * n)
1248    if unit == "day":
1249        return relativedelta(days=1 * n)
1250    if unit == "hour":
1251        return relativedelta(hours=1 * n)
1252    if unit == "minute":
1253        return relativedelta(minutes=1 * n)
1254    if unit == "second":
1255        return relativedelta(seconds=1 * n)
1256
1257    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1260def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1261    if unit == "year":
1262        return d.replace(month=1, day=1)
1263    if unit == "quarter":
1264        if d.month <= 3:
1265            return d.replace(month=1, day=1)
1266        elif d.month <= 6:
1267            return d.replace(month=4, day=1)
1268        elif d.month <= 9:
1269            return d.replace(month=7, day=1)
1270        else:
1271            return d.replace(month=10, day=1)
1272    if unit == "month":
1273        return d.replace(month=d.month, day=1)
1274    if unit == "week":
1275        # Assuming week starts on Monday (0) and ends on Sunday (6)
1276        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1277    if unit == "day":
1278        return d
1279
1280    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
1283def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1284    floor = date_floor(d, unit, dialect)
1285
1286    if floor == d:
1287        return d
1288
1289    return floor + interval(unit)
def boolean_literal(condition):
1292def boolean_literal(condition):
1293    return exp.true() if condition else exp.false()
def gen(expression: Any, comments: bool = False) -> str:
1322def gen(expression: t.Any, comments: bool = False) -> str:
1323    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1324
1325    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1326    generator is expensive so we have a bare minimum sql generator here.
1327
1328    Args:
1329        expression: the expression to convert into a SQL string.
1330        comments: whether to include the expression's comments.
1331    """
1332    return Gen().gen(expression, comments=comments)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

Arguments:
  • expression: the expression to convert into a SQL string.
  • comments: whether to include the expression's comments.
class Gen:
1335class Gen:
1336    def __init__(self):
1337        self.stack = []
1338        self.sqls = []
1339
1340    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1341        self.stack = [expression]
1342        self.sqls.clear()
1343
1344        while self.stack:
1345            node = self.stack.pop()
1346
1347            if isinstance(node, exp.Expression):
1348                if comments and node.comments:
1349                    self.stack.append(f" /*{','.join(node.comments)}*/")
1350
1351                exp_handler_name = f"{node.key}_sql"
1352
1353                if hasattr(self, exp_handler_name):
1354                    getattr(self, exp_handler_name)(node)
1355                elif isinstance(node, exp.Func):
1356                    self._function(node)
1357                else:
1358                    key = node.key.upper()
1359                    self.stack.append(f"{key} " if self._args(node) else key)
1360            elif type(node) is list:
1361                for n in reversed(node):
1362                    if n is not None:
1363                        self.stack.extend((n, ","))
1364                if node:
1365                    self.stack.pop()
1366            else:
1367                if node is not None:
1368                    self.sqls.append(str(node))
1369
1370        return "".join(self.sqls)
1371
1372    def add_sql(self, e: exp.Add) -> None:
1373        self._binary(e, " + ")
1374
1375    def alias_sql(self, e: exp.Alias) -> None:
1376        self.stack.extend(
1377            (
1378                e.args.get("alias"),
1379                " AS ",
1380                e.args.get("this"),
1381            )
1382        )
1383
1384    def and_sql(self, e: exp.And) -> None:
1385        self._binary(e, " AND ")
1386
1387    def anonymous_sql(self, e: exp.Anonymous) -> None:
1388        this = e.this
1389        if isinstance(this, str):
1390            name = this.upper()
1391        elif isinstance(this, exp.Identifier):
1392            name = this.this
1393            name = f'"{name}"' if this.quoted else name.upper()
1394        else:
1395            raise ValueError(
1396                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1397            )
1398
1399        self.stack.extend(
1400            (
1401                ")",
1402                e.expressions,
1403                "(",
1404                name,
1405            )
1406        )
1407
1408    def between_sql(self, e: exp.Between) -> None:
1409        self.stack.extend(
1410            (
1411                e.args.get("high"),
1412                " AND ",
1413                e.args.get("low"),
1414                " BETWEEN ",
1415                e.this,
1416            )
1417        )
1418
1419    def boolean_sql(self, e: exp.Boolean) -> None:
1420        self.stack.append("TRUE" if e.this else "FALSE")
1421
1422    def bracket_sql(self, e: exp.Bracket) -> None:
1423        self.stack.extend(
1424            (
1425                "]",
1426                e.expressions,
1427                "[",
1428                e.this,
1429            )
1430        )
1431
1432    def column_sql(self, e: exp.Column) -> None:
1433        for p in reversed(e.parts):
1434            self.stack.extend((p, "."))
1435        self.stack.pop()
1436
1437    def datatype_sql(self, e: exp.DataType) -> None:
1438        self._args(e, 1)
1439        self.stack.append(f"{e.this.name} ")
1440
1441    def div_sql(self, e: exp.Div) -> None:
1442        self._binary(e, " / ")
1443
1444    def dot_sql(self, e: exp.Dot) -> None:
1445        self._binary(e, ".")
1446
1447    def eq_sql(self, e: exp.EQ) -> None:
1448        self._binary(e, " = ")
1449
1450    def from_sql(self, e: exp.From) -> None:
1451        self.stack.extend((e.this, "FROM "))
1452
1453    def gt_sql(self, e: exp.GT) -> None:
1454        self._binary(e, " > ")
1455
1456    def gte_sql(self, e: exp.GTE) -> None:
1457        self._binary(e, " >= ")
1458
1459    def identifier_sql(self, e: exp.Identifier) -> None:
1460        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1461
1462    def ilike_sql(self, e: exp.ILike) -> None:
1463        self._binary(e, " ILIKE ")
1464
1465    def in_sql(self, e: exp.In) -> None:
1466        self.stack.append(")")
1467        self._args(e, 1)
1468        self.stack.extend(
1469            (
1470                "(",
1471                " IN ",
1472                e.this,
1473            )
1474        )
1475
1476    def intdiv_sql(self, e: exp.IntDiv) -> None:
1477        self._binary(e, " DIV ")
1478
1479    def is_sql(self, e: exp.Is) -> None:
1480        self._binary(e, " IS ")
1481
1482    def like_sql(self, e: exp.Like) -> None:
1483        self._binary(e, " Like ")
1484
1485    def literal_sql(self, e: exp.Literal) -> None:
1486        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1487
1488    def lt_sql(self, e: exp.LT) -> None:
1489        self._binary(e, " < ")
1490
1491    def lte_sql(self, e: exp.LTE) -> None:
1492        self._binary(e, " <= ")
1493
1494    def mod_sql(self, e: exp.Mod) -> None:
1495        self._binary(e, " % ")
1496
1497    def mul_sql(self, e: exp.Mul) -> None:
1498        self._binary(e, " * ")
1499
1500    def neg_sql(self, e: exp.Neg) -> None:
1501        self._unary(e, "-")
1502
1503    def neq_sql(self, e: exp.NEQ) -> None:
1504        self._binary(e, " <> ")
1505
1506    def not_sql(self, e: exp.Not) -> None:
1507        self._unary(e, "NOT ")
1508
1509    def null_sql(self, e: exp.Null) -> None:
1510        self.stack.append("NULL")
1511
1512    def or_sql(self, e: exp.Or) -> None:
1513        self._binary(e, " OR ")
1514
1515    def paren_sql(self, e: exp.Paren) -> None:
1516        self.stack.extend(
1517            (
1518                ")",
1519                e.this,
1520                "(",
1521            )
1522        )
1523
1524    def sub_sql(self, e: exp.Sub) -> None:
1525        self._binary(e, " - ")
1526
1527    def subquery_sql(self, e: exp.Subquery) -> None:
1528        self._args(e, 2)
1529        alias = e.args.get("alias")
1530        if alias:
1531            self.stack.append(alias)
1532        self.stack.extend((")", e.this, "("))
1533
1534    def table_sql(self, e: exp.Table) -> None:
1535        self._args(e, 4)
1536        alias = e.args.get("alias")
1537        if alias:
1538            self.stack.append(alias)
1539        for p in reversed(e.parts):
1540            self.stack.extend((p, "."))
1541        self.stack.pop()
1542
1543    def tablealias_sql(self, e: exp.TableAlias) -> None:
1544        columns = e.columns
1545
1546        if columns:
1547            self.stack.extend((")", columns, "("))
1548
1549        self.stack.extend((e.this, " AS "))
1550
1551    def var_sql(self, e: exp.Var) -> None:
1552        self.stack.append(e.this)
1553
1554    def _binary(self, e: exp.Binary, op: str) -> None:
1555        self.stack.extend((e.expression, op, e.this))
1556
1557    def _unary(self, e: exp.Unary, op: str) -> None:
1558        self.stack.extend((e.this, op))
1559
1560    def _function(self, e: exp.Func) -> None:
1561        self.stack.extend(
1562            (
1563                ")",
1564                list(e.args.values()),
1565                "(",
1566                e.sql_name(),
1567            )
1568        )
1569
1570    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1571        kvs = []
1572        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1573
1574        for k in arg_types or arg_types:
1575            v = node.args.get(k)
1576
1577            if v is not None:
1578                kvs.append([f":{k}", v])
1579        if kvs:
1580            self.stack.append(kvs)
1581            return True
1582        return False
stack
sqls
def gen( self, expression: sqlglot.expressions.Expression, comments: bool = False) -> str:
1340    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1341        self.stack = [expression]
1342        self.sqls.clear()
1343
1344        while self.stack:
1345            node = self.stack.pop()
1346
1347            if isinstance(node, exp.Expression):
1348                if comments and node.comments:
1349                    self.stack.append(f" /*{','.join(node.comments)}*/")
1350
1351                exp_handler_name = f"{node.key}_sql"
1352
1353                if hasattr(self, exp_handler_name):
1354                    getattr(self, exp_handler_name)(node)
1355                elif isinstance(node, exp.Func):
1356                    self._function(node)
1357                else:
1358                    key = node.key.upper()
1359                    self.stack.append(f"{key} " if self._args(node) else key)
1360            elif type(node) is list:
1361                for n in reversed(node):
1362                    if n is not None:
1363                        self.stack.extend((n, ","))
1364                if node:
1365                    self.stack.pop()
1366            else:
1367                if node is not None:
1368                    self.sqls.append(str(node))
1369
1370        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1372    def add_sql(self, e: exp.Add) -> None:
1373        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1375    def alias_sql(self, e: exp.Alias) -> None:
1376        self.stack.extend(
1377            (
1378                e.args.get("alias"),
1379                " AS ",
1380                e.args.get("this"),
1381            )
1382        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1384    def and_sql(self, e: exp.And) -> None:
1385        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1387    def anonymous_sql(self, e: exp.Anonymous) -> None:
1388        this = e.this
1389        if isinstance(this, str):
1390            name = this.upper()
1391        elif isinstance(this, exp.Identifier):
1392            name = this.this
1393            name = f'"{name}"' if this.quoted else name.upper()
1394        else:
1395            raise ValueError(
1396                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1397            )
1398
1399        self.stack.extend(
1400            (
1401                ")",
1402                e.expressions,
1403                "(",
1404                name,
1405            )
1406        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1408    def between_sql(self, e: exp.Between) -> None:
1409        self.stack.extend(
1410            (
1411                e.args.get("high"),
1412                " AND ",
1413                e.args.get("low"),
1414                " BETWEEN ",
1415                e.this,
1416            )
1417        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1419    def boolean_sql(self, e: exp.Boolean) -> None:
1420        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1422    def bracket_sql(self, e: exp.Bracket) -> None:
1423        self.stack.extend(
1424            (
1425                "]",
1426                e.expressions,
1427                "[",
1428                e.this,
1429            )
1430        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1432    def column_sql(self, e: exp.Column) -> None:
1433        for p in reversed(e.parts):
1434            self.stack.extend((p, "."))
1435        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1437    def datatype_sql(self, e: exp.DataType) -> None:
1438        self._args(e, 1)
1439        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1441    def div_sql(self, e: exp.Div) -> None:
1442        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1444    def dot_sql(self, e: exp.Dot) -> None:
1445        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1447    def eq_sql(self, e: exp.EQ) -> None:
1448        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1450    def from_sql(self, e: exp.From) -> None:
1451        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1453    def gt_sql(self, e: exp.GT) -> None:
1454        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1456    def gte_sql(self, e: exp.GTE) -> None:
1457        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1459    def identifier_sql(self, e: exp.Identifier) -> None:
1460        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1462    def ilike_sql(self, e: exp.ILike) -> None:
1463        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1465    def in_sql(self, e: exp.In) -> None:
1466        self.stack.append(")")
1467        self._args(e, 1)
1468        self.stack.extend(
1469            (
1470                "(",
1471                " IN ",
1472                e.this,
1473            )
1474        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1476    def intdiv_sql(self, e: exp.IntDiv) -> None:
1477        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1479    def is_sql(self, e: exp.Is) -> None:
1480        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1482    def like_sql(self, e: exp.Like) -> None:
1483        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1485    def literal_sql(self, e: exp.Literal) -> None:
1486        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1488    def lt_sql(self, e: exp.LT) -> None:
1489        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1491    def lte_sql(self, e: exp.LTE) -> None:
1492        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1494    def mod_sql(self, e: exp.Mod) -> None:
1495        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1497    def mul_sql(self, e: exp.Mul) -> None:
1498        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1500    def neg_sql(self, e: exp.Neg) -> None:
1501        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1503    def neq_sql(self, e: exp.NEQ) -> None:
1504        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1506    def not_sql(self, e: exp.Not) -> None:
1507        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1509    def null_sql(self, e: exp.Null) -> None:
1510        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1512    def or_sql(self, e: exp.Or) -> None:
1513        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1515    def paren_sql(self, e: exp.Paren) -> None:
1516        self.stack.extend(
1517            (
1518                ")",
1519                e.this,
1520                "(",
1521            )
1522        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1524    def sub_sql(self, e: exp.Sub) -> None:
1525        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1527    def subquery_sql(self, e: exp.Subquery) -> None:
1528        self._args(e, 2)
1529        alias = e.args.get("alias")
1530        if alias:
1531            self.stack.append(alias)
1532        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1534    def table_sql(self, e: exp.Table) -> None:
1535        self._args(e, 4)
1536        alias = e.args.get("alias")
1537        if alias:
1538            self.stack.append(alias)
1539        for p in reversed(e.parts):
1540            self.stack.extend((p, "."))
1541        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1543    def tablealias_sql(self, e: exp.TableAlias) -> None:
1544        columns = e.columns
1545
1546        if columns:
1547            self.stack.extend((")", columns, "("))
1548
1549        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1551    def var_sql(self, e: exp.Var) -> None:
1552        self.stack.append(e.this)