Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce, wraps
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.annotate_types import TypeAnnotator
  15from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  16from sqlglot.schema import ensure_schema
  17
  18if t.TYPE_CHECKING:
  19    from sqlglot.dialects.dialect import DialectType
  20
  21    DateRange = t.Tuple[datetime.date, datetime.date]
  22    DateTruncBinaryTransform = t.Callable[
  23        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  24    ]
  25
  26
  27logger = logging.getLogger("sqlglot")
  28
  29
  30# Final means that an expression should not be simplified
  31FINAL = "final"
  32
  33SIMPLIFIABLE = (
  34    exp.Binary,
  35    exp.Func,
  36    exp.Lambda,
  37    exp.Predicate,
  38    exp.Unary,
  39)
  40
  41
  42def simplify(
  43    expression: exp.Expression,
  44    constant_propagation: bool = False,
  45    coalesce_simplification: bool = False,
  46    dialect: DialectType = None,
  47):
  48    """
  49    Rewrite sqlglot AST to simplify expressions.
  50
  51    Example:
  52        >>> import sqlglot
  53        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  54        >>> simplify(expression).sql()
  55        'TRUE'
  56
  57    Args:
  58        expression: expression to simplify
  59        constant_propagation: whether the constant propagation rule should be used
  60        coalesce_simplification: whether the simplify coalesce rule should be used.
  61            This rule tries to remove coalesce functions, which can be useful in certain analyses but
  62            can leave the query more verbose.
  63    Returns:
  64        sqlglot.Expression: simplified expression
  65    """
  66    return Simplifier(dialect=dialect).simplify(
  67        expression,
  68        constant_propagation=constant_propagation,
  69        coalesce_simplification=coalesce_simplification,
  70    )
  71
  72
  73class UnsupportedUnit(Exception):
  74    pass
  75
  76
  77def catch(*exceptions):
  78    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
  79
  80    def decorator(func):
  81        def wrapped(expression, *args, **kwargs):
  82            try:
  83                return func(expression, *args, **kwargs)
  84            except exceptions:
  85                return expression
  86
  87        return wrapped
  88
  89    return decorator
  90
  91
  92def annotate_types_on_change(func):
  93    @wraps(func)
  94    def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]:
  95        new_expression = func(self, expression, *args, **kwargs)
  96
  97        if new_expression is None:
  98            return new_expression
  99
 100        if self.annotate_new_expressions and expression != new_expression:
 101            self._annotator.clear()
 102
 103            # We annotate this to ensure new children nodes are also annotated
 104            new_expression = self._annotator.annotate(
 105                expression=new_expression,
 106                annotate_scope=False,
 107            )
 108
 109            # Whatever expression the original expression is transformed into needs to preserve
 110            # the original type, otherwise the simplification could result in a different schema
 111            new_expression.type = expression.type
 112
 113        return new_expression
 114
 115    return _func
 116
 117
 118def flatten(expression):
 119    """
 120    A AND (B AND C) -> A AND B AND C
 121    A OR (B OR C) -> A OR B OR C
 122    """
 123    if isinstance(expression, exp.Connector):
 124        for node in expression.args.values():
 125            child = node.unnest()
 126            if isinstance(child, expression.__class__):
 127                node.replace(child)
 128    return expression
 129
 130
 131def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
 132    if not isinstance(expression, exp.Paren):
 133        return expression
 134
 135    this = expression.this
 136    parent = expression.parent
 137    parent_is_predicate = isinstance(parent, exp.Predicate)
 138
 139    if isinstance(this, exp.Select):
 140        return expression
 141
 142    if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
 143        return expression
 144
 145    if (
 146        Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
 147        and isinstance(parent, exp.Dot)
 148        and (isinstance(parent.right, (exp.Identifier, exp.Star)))
 149    ):
 150        return expression
 151
 152    if (
 153        not isinstance(parent, (exp.Condition, exp.Binary))
 154        or isinstance(parent, exp.Paren)
 155        or (
 156            not isinstance(this, exp.Binary)
 157            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 158        )
 159        or (
 160            isinstance(this, exp.Predicate)
 161            and not (parent_is_predicate or isinstance(parent, exp.Neg))
 162        )
 163        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 164        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 165        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 166    ):
 167        return this
 168
 169    return expression
 170
 171
 172def propagate_constants(expression, root=True):
 173    """
 174    Propagate constants for conjunctions in DNF:
 175
 176    SELECT * FROM t WHERE a = b AND b = 5 becomes
 177    SELECT * FROM t WHERE a = 5 AND b = 5
 178
 179    Reference: https://www.sqlite.org/optoverview.html
 180    """
 181
 182    if (
 183        isinstance(expression, exp.And)
 184        and (root or not expression.same_parent)
 185        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 186    ):
 187        constant_mapping = {}
 188        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 189            if isinstance(expr, exp.EQ):
 190                l, r = expr.left, expr.right
 191
 192                # TODO: create a helper that can be used to detect nested literal expressions such
 193                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 194                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 195                    constant_mapping[l] = (id(l), r)
 196
 197        if constant_mapping:
 198            for column in find_all_in_scope(expression, exp.Column):
 199                parent = column.parent
 200                column_id, constant = constant_mapping.get(column) or (None, None)
 201                if (
 202                    column_id is not None
 203                    and id(column) != column_id
 204                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 205                ):
 206                    column.replace(constant.copy())
 207
 208    return expression
 209
 210
 211def _is_number(expression: exp.Expression) -> bool:
 212    return expression.is_number
 213
 214
 215def _is_interval(expression: exp.Expression) -> bool:
 216    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 217
 218
 219def _is_nonnull_constant(expression: exp.Expression) -> bool:
 220    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 221
 222
 223def _is_constant(expression: exp.Expression) -> bool:
 224    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 225
 226
 227def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 228    """
 229    Get the date range for a DATE_TRUNC equality comparison:
 230
 231    Example:
 232        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 233    Returns:
 234        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 235    """
 236    floor = date_floor(date, unit, dialect)
 237
 238    if date != floor:
 239        # This will always be False, except for NULL values.
 240        return None
 241
 242    return floor, floor + interval(unit)
 243
 244
 245def _datetrunc_eq_expression(
 246    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 247) -> exp.Expression:
 248    """Get the logical expression for a date range"""
 249    return exp.and_(
 250        left >= date_literal(drange[0], target_type),
 251        left < date_literal(drange[1], target_type),
 252        copy=False,
 253    )
 254
 255
 256def _datetrunc_eq(
 257    left: exp.Expression,
 258    date: datetime.date,
 259    unit: str,
 260    dialect: Dialect,
 261    target_type: t.Optional[exp.DataType],
 262) -> t.Optional[exp.Expression]:
 263    drange = _datetrunc_range(date, unit, dialect)
 264    if not drange:
 265        return None
 266
 267    return _datetrunc_eq_expression(left, drange, target_type)
 268
 269
 270def _datetrunc_neq(
 271    left: exp.Expression,
 272    date: datetime.date,
 273    unit: str,
 274    dialect: Dialect,
 275    target_type: t.Optional[exp.DataType],
 276) -> t.Optional[exp.Expression]:
 277    drange = _datetrunc_range(date, unit, dialect)
 278    if not drange:
 279        return None
 280
 281    return exp.and_(
 282        left < date_literal(drange[0], target_type),
 283        left >= date_literal(drange[1], target_type),
 284        copy=False,
 285    )
 286
 287
 288def always_true(expression):
 289    return (isinstance(expression, exp.Boolean) and expression.this) or (
 290        isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression)
 291    )
 292
 293
 294def always_false(expression):
 295    return is_false(expression) or is_null(expression) or is_zero(expression)
 296
 297
 298def is_zero(expression):
 299    return isinstance(expression, exp.Literal) and expression.to_py() == 0
 300
 301
 302def is_complement(a, b):
 303    return isinstance(b, exp.Not) and b.this == a
 304
 305
 306def is_false(a: exp.Expression) -> bool:
 307    return type(a) is exp.Boolean and not a.this
 308
 309
 310def is_null(a: exp.Expression) -> bool:
 311    return type(a) is exp.Null
 312
 313
 314def eval_boolean(expression, a, b):
 315    if isinstance(expression, (exp.EQ, exp.Is)):
 316        return boolean_literal(a == b)
 317    if isinstance(expression, exp.NEQ):
 318        return boolean_literal(a != b)
 319    if isinstance(expression, exp.GT):
 320        return boolean_literal(a > b)
 321    if isinstance(expression, exp.GTE):
 322        return boolean_literal(a >= b)
 323    if isinstance(expression, exp.LT):
 324        return boolean_literal(a < b)
 325    if isinstance(expression, exp.LTE):
 326        return boolean_literal(a <= b)
 327    return None
 328
 329
 330def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 331    if isinstance(value, datetime.datetime):
 332        return value.date()
 333    if isinstance(value, datetime.date):
 334        return value
 335    try:
 336        return datetime.datetime.fromisoformat(value).date()
 337    except ValueError:
 338        return None
 339
 340
 341def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 342    if isinstance(value, datetime.datetime):
 343        return value
 344    if isinstance(value, datetime.date):
 345        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 346    try:
 347        return datetime.datetime.fromisoformat(value)
 348    except ValueError:
 349        return None
 350
 351
 352def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 353    if not value:
 354        return None
 355    if to.is_type(exp.DataType.Type.DATE):
 356        return cast_as_date(value)
 357    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 358        return cast_as_datetime(value)
 359    return None
 360
 361
 362def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 363    if isinstance(cast, exp.Cast):
 364        to = cast.to
 365    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
 366        to = exp.DataType.build(exp.DataType.Type.DATE)
 367    else:
 368        return None
 369
 370    if isinstance(cast.this, exp.Literal):
 371        value: t.Any = cast.this.name
 372    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 373        value = extract_date(cast.this)
 374    else:
 375        return None
 376    return cast_value(value, to)
 377
 378
 379def _is_date_literal(expression: exp.Expression) -> bool:
 380    return extract_date(expression) is not None
 381
 382
 383def extract_interval(expression):
 384    try:
 385        n = int(expression.this.to_py())
 386        unit = expression.text("unit").lower()
 387        return interval(unit, n)
 388    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
 389        return None
 390
 391
 392def extract_type(*expressions):
 393    target_type = None
 394    for expression in expressions:
 395        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
 396        if target_type:
 397            break
 398
 399    return target_type
 400
 401
 402def date_literal(date, target_type=None):
 403    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
 404        target_type = (
 405            exp.DataType.Type.DATETIME
 406            if isinstance(date, datetime.datetime)
 407            else exp.DataType.Type.DATE
 408        )
 409
 410    return exp.cast(exp.Literal.string(date), target_type)
 411
 412
 413def interval(unit: str, n: int = 1):
 414    from dateutil.relativedelta import relativedelta
 415
 416    if unit == "year":
 417        return relativedelta(years=1 * n)
 418    if unit == "quarter":
 419        return relativedelta(months=3 * n)
 420    if unit == "month":
 421        return relativedelta(months=1 * n)
 422    if unit == "week":
 423        return relativedelta(weeks=1 * n)
 424    if unit == "day":
 425        return relativedelta(days=1 * n)
 426    if unit == "hour":
 427        return relativedelta(hours=1 * n)
 428    if unit == "minute":
 429        return relativedelta(minutes=1 * n)
 430    if unit == "second":
 431        return relativedelta(seconds=1 * n)
 432
 433    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 434
 435
 436def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
 437    if unit == "year":
 438        return d.replace(month=1, day=1)
 439    if unit == "quarter":
 440        if d.month <= 3:
 441            return d.replace(month=1, day=1)
 442        elif d.month <= 6:
 443            return d.replace(month=4, day=1)
 444        elif d.month <= 9:
 445            return d.replace(month=7, day=1)
 446        else:
 447            return d.replace(month=10, day=1)
 448    if unit == "month":
 449        return d.replace(month=d.month, day=1)
 450    if unit == "week":
 451        # Assuming week starts on Monday (0) and ends on Sunday (6)
 452        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
 453    if unit == "day":
 454        return d
 455
 456    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 457
 458
 459def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
 460    floor = date_floor(d, unit, dialect)
 461
 462    if floor == d:
 463        return d
 464
 465    return floor + interval(unit)
 466
 467
 468def boolean_literal(condition):
 469    return exp.true() if condition else exp.false()
 470
 471
 472class Simplifier:
 473    def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
 474        self.dialect = Dialect.get_or_raise(dialect)
 475        self.annotate_new_expressions = annotate_new_expressions
 476
 477        self._annotator: TypeAnnotator = TypeAnnotator(
 478            schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
 479        )
 480
 481    # Value ranges for byte-sized signed/unsigned integers
 482    TINYINT_MIN = -128
 483    TINYINT_MAX = 127
 484    UTINYINT_MIN = 0
 485    UTINYINT_MAX = 255
 486
 487    COMPLEMENT_COMPARISONS = {
 488        exp.LT: exp.GTE,
 489        exp.GT: exp.LTE,
 490        exp.LTE: exp.GT,
 491        exp.GTE: exp.LT,
 492        exp.EQ: exp.NEQ,
 493        exp.NEQ: exp.EQ,
 494    }
 495
 496    COMPLEMENT_SUBQUERY_PREDICATES = {
 497        exp.All: exp.Any,
 498        exp.Any: exp.All,
 499    }
 500
 501    LT_LTE = (exp.LT, exp.LTE)
 502    GT_GTE = (exp.GT, exp.GTE)
 503
 504    COMPARISONS = (
 505        *LT_LTE,
 506        *GT_GTE,
 507        exp.EQ,
 508        exp.NEQ,
 509        exp.Is,
 510    )
 511
 512    INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 513        exp.LT: exp.GT,
 514        exp.GT: exp.LT,
 515        exp.LTE: exp.GTE,
 516        exp.GTE: exp.LTE,
 517    }
 518
 519    NONDETERMINISTIC = (exp.Rand, exp.Randn)
 520    AND_OR = (exp.And, exp.Or)
 521
 522    INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 523        exp.DateAdd: exp.Sub,
 524        exp.DateSub: exp.Add,
 525        exp.DatetimeAdd: exp.Sub,
 526        exp.DatetimeSub: exp.Add,
 527    }
 528
 529    INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 530        **INVERSE_DATE_OPS,
 531        exp.Add: exp.Sub,
 532        exp.Sub: exp.Add,
 533    }
 534
 535    NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 536
 537    CONCATS = (exp.Concat, exp.DPipe)
 538
 539    DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 540        exp.LT: lambda l, dt, u, d, t: l
 541        < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 542        exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 543        exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 544        exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 545        exp.EQ: _datetrunc_eq,
 546        exp.NEQ: _datetrunc_neq,
 547    }
 548
 549    DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 550    DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 551
 552    SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean)
 553
 554    # CROSS joins result in an empty table if the right table is empty.
 555    # So we can only simplify certain types of joins to CROSS.
 556    # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 557    JOINS = {
 558        ("", ""),
 559        ("", "INNER"),
 560        ("RIGHT", ""),
 561        ("RIGHT", "OUTER"),
 562    }
 563
 564    def simplify(
 565        self,
 566        expression: exp.Expression,
 567        constant_propagation: bool = False,
 568        coalesce_simplification: bool = False,
 569    ):
 570        wheres = []
 571        joins = []
 572
 573        for node in expression.walk(
 574            prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
 575        ):
 576            if node.meta.get(FINAL):
 577                continue
 578
 579            # group by expressions cannot be simplified, for example
 580            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 581            # the projection must exactly match the group by key
 582            group = node.args.get("group")
 583
 584            if group and hasattr(node, "selects"):
 585                groups = set(group.expressions)
 586                group.meta[FINAL] = True
 587
 588                for s in node.selects:
 589                    for n in s.walk(FINAL):
 590                        if n in groups:
 591                            s.meta[FINAL] = True
 592                            break
 593
 594                having = node.args.get("having")
 595
 596                if having:
 597                    for n in having.walk():
 598                        if n in groups:
 599                            having.meta[FINAL] = True
 600                            break
 601
 602            if isinstance(node, exp.Condition):
 603                simplified = while_changing(
 604                    node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
 605                )
 606
 607                if node is expression:
 608                    expression = simplified
 609            elif isinstance(node, exp.Where):
 610                wheres.append(node)
 611            elif isinstance(node, exp.Join):
 612                # snowflake match_conditions have very strict ordering rules
 613                if match := node.args.get("match_condition"):
 614                    match.meta[FINAL] = True
 615
 616                joins.append(node)
 617
 618        for where in wheres:
 619            if always_true(where.this):
 620                where.pop()
 621        for join in joins:
 622            if (
 623                always_true(join.args.get("on"))
 624                and not join.args.get("using")
 625                and not join.args.get("method")
 626                and (join.side, join.kind) in self.JOINS
 627            ):
 628                join.args["on"].pop()
 629                join.set("side", None)
 630                join.set("kind", "CROSS")
 631
 632        return expression
 633
 634    def _simplify(
 635        self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool
 636    ):
 637        pre_transformation_stack = [expression]
 638        post_transformation_stack = []
 639
 640        while pre_transformation_stack:
 641            original = pre_transformation_stack.pop()
 642            node = original
 643
 644            if not isinstance(node, SIMPLIFIABLE):
 645                if isinstance(node, exp.Query):
 646                    self.simplify(node, constant_propagation, coalesce_simplification)
 647                continue
 648
 649            parent = node.parent
 650            root = node is expression
 651
 652            node = self.rewrite_between(node)
 653            node = self.uniq_sort(node, root)
 654            node = self.absorb_and_eliminate(node, root)
 655            node = self.simplify_concat(node)
 656            node = self.simplify_conditionals(node)
 657
 658            if constant_propagation:
 659                node = propagate_constants(node, root)
 660
 661            if node is not original:
 662                original.replace(node)
 663
 664            for n in node.iter_expressions(reverse=True):
 665                if n.meta.get(FINAL):
 666                    raise
 667            pre_transformation_stack.extend(
 668                n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 669            )
 670            post_transformation_stack.append((node, parent))
 671
 672        while post_transformation_stack:
 673            original, parent = post_transformation_stack.pop()
 674            root = original is expression
 675
 676            # Resets parent, arg_key, index pointers– this is needed because some of the
 677            # previous transformations mutate the AST, leading to an inconsistent state
 678            for k, v in tuple(original.args.items()):
 679                original.set(k, v)
 680
 681            # Post-order transformations
 682            node = self.simplify_not(original)
 683            node = flatten(node)
 684            node = self.simplify_connectors(node, root)
 685            node = self.remove_complements(node, root)
 686
 687            if coalesce_simplification:
 688                node = self.simplify_coalesce(node)
 689            node.parent = parent
 690
 691            node = self.simplify_literals(node, root)
 692            node = self.simplify_equality(node)
 693            node = simplify_parens(node, dialect=self.dialect)
 694            node = self.simplify_datetrunc(node)
 695            node = self.sort_comparison(node)
 696            node = self.simplify_startswith(node)
 697
 698            if node is not original:
 699                original.replace(node)
 700
 701        return node
 702
 703    @annotate_types_on_change
 704    def rewrite_between(self, expression: exp.Expression) -> exp.Expression:
 705        """Rewrite x between y and z to x >= y AND x <= z.
 706
 707        This is done because comparison simplification is only done on lt/lte/gt/gte.
 708        """
 709        if isinstance(expression, exp.Between):
 710            negate = isinstance(expression.parent, exp.Not)
 711
 712            expression = exp.and_(
 713                exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 714                exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 715                copy=False,
 716            )
 717
 718            if negate:
 719                expression = exp.paren(expression, copy=False)
 720
 721        return expression
 722
 723    @annotate_types_on_change
 724    def simplify_not(self, expression: exp.Expression) -> exp.Expression:
 725        """
 726        Demorgan's Law
 727        NOT (x OR y) -> NOT x AND NOT y
 728        NOT (x AND y) -> NOT x OR NOT y
 729        """
 730        if isinstance(expression, exp.Not):
 731            this = expression.this
 732            if is_null(this):
 733                return exp.and_(exp.null(), exp.true(), copy=False)
 734            if this.__class__ in self.COMPLEMENT_COMPARISONS:
 735                right = this.expression
 736                complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
 737                    right.__class__
 738                )
 739                if complement_subquery_predicate:
 740                    right = complement_subquery_predicate(this=right.this)
 741
 742                return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 743            if isinstance(this, exp.Paren):
 744                condition = this.unnest()
 745                if isinstance(condition, exp.And):
 746                    return exp.paren(
 747                        exp.or_(
 748                            exp.not_(condition.left, copy=False),
 749                            exp.not_(condition.right, copy=False),
 750                            copy=False,
 751                        ),
 752                        copy=False,
 753                    )
 754                if isinstance(condition, exp.Or):
 755                    return exp.paren(
 756                        exp.and_(
 757                            exp.not_(condition.left, copy=False),
 758                            exp.not_(condition.right, copy=False),
 759                            copy=False,
 760                        ),
 761                        copy=False,
 762                    )
 763                if is_null(condition):
 764                    return exp.and_(exp.null(), exp.true(), copy=False)
 765            if always_true(this):
 766                return exp.false()
 767            if is_false(this):
 768                return exp.true()
 769            if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
 770                inner = this.this
 771                if inner.is_type(exp.DataType.Type.BOOLEAN):
 772                    # double negation
 773                    # NOT NOT x -> x, if x is BOOLEAN type
 774                    return inner
 775        return expression
 776
 777    @annotate_types_on_change
 778    def simplify_connectors(self, expression, root=True):
 779        def _simplify_connectors(expression, left, right):
 780            if isinstance(expression, exp.And):
 781                if is_false(left) or is_false(right):
 782                    return exp.false()
 783                if is_zero(left) or is_zero(right):
 784                    return exp.false()
 785                if (
 786                    (is_null(left) and is_null(right))
 787                    or (is_null(left) and always_true(right))
 788                    or (always_true(left) and is_null(right))
 789                ):
 790                    return exp.null()
 791                if always_true(left) and always_true(right):
 792                    return exp.true()
 793                if always_true(left):
 794                    return right
 795                if always_true(right):
 796                    return left
 797                return self._simplify_comparison(expression, left, right)
 798            elif isinstance(expression, exp.Or):
 799                if always_true(left) or always_true(right):
 800                    return exp.true()
 801                if (
 802                    (is_null(left) and is_null(right))
 803                    or (is_null(left) and always_false(right))
 804                    or (always_false(left) and is_null(right))
 805                ):
 806                    return exp.null()
 807                if is_false(left):
 808                    return right
 809                if is_false(right):
 810                    return left
 811                return self._simplify_comparison(expression, left, right, or_=True)
 812
 813        if isinstance(expression, exp.Connector):
 814            original_parent = expression.parent
 815            expression = self._flat_simplify(expression, _simplify_connectors, root)
 816
 817            # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
 818            # to ensure that the resulting type is boolean. We know this is true only for connectors,
 819            # boolean values and columns that are essentially operands to a connector:
 820            #
 821            # A AND (((B)))
 822            #          ~ this is safe to keep because it will eventually be part of another connector
 823            if not isinstance(
 824                expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
 825            ) and not expression.is_type(exp.DataType.Type.BOOLEAN):
 826                while True:
 827                    if isinstance(original_parent, exp.Connector):
 828                        break
 829                    if not isinstance(original_parent, exp.Paren):
 830                        expression = expression.and_(exp.true(), copy=False)
 831                        break
 832
 833                    original_parent = original_parent.parent
 834
 835        return expression
 836
 837    @annotate_types_on_change
 838    def _simplify_comparison(self, expression, left, right, or_=False):
 839        if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
 840            ll, lr = left.args.values()
 841            rl, rr = right.args.values()
 842
 843            largs = {ll, lr}
 844            rargs = {rl, rr}
 845
 846            matching = largs & rargs
 847            columns = {
 848                m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
 849            }
 850
 851            if matching and columns:
 852                try:
 853                    l = first(largs - columns)
 854                    r = first(rargs - columns)
 855                except StopIteration:
 856                    return expression
 857
 858                if l.is_number and r.is_number:
 859                    l = l.to_py()
 860                    r = r.to_py()
 861                elif l.is_string and r.is_string:
 862                    l = l.name
 863                    r = r.name
 864                else:
 865                    l = extract_date(l)
 866                    if not l:
 867                        return None
 868                    r = extract_date(r)
 869                    if not r:
 870                        return None
 871                    # python won't compare date and datetime, but many engines will upcast
 872                    l, r = cast_as_datetime(l), cast_as_datetime(r)
 873
 874                for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 875                    if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
 876                        return left if (av > bv if or_ else av <= bv) else right
 877                    if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
 878                        return left if (av < bv if or_ else av >= bv) else right
 879
 880                    # we can't ever shortcut to true because the column could be null
 881                    if not or_:
 882                        if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
 883                            if av <= bv:
 884                                return exp.false()
 885                        elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
 886                            if av >= bv:
 887                                return exp.false()
 888                        elif isinstance(a, exp.EQ):
 889                            if isinstance(b, exp.LT):
 890                                return exp.false() if av >= bv else a
 891                            if isinstance(b, exp.LTE):
 892                                return exp.false() if av > bv else a
 893                            if isinstance(b, exp.GT):
 894                                return exp.false() if av <= bv else a
 895                            if isinstance(b, exp.GTE):
 896                                return exp.false() if av < bv else a
 897                            if isinstance(b, exp.NEQ):
 898                                return exp.false() if av == bv else a
 899        return None
 900
 901    @annotate_types_on_change
 902    def remove_complements(self, expression, root=True):
 903        """
 904        Removing complements.
 905
 906        A AND NOT A -> FALSE (only for non-NULL A)
 907        A OR NOT A -> TRUE (only for non-NULL A)
 908        """
 909        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 910            ops = set(expression.flatten())
 911            for op in ops:
 912                if isinstance(op, exp.Not) and op.this in ops:
 913                    if expression.meta.get("nonnull") is True:
 914                        return exp.false() if isinstance(expression, exp.And) else exp.true()
 915
 916        return expression
 917
 918    @annotate_types_on_change
 919    def uniq_sort(self, expression, root=True):
 920        """
 921        Uniq and sort a connector.
 922
 923        C AND A AND B AND B -> A AND B AND C
 924        """
 925        if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 926            flattened = tuple(expression.flatten())
 927
 928            if isinstance(expression, exp.Xor):
 929                result_func = exp.xor
 930                # Do not deduplicate XOR as A XOR A != A if A == True
 931                deduped = None
 932                arr = tuple((gen(e), e) for e in flattened)
 933            else:
 934                result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 935                deduped = {gen(e): e for e in flattened}
 936                arr = tuple(deduped.items())
 937
 938            # check if the operands are already sorted, if not sort them
 939            # A AND C AND B -> A AND B AND C
 940            for i, (sql, e) in enumerate(arr[1:]):
 941                if sql < arr[i][0]:
 942                    expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 943                    break
 944            else:
 945                # we didn't have to sort but maybe we need to dedup
 946                if deduped and len(deduped) < len(flattened):
 947                    unique_operand = flattened[0]
 948                    if len(deduped) == 1:
 949                        expression = unique_operand.and_(exp.true(), copy=False)
 950                    else:
 951                        expression = result_func(*deduped.values(), copy=False)
 952
 953        return expression
 954
 955    @annotate_types_on_change
 956    def absorb_and_eliminate(self, expression, root=True):
 957        """
 958        absorption:
 959            A AND (A OR B) -> A
 960            A OR (A AND B) -> A
 961            A AND (NOT A OR B) -> A AND B
 962            A OR (NOT A AND B) -> A OR B
 963        elimination:
 964            (A AND B) OR (A AND NOT B) -> A
 965            (A OR B) AND (A OR NOT B) -> A
 966        """
 967        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 968            kind = exp.Or if isinstance(expression, exp.And) else exp.And
 969
 970            ops = tuple(expression.flatten())
 971
 972            # Initialize lookup tables:
 973            # Set of all operands, used to find complements for absorption.
 974            op_set = set()
 975            # Sub-operands, used to find subsets for absorption.
 976            subops = defaultdict(list)
 977            # Pairs of complements, used for elimination.
 978            pairs = defaultdict(list)
 979
 980            # Populate the lookup tables
 981            for op in ops:
 982                op_set.add(op)
 983
 984                if not isinstance(op, kind):
 985                    # In cases like: A OR (A AND B)
 986                    # Subop will be: ^
 987                    subops[op].append({op})
 988                    continue
 989
 990                # In cases like: (A AND B) OR (A AND B AND C)
 991                # Subops will be: ^     ^
 992                subset = set(op.flatten())
 993                for i in subset:
 994                    subops[i].append(subset)
 995
 996                a, b = op.unnest_operands()
 997                if isinstance(a, exp.Not):
 998                    pairs[frozenset((a.this, b))].append((op, b))
 999                if isinstance(b, exp.Not):
1000                    pairs[frozenset((a, b.this))].append((op, a))
1001
1002            for op in ops:
1003                if not isinstance(op, kind):
1004                    continue
1005
1006                a, b = op.unnest_operands()
1007
1008                # Absorb
1009                if isinstance(a, exp.Not) and a.this in op_set:
1010                    a.replace(exp.true() if kind == exp.And else exp.false())
1011                    continue
1012                if isinstance(b, exp.Not) and b.this in op_set:
1013                    b.replace(exp.true() if kind == exp.And else exp.false())
1014                    continue
1015                superset = set(op.flatten())
1016                if any(any(subset < superset for subset in subops[i]) for i in superset):
1017                    op.replace(exp.false() if kind == exp.And else exp.true())
1018                    continue
1019
1020                # Eliminate
1021                for other, complement in pairs[frozenset((a, b))]:
1022                    op.replace(complement)
1023                    other.replace(complement)
1024
1025        return expression
1026
1027    @annotate_types_on_change
1028    @catch(ModuleNotFoundError, UnsupportedUnit)
1029    def simplify_equality(self, expression: exp.Expression) -> exp.Expression:
1030        """
1031        Use the subtraction and addition properties of equality to simplify expressions:
1032
1033            x + 1 = 3 becomes x = 2
1034
1035        There are two binary operations in the above expression: + and =
1036        Here's how we reference all the operands in the code below:
1037
1038            l     r
1039            x + 1 = 3
1040            a   b
1041        """
1042        if isinstance(expression, self.COMPARISONS):
1043            l, r = expression.left, expression.right
1044
1045            if l.__class__ not in self.INVERSE_OPS:
1046                return expression
1047
1048            if r.is_number:
1049                a_predicate = _is_number
1050                b_predicate = _is_number
1051            elif _is_date_literal(r):
1052                a_predicate = _is_date_literal
1053                b_predicate = _is_interval
1054            else:
1055                return expression
1056
1057            if l.__class__ in self.INVERSE_DATE_OPS:
1058                l = t.cast(exp.IntervalOp, l)
1059                a = l.this
1060                b = l.interval()
1061            else:
1062                l = t.cast(exp.Binary, l)
1063                a, b = l.left, l.right
1064
1065            if not a_predicate(a) and b_predicate(b):
1066                pass
1067            elif not a_predicate(b) and b_predicate(a):
1068                a, b = b, a
1069            else:
1070                return expression
1071
1072            return expression.__class__(
1073                this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
1074            )
1075        return expression
1076
1077    @annotate_types_on_change
1078    def simplify_literals(self, expression, root=True):
1079        if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1080            return self._flat_simplify(expression, self._simplify_binary, root)
1081
1082        if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1083            return expression.this.this
1084
1085        if type(expression) in self.INVERSE_DATE_OPS:
1086            return (
1087                self._simplify_binary(expression, expression.this, expression.interval())
1088                or expression
1089            )
1090
1091        return expression
1092
1093    def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression:
1094        if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
1095            this = self._simplify_integer_cast(expr.this)
1096        else:
1097            this = expr.this
1098
1099        if isinstance(expr, exp.Cast) and this.is_int:
1100            num = this.to_py()
1101
1102            # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
1103            # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
1104            # engine-dependent
1105            if (
1106                self.TINYINT_MIN <= num <= self.TINYINT_MAX
1107                and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
1108            ) or (
1109                self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
1110                and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
1111            ):
1112                return this
1113
1114        return expr
1115
1116    def _simplify_binary(self, expression, a, b):
1117        if isinstance(expression, self.COMPARISONS):
1118            a = self._simplify_integer_cast(a)
1119            b = self._simplify_integer_cast(b)
1120
1121        if isinstance(expression, exp.Is):
1122            if isinstance(b, exp.Not):
1123                c = b.this
1124                not_ = True
1125            else:
1126                c = b
1127                not_ = False
1128
1129            if is_null(c):
1130                if isinstance(a, exp.Literal):
1131                    return exp.true() if not_ else exp.false()
1132                if is_null(a):
1133                    return exp.false() if not_ else exp.true()
1134        elif isinstance(expression, self.NULL_OK):
1135            return None
1136        elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
1137            return exp.null()
1138
1139        if a.is_number and b.is_number:
1140            num_a = a.to_py()
1141            num_b = b.to_py()
1142
1143            if isinstance(expression, exp.Add):
1144                return exp.Literal.number(num_a + num_b)
1145            if isinstance(expression, exp.Mul):
1146                return exp.Literal.number(num_a * num_b)
1147
1148            # We only simplify Sub, Div if a and b have the same parent because they're not associative
1149            if isinstance(expression, exp.Sub):
1150                return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
1151            if isinstance(expression, exp.Div):
1152                # engines have differing int div behavior so intdiv is not safe
1153                if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
1154                    return None
1155                return exp.Literal.number(num_a / num_b)
1156
1157            boolean = eval_boolean(expression, num_a, num_b)
1158
1159            if boolean:
1160                return boolean
1161        elif a.is_string and b.is_string:
1162            boolean = eval_boolean(expression, a.this, b.this)
1163
1164            if boolean:
1165                return boolean
1166        elif _is_date_literal(a) and isinstance(b, exp.Interval):
1167            date, b = extract_date(a), extract_interval(b)
1168            if date and b:
1169                if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
1170                    return date_literal(date + b, extract_type(a))
1171                if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
1172                    return date_literal(date - b, extract_type(a))
1173        elif isinstance(a, exp.Interval) and _is_date_literal(b):
1174            a, date = extract_interval(a), extract_date(b)
1175            # you cannot subtract a date from an interval
1176            if a and b and isinstance(expression, exp.Add):
1177                return date_literal(a + date, extract_type(b))
1178        elif _is_date_literal(a) and _is_date_literal(b):
1179            if isinstance(expression, exp.Predicate):
1180                a, b = extract_date(a), extract_date(b)
1181                boolean = eval_boolean(expression, a, b)
1182                if boolean:
1183                    return boolean
1184
1185        return None
1186
1187    @annotate_types_on_change
1188    def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
1189        # COALESCE(x) -> x
1190        if (
1191            isinstance(expression, exp.Coalesce)
1192            and (not expression.expressions or _is_nonnull_constant(expression.this))
1193            # COALESCE is also used as a Spark partitioning hint
1194            and not isinstance(expression.parent, exp.Hint)
1195        ):
1196            return expression.this
1197
1198        if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1199            return expression
1200
1201        if not isinstance(expression, self.COMPARISONS):
1202            return expression
1203
1204        if isinstance(expression.left, exp.Coalesce):
1205            coalesce = expression.left
1206            other = expression.right
1207        elif isinstance(expression.right, exp.Coalesce):
1208            coalesce = expression.right
1209            other = expression.left
1210        else:
1211            return expression
1212
1213        # This transformation is valid for non-constants,
1214        # but it really only does anything if they are both constants.
1215        if not _is_constant(other):
1216            return expression
1217
1218        # Find the first constant arg
1219        for arg_index, arg in enumerate(coalesce.expressions):
1220            if _is_constant(arg):
1221                break
1222        else:
1223            return expression
1224
1225        coalesce.set("expressions", coalesce.expressions[:arg_index])
1226
1227        # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1228        # since we already remove COALESCE at the top of this function.
1229        coalesce = coalesce if coalesce.expressions else coalesce.this
1230
1231        # This expression is more complex than when we started, but it will get simplified further
1232        return exp.paren(
1233            exp.or_(
1234                exp.and_(
1235                    coalesce.is_(exp.null()).not_(copy=False),
1236                    expression.copy(),
1237                    copy=False,
1238                ),
1239                exp.and_(
1240                    coalesce.is_(exp.null()),
1241                    type(expression)(this=arg.copy(), expression=other.copy()),
1242                    copy=False,
1243                ),
1244                copy=False,
1245            ),
1246            copy=False,
1247        )
1248
1249    @annotate_types_on_change
1250    def simplify_concat(self, expression):
1251        """Reduces all groups that contain string literals by concatenating them."""
1252        if not isinstance(expression, self.CONCATS) or (
1253            # We can't reduce a CONCAT_WS call if we don't statically know the separator
1254            isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1255        ):
1256            return expression
1257
1258        if isinstance(expression, exp.ConcatWs):
1259            sep_expr, *expressions = expression.expressions
1260            sep = sep_expr.name
1261            concat_type = exp.ConcatWs
1262            args = {}
1263        else:
1264            expressions = expression.expressions
1265            sep = ""
1266            concat_type = exp.Concat
1267            args = {
1268                "safe": expression.args.get("safe"),
1269                "coalesce": expression.args.get("coalesce"),
1270            }
1271
1272        new_args = []
1273        for is_string_group, group in itertools.groupby(
1274            expressions or expression.flatten(), lambda e: e.is_string
1275        ):
1276            if is_string_group:
1277                new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1278            else:
1279                new_args.extend(group)
1280
1281        if len(new_args) == 1 and new_args[0].is_string:
1282            return new_args[0]
1283
1284        if concat_type is exp.ConcatWs:
1285            new_args = [sep_expr] + new_args
1286        elif isinstance(expression, exp.DPipe):
1287            return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1288
1289        return concat_type(expressions=new_args, **args)
1290
1291    @annotate_types_on_change
1292    def simplify_conditionals(self, expression):
1293        """Simplifies expressions like IF, CASE if their condition is statically known."""
1294        if isinstance(expression, exp.Case):
1295            this = expression.this
1296            for case in expression.args["ifs"]:
1297                cond = case.this
1298                if this:
1299                    # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1300                    cond = cond.replace(this.pop().eq(cond))
1301
1302                if always_true(cond):
1303                    return case.args["true"]
1304
1305                if always_false(cond):
1306                    case.pop()
1307                    if not expression.args["ifs"]:
1308                        return expression.args.get("default") or exp.null()
1309        elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1310            if always_true(expression.this):
1311                return expression.args["true"]
1312            if always_false(expression.this):
1313                return expression.args.get("false") or exp.null()
1314
1315        return expression
1316
1317    @annotate_types_on_change
1318    def simplify_startswith(self, expression: exp.Expression) -> exp.Expression:
1319        """
1320        Reduces a prefix check to either TRUE or FALSE if both the string and the
1321        prefix are statically known.
1322
1323        Example:
1324            >>> from sqlglot import parse_one
1325            >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1326            'TRUE'
1327        """
1328        if (
1329            isinstance(expression, exp.StartsWith)
1330            and expression.this.is_string
1331            and expression.expression.is_string
1332        ):
1333            return exp.convert(expression.name.startswith(expression.expression.name))
1334
1335        return expression
1336
1337    def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool:
1338        return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
1339
1340    @annotate_types_on_change
1341    @catch(ModuleNotFoundError, UnsupportedUnit)
1342    def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression:
1343        """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1344        comparison = expression.__class__
1345
1346        if isinstance(expression, self.DATETRUNCS):
1347            this = expression.this
1348            trunc_type = extract_type(this)
1349            date = extract_date(this)
1350            if date and expression.unit:
1351                return date_literal(
1352                    date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
1353                )
1354        elif comparison not in self.DATETRUNC_COMPARISONS:
1355            return expression
1356
1357        if isinstance(expression, exp.Binary):
1358            l, r = expression.left, expression.right
1359
1360            if not self._is_datetrunc_predicate(l, r):
1361                return expression
1362
1363            l = t.cast(exp.DateTrunc, l)
1364            trunc_arg = l.this
1365            unit = l.unit.name.lower()
1366            date = extract_date(r)
1367
1368            if not date:
1369                return expression
1370
1371            return (
1372                self.DATETRUNC_BINARY_COMPARISONS[comparison](
1373                    trunc_arg, date, unit, self.dialect, extract_type(r)
1374                )
1375                or expression
1376            )
1377
1378        if isinstance(expression, exp.In):
1379            l = expression.this
1380            rs = expression.expressions
1381
1382            if rs and all(self._is_datetrunc_predicate(l, r) for r in rs):
1383                l = t.cast(exp.DateTrunc, l)
1384                unit = l.unit.name.lower()
1385
1386                ranges = []
1387                for r in rs:
1388                    date = extract_date(r)
1389                    if not date:
1390                        return expression
1391                    drange = _datetrunc_range(date, unit, self.dialect)
1392                    if drange:
1393                        ranges.append(drange)
1394
1395                if not ranges:
1396                    return expression
1397
1398                ranges = merge_ranges(ranges)
1399                target_type = extract_type(*rs)
1400
1401                return exp.or_(
1402                    *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
1403                    copy=False,
1404                )
1405
1406        return expression
1407
1408    @annotate_types_on_change
1409    def sort_comparison(self, expression: exp.Expression) -> exp.Expression:
1410        if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1411            l, r = expression.this, expression.expression
1412            l_column = isinstance(l, exp.Column)
1413            r_column = isinstance(r, exp.Column)
1414            l_const = _is_constant(l)
1415            r_const = _is_constant(r)
1416
1417            if (
1418                (l_column and not r_column)
1419                or (r_const and not l_const)
1420                or isinstance(r, exp.SubqueryPredicate)
1421            ):
1422                return expression
1423            if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1424                return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1425                    this=r, expression=l
1426                )
1427        return expression
1428
1429    def _flat_simplify(self, expression, simplifier, root=True):
1430        if root or not expression.same_parent:
1431            operands = []
1432            queue = deque(expression.flatten(unnest=False))
1433            size = len(queue)
1434
1435            while queue:
1436                a = queue.popleft()
1437
1438                for b in queue:
1439                    result = simplifier(expression, a, b)
1440
1441                    if result and result is not expression:
1442                        queue.remove(b)
1443                        queue.appendleft(result)
1444                        break
1445                else:
1446                    operands.append(a)
1447
1448            if len(operands) < size:
1449                return functools.reduce(
1450                    lambda a, b: expression.__class__(this=a, expression=b), operands
1451                )
1452        return expression
1453
1454
1455def gen(expression: t.Any, comments: bool = False) -> str:
1456    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1457
1458    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1459    generator is expensive so we have a bare minimum sql generator here.
1460
1461    Args:
1462        expression: the expression to convert into a SQL string.
1463        comments: whether to include the expression's comments.
1464    """
1465    return Gen().gen(expression, comments=comments)
1466
1467
1468class Gen:
1469    def __init__(self):
1470        self.stack = []
1471        self.sqls = []
1472
1473    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1474        self.stack = [expression]
1475        self.sqls.clear()
1476
1477        while self.stack:
1478            node = self.stack.pop()
1479
1480            if isinstance(node, exp.Expression):
1481                if comments and node.comments:
1482                    self.stack.append(f" /*{','.join(node.comments)}*/")
1483
1484                exp_handler_name = f"{node.key}_sql"
1485
1486                if hasattr(self, exp_handler_name):
1487                    getattr(self, exp_handler_name)(node)
1488                elif isinstance(node, exp.Func):
1489                    self._function(node)
1490                else:
1491                    key = node.key.upper()
1492                    self.stack.append(f"{key} " if self._args(node) else key)
1493            elif type(node) is list:
1494                for n in reversed(node):
1495                    if n is not None:
1496                        self.stack.extend((n, ","))
1497                if node:
1498                    self.stack.pop()
1499            else:
1500                if node is not None:
1501                    self.sqls.append(str(node))
1502
1503        return "".join(self.sqls)
1504
1505    def add_sql(self, e: exp.Add) -> None:
1506        self._binary(e, " + ")
1507
1508    def alias_sql(self, e: exp.Alias) -> None:
1509        self.stack.extend(
1510            (
1511                e.args.get("alias"),
1512                " AS ",
1513                e.args.get("this"),
1514            )
1515        )
1516
1517    def and_sql(self, e: exp.And) -> None:
1518        self._binary(e, " AND ")
1519
1520    def anonymous_sql(self, e: exp.Anonymous) -> None:
1521        this = e.this
1522        if isinstance(this, str):
1523            name = this.upper()
1524        elif isinstance(this, exp.Identifier):
1525            name = this.this
1526            name = f'"{name}"' if this.quoted else name.upper()
1527        else:
1528            raise ValueError(
1529                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1530            )
1531
1532        self.stack.extend(
1533            (
1534                ")",
1535                e.expressions,
1536                "(",
1537                name,
1538            )
1539        )
1540
1541    def between_sql(self, e: exp.Between) -> None:
1542        self.stack.extend(
1543            (
1544                e.args.get("high"),
1545                " AND ",
1546                e.args.get("low"),
1547                " BETWEEN ",
1548                e.this,
1549            )
1550        )
1551
1552    def boolean_sql(self, e: exp.Boolean) -> None:
1553        self.stack.append("TRUE" if e.this else "FALSE")
1554
1555    def bracket_sql(self, e: exp.Bracket) -> None:
1556        self.stack.extend(
1557            (
1558                "]",
1559                e.expressions,
1560                "[",
1561                e.this,
1562            )
1563        )
1564
1565    def column_sql(self, e: exp.Column) -> None:
1566        for p in reversed(e.parts):
1567            self.stack.extend((p, "."))
1568        self.stack.pop()
1569
1570    def datatype_sql(self, e: exp.DataType) -> None:
1571        self._args(e, 1)
1572        self.stack.append(f"{e.this.name} ")
1573
1574    def div_sql(self, e: exp.Div) -> None:
1575        self._binary(e, " / ")
1576
1577    def dot_sql(self, e: exp.Dot) -> None:
1578        self._binary(e, ".")
1579
1580    def eq_sql(self, e: exp.EQ) -> None:
1581        self._binary(e, " = ")
1582
1583    def from_sql(self, e: exp.From) -> None:
1584        self.stack.extend((e.this, "FROM "))
1585
1586    def gt_sql(self, e: exp.GT) -> None:
1587        self._binary(e, " > ")
1588
1589    def gte_sql(self, e: exp.GTE) -> None:
1590        self._binary(e, " >= ")
1591
1592    def identifier_sql(self, e: exp.Identifier) -> None:
1593        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1594
1595    def ilike_sql(self, e: exp.ILike) -> None:
1596        self._binary(e, " ILIKE ")
1597
1598    def in_sql(self, e: exp.In) -> None:
1599        self.stack.append(")")
1600        self._args(e, 1)
1601        self.stack.extend(
1602            (
1603                "(",
1604                " IN ",
1605                e.this,
1606            )
1607        )
1608
1609    def intdiv_sql(self, e: exp.IntDiv) -> None:
1610        self._binary(e, " DIV ")
1611
1612    def is_sql(self, e: exp.Is) -> None:
1613        self._binary(e, " IS ")
1614
1615    def like_sql(self, e: exp.Like) -> None:
1616        self._binary(e, " Like ")
1617
1618    def literal_sql(self, e: exp.Literal) -> None:
1619        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1620
1621    def lt_sql(self, e: exp.LT) -> None:
1622        self._binary(e, " < ")
1623
1624    def lte_sql(self, e: exp.LTE) -> None:
1625        self._binary(e, " <= ")
1626
1627    def mod_sql(self, e: exp.Mod) -> None:
1628        self._binary(e, " % ")
1629
1630    def mul_sql(self, e: exp.Mul) -> None:
1631        self._binary(e, " * ")
1632
1633    def neg_sql(self, e: exp.Neg) -> None:
1634        self._unary(e, "-")
1635
1636    def neq_sql(self, e: exp.NEQ) -> None:
1637        self._binary(e, " <> ")
1638
1639    def not_sql(self, e: exp.Not) -> None:
1640        self._unary(e, "NOT ")
1641
1642    def null_sql(self, e: exp.Null) -> None:
1643        self.stack.append("NULL")
1644
1645    def or_sql(self, e: exp.Or) -> None:
1646        self._binary(e, " OR ")
1647
1648    def paren_sql(self, e: exp.Paren) -> None:
1649        self.stack.extend(
1650            (
1651                ")",
1652                e.this,
1653                "(",
1654            )
1655        )
1656
1657    def sub_sql(self, e: exp.Sub) -> None:
1658        self._binary(e, " - ")
1659
1660    def subquery_sql(self, e: exp.Subquery) -> None:
1661        self._args(e, 2)
1662        alias = e.args.get("alias")
1663        if alias:
1664            self.stack.append(alias)
1665        self.stack.extend((")", e.this, "("))
1666
1667    def table_sql(self, e: exp.Table) -> None:
1668        self._args(e, 4)
1669        alias = e.args.get("alias")
1670        if alias:
1671            self.stack.append(alias)
1672        for p in reversed(e.parts):
1673            self.stack.extend((p, "."))
1674        self.stack.pop()
1675
1676    def tablealias_sql(self, e: exp.TableAlias) -> None:
1677        columns = e.columns
1678
1679        if columns:
1680            self.stack.extend((")", columns, "("))
1681
1682        self.stack.extend((e.this, " AS "))
1683
1684    def var_sql(self, e: exp.Var) -> None:
1685        self.stack.append(e.this)
1686
1687    def _binary(self, e: exp.Binary, op: str) -> None:
1688        self.stack.extend((e.expression, op, e.this))
1689
1690    def _unary(self, e: exp.Unary, op: str) -> None:
1691        self.stack.extend((e.this, op))
1692
1693    def _function(self, e: exp.Func) -> None:
1694        self.stack.extend(
1695            (
1696                ")",
1697                list(e.args.values()),
1698                "(",
1699                e.sql_name(),
1700            )
1701        )
1702
1703    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1704        kvs = []
1705        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1706
1707        for k in arg_types:
1708            v = node.args.get(k)
1709
1710            if v is not None:
1711                kvs.append([f":{k}", v])
1712        if kvs:
1713            self.stack.append(kvs)
1714            return True
1715        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, coalesce_simplification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None):
43def simplify(
44    expression: exp.Expression,
45    constant_propagation: bool = False,
46    coalesce_simplification: bool = False,
47    dialect: DialectType = None,
48):
49    """
50    Rewrite sqlglot AST to simplify expressions.
51
52    Example:
53        >>> import sqlglot
54        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
55        >>> simplify(expression).sql()
56        'TRUE'
57
58    Args:
59        expression: expression to simplify
60        constant_propagation: whether the constant propagation rule should be used
61        coalesce_simplification: whether the simplify coalesce rule should be used.
62            This rule tries to remove coalesce functions, which can be useful in certain analyses but
63            can leave the query more verbose.
64    Returns:
65        sqlglot.Expression: simplified expression
66    """
67    return Simplifier(dialect=dialect).simplify(
68        expression,
69        constant_propagation=constant_propagation,
70        coalesce_simplification=coalesce_simplification,
71    )

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:

sqlglot.Expression: simplified expression

class UnsupportedUnit(builtins.Exception):
74class UnsupportedUnit(Exception):
75    pass

Common base class for all non-exit exceptions.

def catch(*exceptions):
78def catch(*exceptions):
79    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
80
81    def decorator(func):
82        def wrapped(expression, *args, **kwargs):
83            try:
84                return func(expression, *args, **kwargs)
85            except exceptions:
86                return expression
87
88        return wrapped
89
90    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def annotate_types_on_change(func):
 93def annotate_types_on_change(func):
 94    @wraps(func)
 95    def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]:
 96        new_expression = func(self, expression, *args, **kwargs)
 97
 98        if new_expression is None:
 99            return new_expression
100
101        if self.annotate_new_expressions and expression != new_expression:
102            self._annotator.clear()
103
104            # We annotate this to ensure new children nodes are also annotated
105            new_expression = self._annotator.annotate(
106                expression=new_expression,
107                annotate_scope=False,
108            )
109
110            # Whatever expression the original expression is transformed into needs to preserve
111            # the original type, otherwise the simplification could result in a different schema
112            new_expression.type = expression.type
113
114        return new_expression
115
116    return _func
def flatten(expression):
119def flatten(expression):
120    """
121    A AND (B AND C) -> A AND B AND C
122    A OR (B OR C) -> A OR B OR C
123    """
124    if isinstance(expression, exp.Connector):
125        for node in expression.args.values():
126            child = node.unnest()
127            if isinstance(child, expression.__class__):
128                node.replace(child)
129    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_parens( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
132def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression:
133    if not isinstance(expression, exp.Paren):
134        return expression
135
136    this = expression.this
137    parent = expression.parent
138    parent_is_predicate = isinstance(parent, exp.Predicate)
139
140    if isinstance(this, exp.Select):
141        return expression
142
143    if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
144        return expression
145
146    if (
147        Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
148        and isinstance(parent, exp.Dot)
149        and (isinstance(parent.right, (exp.Identifier, exp.Star)))
150    ):
151        return expression
152
153    if (
154        not isinstance(parent, (exp.Condition, exp.Binary))
155        or isinstance(parent, exp.Paren)
156        or (
157            not isinstance(this, exp.Binary)
158            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
159        )
160        or (
161            isinstance(this, exp.Predicate)
162            and not (parent_is_predicate or isinstance(parent, exp.Neg))
163        )
164        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
165        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
166        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
167    ):
168        return this
169
170    return expression
def propagate_constants(expression, root=True):
173def propagate_constants(expression, root=True):
174    """
175    Propagate constants for conjunctions in DNF:
176
177    SELECT * FROM t WHERE a = b AND b = 5 becomes
178    SELECT * FROM t WHERE a = 5 AND b = 5
179
180    Reference: https://www.sqlite.org/optoverview.html
181    """
182
183    if (
184        isinstance(expression, exp.And)
185        and (root or not expression.same_parent)
186        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
187    ):
188        constant_mapping = {}
189        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
190            if isinstance(expr, exp.EQ):
191                l, r = expr.left, expr.right
192
193                # TODO: create a helper that can be used to detect nested literal expressions such
194                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
195                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
196                    constant_mapping[l] = (id(l), r)
197
198        if constant_mapping:
199            for column in find_all_in_scope(expression, exp.Column):
200                parent = column.parent
201                column_id, constant = constant_mapping.get(column) or (None, None)
202                if (
203                    column_id is not None
204                    and id(column) != column_id
205                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
206                ):
207                    column.replace(constant.copy())
208
209    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def always_true(expression):
289def always_true(expression):
290    return (isinstance(expression, exp.Boolean) and expression.this) or (
291        isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression)
292    )
def always_false(expression):
295def always_false(expression):
296    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression):
299def is_zero(expression):
300    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a, b):
303def is_complement(a, b):
304    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
307def is_false(a: exp.Expression) -> bool:
308    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
311def is_null(a: exp.Expression) -> bool:
312    return type(a) is exp.Null
def eval_boolean(expression, a, b):
315def eval_boolean(expression, a, b):
316    if isinstance(expression, (exp.EQ, exp.Is)):
317        return boolean_literal(a == b)
318    if isinstance(expression, exp.NEQ):
319        return boolean_literal(a != b)
320    if isinstance(expression, exp.GT):
321        return boolean_literal(a > b)
322    if isinstance(expression, exp.GTE):
323        return boolean_literal(a >= b)
324    if isinstance(expression, exp.LT):
325        return boolean_literal(a < b)
326    if isinstance(expression, exp.LTE):
327        return boolean_literal(a <= b)
328    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
331def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
332    if isinstance(value, datetime.datetime):
333        return value.date()
334    if isinstance(value, datetime.date):
335        return value
336    try:
337        return datetime.datetime.fromisoformat(value).date()
338    except ValueError:
339        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
342def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
343    if isinstance(value, datetime.datetime):
344        return value
345    if isinstance(value, datetime.date):
346        return datetime.datetime(year=value.year, month=value.month, day=value.day)
347    try:
348        return datetime.datetime.fromisoformat(value)
349    except ValueError:
350        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
353def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
354    if not value:
355        return None
356    if to.is_type(exp.DataType.Type.DATE):
357        return cast_as_date(value)
358    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
359        return cast_as_datetime(value)
360    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
363def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
364    if isinstance(cast, exp.Cast):
365        to = cast.to
366    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
367        to = exp.DataType.build(exp.DataType.Type.DATE)
368    else:
369        return None
370
371    if isinstance(cast.this, exp.Literal):
372        value: t.Any = cast.this.name
373    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
374        value = extract_date(cast.this)
375    else:
376        return None
377    return cast_value(value, to)
def extract_interval(expression):
384def extract_interval(expression):
385    try:
386        n = int(expression.this.to_py())
387        unit = expression.text("unit").lower()
388        return interval(unit, n)
389    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
390        return None
def extract_type(*expressions):
393def extract_type(*expressions):
394    target_type = None
395    for expression in expressions:
396        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
397        if target_type:
398            break
399
400    return target_type
def date_literal(date, target_type=None):
403def date_literal(date, target_type=None):
404    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
405        target_type = (
406            exp.DataType.Type.DATETIME
407            if isinstance(date, datetime.datetime)
408            else exp.DataType.Type.DATE
409        )
410
411    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
414def interval(unit: str, n: int = 1):
415    from dateutil.relativedelta import relativedelta
416
417    if unit == "year":
418        return relativedelta(years=1 * n)
419    if unit == "quarter":
420        return relativedelta(months=3 * n)
421    if unit == "month":
422        return relativedelta(months=1 * n)
423    if unit == "week":
424        return relativedelta(weeks=1 * n)
425    if unit == "day":
426        return relativedelta(days=1 * n)
427    if unit == "hour":
428        return relativedelta(hours=1 * n)
429    if unit == "minute":
430        return relativedelta(minutes=1 * n)
431    if unit == "second":
432        return relativedelta(seconds=1 * n)
433
434    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
437def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
438    if unit == "year":
439        return d.replace(month=1, day=1)
440    if unit == "quarter":
441        if d.month <= 3:
442            return d.replace(month=1, day=1)
443        elif d.month <= 6:
444            return d.replace(month=4, day=1)
445        elif d.month <= 9:
446            return d.replace(month=7, day=1)
447        else:
448            return d.replace(month=10, day=1)
449    if unit == "month":
450        return d.replace(month=d.month, day=1)
451    if unit == "week":
452        # Assuming week starts on Monday (0) and ends on Sunday (6)
453        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
454    if unit == "day":
455        return d
456
457    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
460def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
461    floor = date_floor(d, unit, dialect)
462
463    if floor == d:
464        return d
465
466    return floor + interval(unit)
def boolean_literal(condition):
469def boolean_literal(condition):
470    return exp.true() if condition else exp.false()
class Simplifier:
 473class Simplifier:
 474    def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
 475        self.dialect = Dialect.get_or_raise(dialect)
 476        self.annotate_new_expressions = annotate_new_expressions
 477
 478        self._annotator: TypeAnnotator = TypeAnnotator(
 479            schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
 480        )
 481
 482    # Value ranges for byte-sized signed/unsigned integers
 483    TINYINT_MIN = -128
 484    TINYINT_MAX = 127
 485    UTINYINT_MIN = 0
 486    UTINYINT_MAX = 255
 487
 488    COMPLEMENT_COMPARISONS = {
 489        exp.LT: exp.GTE,
 490        exp.GT: exp.LTE,
 491        exp.LTE: exp.GT,
 492        exp.GTE: exp.LT,
 493        exp.EQ: exp.NEQ,
 494        exp.NEQ: exp.EQ,
 495    }
 496
 497    COMPLEMENT_SUBQUERY_PREDICATES = {
 498        exp.All: exp.Any,
 499        exp.Any: exp.All,
 500    }
 501
 502    LT_LTE = (exp.LT, exp.LTE)
 503    GT_GTE = (exp.GT, exp.GTE)
 504
 505    COMPARISONS = (
 506        *LT_LTE,
 507        *GT_GTE,
 508        exp.EQ,
 509        exp.NEQ,
 510        exp.Is,
 511    )
 512
 513    INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 514        exp.LT: exp.GT,
 515        exp.GT: exp.LT,
 516        exp.LTE: exp.GTE,
 517        exp.GTE: exp.LTE,
 518    }
 519
 520    NONDETERMINISTIC = (exp.Rand, exp.Randn)
 521    AND_OR = (exp.And, exp.Or)
 522
 523    INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 524        exp.DateAdd: exp.Sub,
 525        exp.DateSub: exp.Add,
 526        exp.DatetimeAdd: exp.Sub,
 527        exp.DatetimeSub: exp.Add,
 528    }
 529
 530    INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 531        **INVERSE_DATE_OPS,
 532        exp.Add: exp.Sub,
 533        exp.Sub: exp.Add,
 534    }
 535
 536    NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 537
 538    CONCATS = (exp.Concat, exp.DPipe)
 539
 540    DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 541        exp.LT: lambda l, dt, u, d, t: l
 542        < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 543        exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 544        exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 545        exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 546        exp.EQ: _datetrunc_eq,
 547        exp.NEQ: _datetrunc_neq,
 548    }
 549
 550    DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 551    DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 552
 553    SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean)
 554
 555    # CROSS joins result in an empty table if the right table is empty.
 556    # So we can only simplify certain types of joins to CROSS.
 557    # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 558    JOINS = {
 559        ("", ""),
 560        ("", "INNER"),
 561        ("RIGHT", ""),
 562        ("RIGHT", "OUTER"),
 563    }
 564
 565    def simplify(
 566        self,
 567        expression: exp.Expression,
 568        constant_propagation: bool = False,
 569        coalesce_simplification: bool = False,
 570    ):
 571        wheres = []
 572        joins = []
 573
 574        for node in expression.walk(
 575            prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
 576        ):
 577            if node.meta.get(FINAL):
 578                continue
 579
 580            # group by expressions cannot be simplified, for example
 581            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 582            # the projection must exactly match the group by key
 583            group = node.args.get("group")
 584
 585            if group and hasattr(node, "selects"):
 586                groups = set(group.expressions)
 587                group.meta[FINAL] = True
 588
 589                for s in node.selects:
 590                    for n in s.walk(FINAL):
 591                        if n in groups:
 592                            s.meta[FINAL] = True
 593                            break
 594
 595                having = node.args.get("having")
 596
 597                if having:
 598                    for n in having.walk():
 599                        if n in groups:
 600                            having.meta[FINAL] = True
 601                            break
 602
 603            if isinstance(node, exp.Condition):
 604                simplified = while_changing(
 605                    node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
 606                )
 607
 608                if node is expression:
 609                    expression = simplified
 610            elif isinstance(node, exp.Where):
 611                wheres.append(node)
 612            elif isinstance(node, exp.Join):
 613                # snowflake match_conditions have very strict ordering rules
 614                if match := node.args.get("match_condition"):
 615                    match.meta[FINAL] = True
 616
 617                joins.append(node)
 618
 619        for where in wheres:
 620            if always_true(where.this):
 621                where.pop()
 622        for join in joins:
 623            if (
 624                always_true(join.args.get("on"))
 625                and not join.args.get("using")
 626                and not join.args.get("method")
 627                and (join.side, join.kind) in self.JOINS
 628            ):
 629                join.args["on"].pop()
 630                join.set("side", None)
 631                join.set("kind", "CROSS")
 632
 633        return expression
 634
 635    def _simplify(
 636        self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool
 637    ):
 638        pre_transformation_stack = [expression]
 639        post_transformation_stack = []
 640
 641        while pre_transformation_stack:
 642            original = pre_transformation_stack.pop()
 643            node = original
 644
 645            if not isinstance(node, SIMPLIFIABLE):
 646                if isinstance(node, exp.Query):
 647                    self.simplify(node, constant_propagation, coalesce_simplification)
 648                continue
 649
 650            parent = node.parent
 651            root = node is expression
 652
 653            node = self.rewrite_between(node)
 654            node = self.uniq_sort(node, root)
 655            node = self.absorb_and_eliminate(node, root)
 656            node = self.simplify_concat(node)
 657            node = self.simplify_conditionals(node)
 658
 659            if constant_propagation:
 660                node = propagate_constants(node, root)
 661
 662            if node is not original:
 663                original.replace(node)
 664
 665            for n in node.iter_expressions(reverse=True):
 666                if n.meta.get(FINAL):
 667                    raise
 668            pre_transformation_stack.extend(
 669                n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 670            )
 671            post_transformation_stack.append((node, parent))
 672
 673        while post_transformation_stack:
 674            original, parent = post_transformation_stack.pop()
 675            root = original is expression
 676
 677            # Resets parent, arg_key, index pointers– this is needed because some of the
 678            # previous transformations mutate the AST, leading to an inconsistent state
 679            for k, v in tuple(original.args.items()):
 680                original.set(k, v)
 681
 682            # Post-order transformations
 683            node = self.simplify_not(original)
 684            node = flatten(node)
 685            node = self.simplify_connectors(node, root)
 686            node = self.remove_complements(node, root)
 687
 688            if coalesce_simplification:
 689                node = self.simplify_coalesce(node)
 690            node.parent = parent
 691
 692            node = self.simplify_literals(node, root)
 693            node = self.simplify_equality(node)
 694            node = simplify_parens(node, dialect=self.dialect)
 695            node = self.simplify_datetrunc(node)
 696            node = self.sort_comparison(node)
 697            node = self.simplify_startswith(node)
 698
 699            if node is not original:
 700                original.replace(node)
 701
 702        return node
 703
 704    @annotate_types_on_change
 705    def rewrite_between(self, expression: exp.Expression) -> exp.Expression:
 706        """Rewrite x between y and z to x >= y AND x <= z.
 707
 708        This is done because comparison simplification is only done on lt/lte/gt/gte.
 709        """
 710        if isinstance(expression, exp.Between):
 711            negate = isinstance(expression.parent, exp.Not)
 712
 713            expression = exp.and_(
 714                exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 715                exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 716                copy=False,
 717            )
 718
 719            if negate:
 720                expression = exp.paren(expression, copy=False)
 721
 722        return expression
 723
 724    @annotate_types_on_change
 725    def simplify_not(self, expression: exp.Expression) -> exp.Expression:
 726        """
 727        Demorgan's Law
 728        NOT (x OR y) -> NOT x AND NOT y
 729        NOT (x AND y) -> NOT x OR NOT y
 730        """
 731        if isinstance(expression, exp.Not):
 732            this = expression.this
 733            if is_null(this):
 734                return exp.and_(exp.null(), exp.true(), copy=False)
 735            if this.__class__ in self.COMPLEMENT_COMPARISONS:
 736                right = this.expression
 737                complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
 738                    right.__class__
 739                )
 740                if complement_subquery_predicate:
 741                    right = complement_subquery_predicate(this=right.this)
 742
 743                return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 744            if isinstance(this, exp.Paren):
 745                condition = this.unnest()
 746                if isinstance(condition, exp.And):
 747                    return exp.paren(
 748                        exp.or_(
 749                            exp.not_(condition.left, copy=False),
 750                            exp.not_(condition.right, copy=False),
 751                            copy=False,
 752                        ),
 753                        copy=False,
 754                    )
 755                if isinstance(condition, exp.Or):
 756                    return exp.paren(
 757                        exp.and_(
 758                            exp.not_(condition.left, copy=False),
 759                            exp.not_(condition.right, copy=False),
 760                            copy=False,
 761                        ),
 762                        copy=False,
 763                    )
 764                if is_null(condition):
 765                    return exp.and_(exp.null(), exp.true(), copy=False)
 766            if always_true(this):
 767                return exp.false()
 768            if is_false(this):
 769                return exp.true()
 770            if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
 771                inner = this.this
 772                if inner.is_type(exp.DataType.Type.BOOLEAN):
 773                    # double negation
 774                    # NOT NOT x -> x, if x is BOOLEAN type
 775                    return inner
 776        return expression
 777
 778    @annotate_types_on_change
 779    def simplify_connectors(self, expression, root=True):
 780        def _simplify_connectors(expression, left, right):
 781            if isinstance(expression, exp.And):
 782                if is_false(left) or is_false(right):
 783                    return exp.false()
 784                if is_zero(left) or is_zero(right):
 785                    return exp.false()
 786                if (
 787                    (is_null(left) and is_null(right))
 788                    or (is_null(left) and always_true(right))
 789                    or (always_true(left) and is_null(right))
 790                ):
 791                    return exp.null()
 792                if always_true(left) and always_true(right):
 793                    return exp.true()
 794                if always_true(left):
 795                    return right
 796                if always_true(right):
 797                    return left
 798                return self._simplify_comparison(expression, left, right)
 799            elif isinstance(expression, exp.Or):
 800                if always_true(left) or always_true(right):
 801                    return exp.true()
 802                if (
 803                    (is_null(left) and is_null(right))
 804                    or (is_null(left) and always_false(right))
 805                    or (always_false(left) and is_null(right))
 806                ):
 807                    return exp.null()
 808                if is_false(left):
 809                    return right
 810                if is_false(right):
 811                    return left
 812                return self._simplify_comparison(expression, left, right, or_=True)
 813
 814        if isinstance(expression, exp.Connector):
 815            original_parent = expression.parent
 816            expression = self._flat_simplify(expression, _simplify_connectors, root)
 817
 818            # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
 819            # to ensure that the resulting type is boolean. We know this is true only for connectors,
 820            # boolean values and columns that are essentially operands to a connector:
 821            #
 822            # A AND (((B)))
 823            #          ~ this is safe to keep because it will eventually be part of another connector
 824            if not isinstance(
 825                expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
 826            ) and not expression.is_type(exp.DataType.Type.BOOLEAN):
 827                while True:
 828                    if isinstance(original_parent, exp.Connector):
 829                        break
 830                    if not isinstance(original_parent, exp.Paren):
 831                        expression = expression.and_(exp.true(), copy=False)
 832                        break
 833
 834                    original_parent = original_parent.parent
 835
 836        return expression
 837
 838    @annotate_types_on_change
 839    def _simplify_comparison(self, expression, left, right, or_=False):
 840        if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
 841            ll, lr = left.args.values()
 842            rl, rr = right.args.values()
 843
 844            largs = {ll, lr}
 845            rargs = {rl, rr}
 846
 847            matching = largs & rargs
 848            columns = {
 849                m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
 850            }
 851
 852            if matching and columns:
 853                try:
 854                    l = first(largs - columns)
 855                    r = first(rargs - columns)
 856                except StopIteration:
 857                    return expression
 858
 859                if l.is_number and r.is_number:
 860                    l = l.to_py()
 861                    r = r.to_py()
 862                elif l.is_string and r.is_string:
 863                    l = l.name
 864                    r = r.name
 865                else:
 866                    l = extract_date(l)
 867                    if not l:
 868                        return None
 869                    r = extract_date(r)
 870                    if not r:
 871                        return None
 872                    # python won't compare date and datetime, but many engines will upcast
 873                    l, r = cast_as_datetime(l), cast_as_datetime(r)
 874
 875                for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 876                    if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
 877                        return left if (av > bv if or_ else av <= bv) else right
 878                    if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
 879                        return left if (av < bv if or_ else av >= bv) else right
 880
 881                    # we can't ever shortcut to true because the column could be null
 882                    if not or_:
 883                        if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
 884                            if av <= bv:
 885                                return exp.false()
 886                        elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
 887                            if av >= bv:
 888                                return exp.false()
 889                        elif isinstance(a, exp.EQ):
 890                            if isinstance(b, exp.LT):
 891                                return exp.false() if av >= bv else a
 892                            if isinstance(b, exp.LTE):
 893                                return exp.false() if av > bv else a
 894                            if isinstance(b, exp.GT):
 895                                return exp.false() if av <= bv else a
 896                            if isinstance(b, exp.GTE):
 897                                return exp.false() if av < bv else a
 898                            if isinstance(b, exp.NEQ):
 899                                return exp.false() if av == bv else a
 900        return None
 901
 902    @annotate_types_on_change
 903    def remove_complements(self, expression, root=True):
 904        """
 905        Removing complements.
 906
 907        A AND NOT A -> FALSE (only for non-NULL A)
 908        A OR NOT A -> TRUE (only for non-NULL A)
 909        """
 910        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 911            ops = set(expression.flatten())
 912            for op in ops:
 913                if isinstance(op, exp.Not) and op.this in ops:
 914                    if expression.meta.get("nonnull") is True:
 915                        return exp.false() if isinstance(expression, exp.And) else exp.true()
 916
 917        return expression
 918
 919    @annotate_types_on_change
 920    def uniq_sort(self, expression, root=True):
 921        """
 922        Uniq and sort a connector.
 923
 924        C AND A AND B AND B -> A AND B AND C
 925        """
 926        if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 927            flattened = tuple(expression.flatten())
 928
 929            if isinstance(expression, exp.Xor):
 930                result_func = exp.xor
 931                # Do not deduplicate XOR as A XOR A != A if A == True
 932                deduped = None
 933                arr = tuple((gen(e), e) for e in flattened)
 934            else:
 935                result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 936                deduped = {gen(e): e for e in flattened}
 937                arr = tuple(deduped.items())
 938
 939            # check if the operands are already sorted, if not sort them
 940            # A AND C AND B -> A AND B AND C
 941            for i, (sql, e) in enumerate(arr[1:]):
 942                if sql < arr[i][0]:
 943                    expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 944                    break
 945            else:
 946                # we didn't have to sort but maybe we need to dedup
 947                if deduped and len(deduped) < len(flattened):
 948                    unique_operand = flattened[0]
 949                    if len(deduped) == 1:
 950                        expression = unique_operand.and_(exp.true(), copy=False)
 951                    else:
 952                        expression = result_func(*deduped.values(), copy=False)
 953
 954        return expression
 955
 956    @annotate_types_on_change
 957    def absorb_and_eliminate(self, expression, root=True):
 958        """
 959        absorption:
 960            A AND (A OR B) -> A
 961            A OR (A AND B) -> A
 962            A AND (NOT A OR B) -> A AND B
 963            A OR (NOT A AND B) -> A OR B
 964        elimination:
 965            (A AND B) OR (A AND NOT B) -> A
 966            (A OR B) AND (A OR NOT B) -> A
 967        """
 968        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 969            kind = exp.Or if isinstance(expression, exp.And) else exp.And
 970
 971            ops = tuple(expression.flatten())
 972
 973            # Initialize lookup tables:
 974            # Set of all operands, used to find complements for absorption.
 975            op_set = set()
 976            # Sub-operands, used to find subsets for absorption.
 977            subops = defaultdict(list)
 978            # Pairs of complements, used for elimination.
 979            pairs = defaultdict(list)
 980
 981            # Populate the lookup tables
 982            for op in ops:
 983                op_set.add(op)
 984
 985                if not isinstance(op, kind):
 986                    # In cases like: A OR (A AND B)
 987                    # Subop will be: ^
 988                    subops[op].append({op})
 989                    continue
 990
 991                # In cases like: (A AND B) OR (A AND B AND C)
 992                # Subops will be: ^     ^
 993                subset = set(op.flatten())
 994                for i in subset:
 995                    subops[i].append(subset)
 996
 997                a, b = op.unnest_operands()
 998                if isinstance(a, exp.Not):
 999                    pairs[frozenset((a.this, b))].append((op, b))
1000                if isinstance(b, exp.Not):
1001                    pairs[frozenset((a, b.this))].append((op, a))
1002
1003            for op in ops:
1004                if not isinstance(op, kind):
1005                    continue
1006
1007                a, b = op.unnest_operands()
1008
1009                # Absorb
1010                if isinstance(a, exp.Not) and a.this in op_set:
1011                    a.replace(exp.true() if kind == exp.And else exp.false())
1012                    continue
1013                if isinstance(b, exp.Not) and b.this in op_set:
1014                    b.replace(exp.true() if kind == exp.And else exp.false())
1015                    continue
1016                superset = set(op.flatten())
1017                if any(any(subset < superset for subset in subops[i]) for i in superset):
1018                    op.replace(exp.false() if kind == exp.And else exp.true())
1019                    continue
1020
1021                # Eliminate
1022                for other, complement in pairs[frozenset((a, b))]:
1023                    op.replace(complement)
1024                    other.replace(complement)
1025
1026        return expression
1027
1028    @annotate_types_on_change
1029    @catch(ModuleNotFoundError, UnsupportedUnit)
1030    def simplify_equality(self, expression: exp.Expression) -> exp.Expression:
1031        """
1032        Use the subtraction and addition properties of equality to simplify expressions:
1033
1034            x + 1 = 3 becomes x = 2
1035
1036        There are two binary operations in the above expression: + and =
1037        Here's how we reference all the operands in the code below:
1038
1039            l     r
1040            x + 1 = 3
1041            a   b
1042        """
1043        if isinstance(expression, self.COMPARISONS):
1044            l, r = expression.left, expression.right
1045
1046            if l.__class__ not in self.INVERSE_OPS:
1047                return expression
1048
1049            if r.is_number:
1050                a_predicate = _is_number
1051                b_predicate = _is_number
1052            elif _is_date_literal(r):
1053                a_predicate = _is_date_literal
1054                b_predicate = _is_interval
1055            else:
1056                return expression
1057
1058            if l.__class__ in self.INVERSE_DATE_OPS:
1059                l = t.cast(exp.IntervalOp, l)
1060                a = l.this
1061                b = l.interval()
1062            else:
1063                l = t.cast(exp.Binary, l)
1064                a, b = l.left, l.right
1065
1066            if not a_predicate(a) and b_predicate(b):
1067                pass
1068            elif not a_predicate(b) and b_predicate(a):
1069                a, b = b, a
1070            else:
1071                return expression
1072
1073            return expression.__class__(
1074                this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
1075            )
1076        return expression
1077
1078    @annotate_types_on_change
1079    def simplify_literals(self, expression, root=True):
1080        if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1081            return self._flat_simplify(expression, self._simplify_binary, root)
1082
1083        if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1084            return expression.this.this
1085
1086        if type(expression) in self.INVERSE_DATE_OPS:
1087            return (
1088                self._simplify_binary(expression, expression.this, expression.interval())
1089                or expression
1090            )
1091
1092        return expression
1093
1094    def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression:
1095        if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
1096            this = self._simplify_integer_cast(expr.this)
1097        else:
1098            this = expr.this
1099
1100        if isinstance(expr, exp.Cast) and this.is_int:
1101            num = this.to_py()
1102
1103            # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
1104            # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
1105            # engine-dependent
1106            if (
1107                self.TINYINT_MIN <= num <= self.TINYINT_MAX
1108                and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
1109            ) or (
1110                self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
1111                and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
1112            ):
1113                return this
1114
1115        return expr
1116
1117    def _simplify_binary(self, expression, a, b):
1118        if isinstance(expression, self.COMPARISONS):
1119            a = self._simplify_integer_cast(a)
1120            b = self._simplify_integer_cast(b)
1121
1122        if isinstance(expression, exp.Is):
1123            if isinstance(b, exp.Not):
1124                c = b.this
1125                not_ = True
1126            else:
1127                c = b
1128                not_ = False
1129
1130            if is_null(c):
1131                if isinstance(a, exp.Literal):
1132                    return exp.true() if not_ else exp.false()
1133                if is_null(a):
1134                    return exp.false() if not_ else exp.true()
1135        elif isinstance(expression, self.NULL_OK):
1136            return None
1137        elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
1138            return exp.null()
1139
1140        if a.is_number and b.is_number:
1141            num_a = a.to_py()
1142            num_b = b.to_py()
1143
1144            if isinstance(expression, exp.Add):
1145                return exp.Literal.number(num_a + num_b)
1146            if isinstance(expression, exp.Mul):
1147                return exp.Literal.number(num_a * num_b)
1148
1149            # We only simplify Sub, Div if a and b have the same parent because they're not associative
1150            if isinstance(expression, exp.Sub):
1151                return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
1152            if isinstance(expression, exp.Div):
1153                # engines have differing int div behavior so intdiv is not safe
1154                if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
1155                    return None
1156                return exp.Literal.number(num_a / num_b)
1157
1158            boolean = eval_boolean(expression, num_a, num_b)
1159
1160            if boolean:
1161                return boolean
1162        elif a.is_string and b.is_string:
1163            boolean = eval_boolean(expression, a.this, b.this)
1164
1165            if boolean:
1166                return boolean
1167        elif _is_date_literal(a) and isinstance(b, exp.Interval):
1168            date, b = extract_date(a), extract_interval(b)
1169            if date and b:
1170                if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
1171                    return date_literal(date + b, extract_type(a))
1172                if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
1173                    return date_literal(date - b, extract_type(a))
1174        elif isinstance(a, exp.Interval) and _is_date_literal(b):
1175            a, date = extract_interval(a), extract_date(b)
1176            # you cannot subtract a date from an interval
1177            if a and b and isinstance(expression, exp.Add):
1178                return date_literal(a + date, extract_type(b))
1179        elif _is_date_literal(a) and _is_date_literal(b):
1180            if isinstance(expression, exp.Predicate):
1181                a, b = extract_date(a), extract_date(b)
1182                boolean = eval_boolean(expression, a, b)
1183                if boolean:
1184                    return boolean
1185
1186        return None
1187
1188    @annotate_types_on_change
1189    def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
1190        # COALESCE(x) -> x
1191        if (
1192            isinstance(expression, exp.Coalesce)
1193            and (not expression.expressions or _is_nonnull_constant(expression.this))
1194            # COALESCE is also used as a Spark partitioning hint
1195            and not isinstance(expression.parent, exp.Hint)
1196        ):
1197            return expression.this
1198
1199        if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1200            return expression
1201
1202        if not isinstance(expression, self.COMPARISONS):
1203            return expression
1204
1205        if isinstance(expression.left, exp.Coalesce):
1206            coalesce = expression.left
1207            other = expression.right
1208        elif isinstance(expression.right, exp.Coalesce):
1209            coalesce = expression.right
1210            other = expression.left
1211        else:
1212            return expression
1213
1214        # This transformation is valid for non-constants,
1215        # but it really only does anything if they are both constants.
1216        if not _is_constant(other):
1217            return expression
1218
1219        # Find the first constant arg
1220        for arg_index, arg in enumerate(coalesce.expressions):
1221            if _is_constant(arg):
1222                break
1223        else:
1224            return expression
1225
1226        coalesce.set("expressions", coalesce.expressions[:arg_index])
1227
1228        # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1229        # since we already remove COALESCE at the top of this function.
1230        coalesce = coalesce if coalesce.expressions else coalesce.this
1231
1232        # This expression is more complex than when we started, but it will get simplified further
1233        return exp.paren(
1234            exp.or_(
1235                exp.and_(
1236                    coalesce.is_(exp.null()).not_(copy=False),
1237                    expression.copy(),
1238                    copy=False,
1239                ),
1240                exp.and_(
1241                    coalesce.is_(exp.null()),
1242                    type(expression)(this=arg.copy(), expression=other.copy()),
1243                    copy=False,
1244                ),
1245                copy=False,
1246            ),
1247            copy=False,
1248        )
1249
1250    @annotate_types_on_change
1251    def simplify_concat(self, expression):
1252        """Reduces all groups that contain string literals by concatenating them."""
1253        if not isinstance(expression, self.CONCATS) or (
1254            # We can't reduce a CONCAT_WS call if we don't statically know the separator
1255            isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1256        ):
1257            return expression
1258
1259        if isinstance(expression, exp.ConcatWs):
1260            sep_expr, *expressions = expression.expressions
1261            sep = sep_expr.name
1262            concat_type = exp.ConcatWs
1263            args = {}
1264        else:
1265            expressions = expression.expressions
1266            sep = ""
1267            concat_type = exp.Concat
1268            args = {
1269                "safe": expression.args.get("safe"),
1270                "coalesce": expression.args.get("coalesce"),
1271            }
1272
1273        new_args = []
1274        for is_string_group, group in itertools.groupby(
1275            expressions or expression.flatten(), lambda e: e.is_string
1276        ):
1277            if is_string_group:
1278                new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1279            else:
1280                new_args.extend(group)
1281
1282        if len(new_args) == 1 and new_args[0].is_string:
1283            return new_args[0]
1284
1285        if concat_type is exp.ConcatWs:
1286            new_args = [sep_expr] + new_args
1287        elif isinstance(expression, exp.DPipe):
1288            return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1289
1290        return concat_type(expressions=new_args, **args)
1291
1292    @annotate_types_on_change
1293    def simplify_conditionals(self, expression):
1294        """Simplifies expressions like IF, CASE if their condition is statically known."""
1295        if isinstance(expression, exp.Case):
1296            this = expression.this
1297            for case in expression.args["ifs"]:
1298                cond = case.this
1299                if this:
1300                    # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1301                    cond = cond.replace(this.pop().eq(cond))
1302
1303                if always_true(cond):
1304                    return case.args["true"]
1305
1306                if always_false(cond):
1307                    case.pop()
1308                    if not expression.args["ifs"]:
1309                        return expression.args.get("default") or exp.null()
1310        elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1311            if always_true(expression.this):
1312                return expression.args["true"]
1313            if always_false(expression.this):
1314                return expression.args.get("false") or exp.null()
1315
1316        return expression
1317
1318    @annotate_types_on_change
1319    def simplify_startswith(self, expression: exp.Expression) -> exp.Expression:
1320        """
1321        Reduces a prefix check to either TRUE or FALSE if both the string and the
1322        prefix are statically known.
1323
1324        Example:
1325            >>> from sqlglot import parse_one
1326            >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1327            'TRUE'
1328        """
1329        if (
1330            isinstance(expression, exp.StartsWith)
1331            and expression.this.is_string
1332            and expression.expression.is_string
1333        ):
1334            return exp.convert(expression.name.startswith(expression.expression.name))
1335
1336        return expression
1337
1338    def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool:
1339        return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
1340
1341    @annotate_types_on_change
1342    @catch(ModuleNotFoundError, UnsupportedUnit)
1343    def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression:
1344        """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1345        comparison = expression.__class__
1346
1347        if isinstance(expression, self.DATETRUNCS):
1348            this = expression.this
1349            trunc_type = extract_type(this)
1350            date = extract_date(this)
1351            if date and expression.unit:
1352                return date_literal(
1353                    date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
1354                )
1355        elif comparison not in self.DATETRUNC_COMPARISONS:
1356            return expression
1357
1358        if isinstance(expression, exp.Binary):
1359            l, r = expression.left, expression.right
1360
1361            if not self._is_datetrunc_predicate(l, r):
1362                return expression
1363
1364            l = t.cast(exp.DateTrunc, l)
1365            trunc_arg = l.this
1366            unit = l.unit.name.lower()
1367            date = extract_date(r)
1368
1369            if not date:
1370                return expression
1371
1372            return (
1373                self.DATETRUNC_BINARY_COMPARISONS[comparison](
1374                    trunc_arg, date, unit, self.dialect, extract_type(r)
1375                )
1376                or expression
1377            )
1378
1379        if isinstance(expression, exp.In):
1380            l = expression.this
1381            rs = expression.expressions
1382
1383            if rs and all(self._is_datetrunc_predicate(l, r) for r in rs):
1384                l = t.cast(exp.DateTrunc, l)
1385                unit = l.unit.name.lower()
1386
1387                ranges = []
1388                for r in rs:
1389                    date = extract_date(r)
1390                    if not date:
1391                        return expression
1392                    drange = _datetrunc_range(date, unit, self.dialect)
1393                    if drange:
1394                        ranges.append(drange)
1395
1396                if not ranges:
1397                    return expression
1398
1399                ranges = merge_ranges(ranges)
1400                target_type = extract_type(*rs)
1401
1402                return exp.or_(
1403                    *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
1404                    copy=False,
1405                )
1406
1407        return expression
1408
1409    @annotate_types_on_change
1410    def sort_comparison(self, expression: exp.Expression) -> exp.Expression:
1411        if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1412            l, r = expression.this, expression.expression
1413            l_column = isinstance(l, exp.Column)
1414            r_column = isinstance(r, exp.Column)
1415            l_const = _is_constant(l)
1416            r_const = _is_constant(r)
1417
1418            if (
1419                (l_column and not r_column)
1420                or (r_const and not l_const)
1421                or isinstance(r, exp.SubqueryPredicate)
1422            ):
1423                return expression
1424            if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1425                return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1426                    this=r, expression=l
1427                )
1428        return expression
1429
1430    def _flat_simplify(self, expression, simplifier, root=True):
1431        if root or not expression.same_parent:
1432            operands = []
1433            queue = deque(expression.flatten(unnest=False))
1434            size = len(queue)
1435
1436            while queue:
1437                a = queue.popleft()
1438
1439                for b in queue:
1440                    result = simplifier(expression, a, b)
1441
1442                    if result and result is not expression:
1443                        queue.remove(b)
1444                        queue.appendleft(result)
1445                        break
1446                else:
1447                    operands.append(a)
1448
1449            if len(operands) < size:
1450                return functools.reduce(
1451                    lambda a, b: expression.__class__(this=a, expression=b), operands
1452                )
1453        return expression
Simplifier( dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None, annotate_new_expressions: bool = True)
474    def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
475        self.dialect = Dialect.get_or_raise(dialect)
476        self.annotate_new_expressions = annotate_new_expressions
477
478        self._annotator: TypeAnnotator = TypeAnnotator(
479            schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
480        )
dialect
annotate_new_expressions
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
COMPLEMENT_SUBQUERY_PREDICATES = {<class 'sqlglot.expressions.All'>: <class 'sqlglot.expressions.Any'>, <class 'sqlglot.expressions.Any'>: <class 'sqlglot.expressions.All'>}
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
AND_OR = (<class 'sqlglot.expressions.And'>, <class 'sqlglot.expressions.Or'>)
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.GT'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.LT'>}
SAFE_CONNECTOR_ELIMINATION_RESULT = (<class 'sqlglot.expressions.Connector'>, <class 'sqlglot.expressions.Boolean'>)
JOINS = {('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', '')}
def simplify( self, expression: sqlglot.expressions.Expression, constant_propagation: bool = False, coalesce_simplification: bool = False):
565    def simplify(
566        self,
567        expression: exp.Expression,
568        constant_propagation: bool = False,
569        coalesce_simplification: bool = False,
570    ):
571        wheres = []
572        joins = []
573
574        for node in expression.walk(
575            prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
576        ):
577            if node.meta.get(FINAL):
578                continue
579
580            # group by expressions cannot be simplified, for example
581            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
582            # the projection must exactly match the group by key
583            group = node.args.get("group")
584
585            if group and hasattr(node, "selects"):
586                groups = set(group.expressions)
587                group.meta[FINAL] = True
588
589                for s in node.selects:
590                    for n in s.walk(FINAL):
591                        if n in groups:
592                            s.meta[FINAL] = True
593                            break
594
595                having = node.args.get("having")
596
597                if having:
598                    for n in having.walk():
599                        if n in groups:
600                            having.meta[FINAL] = True
601                            break
602
603            if isinstance(node, exp.Condition):
604                simplified = while_changing(
605                    node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
606                )
607
608                if node is expression:
609                    expression = simplified
610            elif isinstance(node, exp.Where):
611                wheres.append(node)
612            elif isinstance(node, exp.Join):
613                # snowflake match_conditions have very strict ordering rules
614                if match := node.args.get("match_condition"):
615                    match.meta[FINAL] = True
616
617                joins.append(node)
618
619        for where in wheres:
620            if always_true(where.this):
621                where.pop()
622        for join in joins:
623            if (
624                always_true(join.args.get("on"))
625                and not join.args.get("using")
626                and not join.args.get("method")
627                and (join.side, join.kind) in self.JOINS
628            ):
629                join.args["on"].pop()
630                join.set("side", None)
631                join.set("kind", "CROSS")
632
633        return expression
@annotate_types_on_change
def rewrite_between( self, expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
704    @annotate_types_on_change
705    def rewrite_between(self, expression: exp.Expression) -> exp.Expression:
706        """Rewrite x between y and z to x >= y AND x <= z.
707
708        This is done because comparison simplification is only done on lt/lte/gt/gte.
709        """
710        if isinstance(expression, exp.Between):
711            negate = isinstance(expression.parent, exp.Not)
712
713            expression = exp.and_(
714                exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
715                exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
716                copy=False,
717            )
718
719            if negate:
720                expression = exp.paren(expression, copy=False)
721
722        return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

@annotate_types_on_change
def simplify_not( self, expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
724    @annotate_types_on_change
725    def simplify_not(self, expression: exp.Expression) -> exp.Expression:
726        """
727        Demorgan's Law
728        NOT (x OR y) -> NOT x AND NOT y
729        NOT (x AND y) -> NOT x OR NOT y
730        """
731        if isinstance(expression, exp.Not):
732            this = expression.this
733            if is_null(this):
734                return exp.and_(exp.null(), exp.true(), copy=False)
735            if this.__class__ in self.COMPLEMENT_COMPARISONS:
736                right = this.expression
737                complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
738                    right.__class__
739                )
740                if complement_subquery_predicate:
741                    right = complement_subquery_predicate(this=right.this)
742
743                return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
744            if isinstance(this, exp.Paren):
745                condition = this.unnest()
746                if isinstance(condition, exp.And):
747                    return exp.paren(
748                        exp.or_(
749                            exp.not_(condition.left, copy=False),
750                            exp.not_(condition.right, copy=False),
751                            copy=False,
752                        ),
753                        copy=False,
754                    )
755                if isinstance(condition, exp.Or):
756                    return exp.paren(
757                        exp.and_(
758                            exp.not_(condition.left, copy=False),
759                            exp.not_(condition.right, copy=False),
760                            copy=False,
761                        ),
762                        copy=False,
763                    )
764                if is_null(condition):
765                    return exp.and_(exp.null(), exp.true(), copy=False)
766            if always_true(this):
767                return exp.false()
768            if is_false(this):
769                return exp.true()
770            if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
771                inner = this.this
772                if inner.is_type(exp.DataType.Type.BOOLEAN):
773                    # double negation
774                    # NOT NOT x -> x, if x is BOOLEAN type
775                    return inner
776        return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

@annotate_types_on_change
def simplify_connectors(self, expression, root=True):
778    @annotate_types_on_change
779    def simplify_connectors(self, expression, root=True):
780        def _simplify_connectors(expression, left, right):
781            if isinstance(expression, exp.And):
782                if is_false(left) or is_false(right):
783                    return exp.false()
784                if is_zero(left) or is_zero(right):
785                    return exp.false()
786                if (
787                    (is_null(left) and is_null(right))
788                    or (is_null(left) and always_true(right))
789                    or (always_true(left) and is_null(right))
790                ):
791                    return exp.null()
792                if always_true(left) and always_true(right):
793                    return exp.true()
794                if always_true(left):
795                    return right
796                if always_true(right):
797                    return left
798                return self._simplify_comparison(expression, left, right)
799            elif isinstance(expression, exp.Or):
800                if always_true(left) or always_true(right):
801                    return exp.true()
802                if (
803                    (is_null(left) and is_null(right))
804                    or (is_null(left) and always_false(right))
805                    or (always_false(left) and is_null(right))
806                ):
807                    return exp.null()
808                if is_false(left):
809                    return right
810                if is_false(right):
811                    return left
812                return self._simplify_comparison(expression, left, right, or_=True)
813
814        if isinstance(expression, exp.Connector):
815            original_parent = expression.parent
816            expression = self._flat_simplify(expression, _simplify_connectors, root)
817
818            # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
819            # to ensure that the resulting type is boolean. We know this is true only for connectors,
820            # boolean values and columns that are essentially operands to a connector:
821            #
822            # A AND (((B)))
823            #          ~ this is safe to keep because it will eventually be part of another connector
824            if not isinstance(
825                expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
826            ) and not expression.is_type(exp.DataType.Type.BOOLEAN):
827                while True:
828                    if isinstance(original_parent, exp.Connector):
829                        break
830                    if not isinstance(original_parent, exp.Paren):
831                        expression = expression.and_(exp.true(), copy=False)
832                        break
833
834                    original_parent = original_parent.parent
835
836        return expression
@annotate_types_on_change
def remove_complements(self, expression, root=True):
902    @annotate_types_on_change
903    def remove_complements(self, expression, root=True):
904        """
905        Removing complements.
906
907        A AND NOT A -> FALSE (only for non-NULL A)
908        A OR NOT A -> TRUE (only for non-NULL A)
909        """
910        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
911            ops = set(expression.flatten())
912            for op in ops:
913                if isinstance(op, exp.Not) and op.this in ops:
914                    if expression.meta.get("nonnull") is True:
915                        return exp.false() if isinstance(expression, exp.And) else exp.true()
916
917        return expression

Removing complements.

A AND NOT A -> FALSE (only for non-NULL A) A OR NOT A -> TRUE (only for non-NULL A)

@annotate_types_on_change
def uniq_sort(self, expression, root=True):
919    @annotate_types_on_change
920    def uniq_sort(self, expression, root=True):
921        """
922        Uniq and sort a connector.
923
924        C AND A AND B AND B -> A AND B AND C
925        """
926        if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
927            flattened = tuple(expression.flatten())
928
929            if isinstance(expression, exp.Xor):
930                result_func = exp.xor
931                # Do not deduplicate XOR as A XOR A != A if A == True
932                deduped = None
933                arr = tuple((gen(e), e) for e in flattened)
934            else:
935                result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
936                deduped = {gen(e): e for e in flattened}
937                arr = tuple(deduped.items())
938
939            # check if the operands are already sorted, if not sort them
940            # A AND C AND B -> A AND B AND C
941            for i, (sql, e) in enumerate(arr[1:]):
942                if sql < arr[i][0]:
943                    expression = result_func(*(e for _, e in sorted(arr)), copy=False)
944                    break
945            else:
946                # we didn't have to sort but maybe we need to dedup
947                if deduped and len(deduped) < len(flattened):
948                    unique_operand = flattened[0]
949                    if len(deduped) == 1:
950                        expression = unique_operand.and_(exp.true(), copy=False)
951                    else:
952                        expression = result_func(*deduped.values(), copy=False)
953
954        return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

@annotate_types_on_change
def absorb_and_eliminate(self, expression, root=True):
 956    @annotate_types_on_change
 957    def absorb_and_eliminate(self, expression, root=True):
 958        """
 959        absorption:
 960            A AND (A OR B) -> A
 961            A OR (A AND B) -> A
 962            A AND (NOT A OR B) -> A AND B
 963            A OR (NOT A AND B) -> A OR B
 964        elimination:
 965            (A AND B) OR (A AND NOT B) -> A
 966            (A OR B) AND (A OR NOT B) -> A
 967        """
 968        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 969            kind = exp.Or if isinstance(expression, exp.And) else exp.And
 970
 971            ops = tuple(expression.flatten())
 972
 973            # Initialize lookup tables:
 974            # Set of all operands, used to find complements for absorption.
 975            op_set = set()
 976            # Sub-operands, used to find subsets for absorption.
 977            subops = defaultdict(list)
 978            # Pairs of complements, used for elimination.
 979            pairs = defaultdict(list)
 980
 981            # Populate the lookup tables
 982            for op in ops:
 983                op_set.add(op)
 984
 985                if not isinstance(op, kind):
 986                    # In cases like: A OR (A AND B)
 987                    # Subop will be: ^
 988                    subops[op].append({op})
 989                    continue
 990
 991                # In cases like: (A AND B) OR (A AND B AND C)
 992                # Subops will be: ^     ^
 993                subset = set(op.flatten())
 994                for i in subset:
 995                    subops[i].append(subset)
 996
 997                a, b = op.unnest_operands()
 998                if isinstance(a, exp.Not):
 999                    pairs[frozenset((a.this, b))].append((op, b))
1000                if isinstance(b, exp.Not):
1001                    pairs[frozenset((a, b.this))].append((op, a))
1002
1003            for op in ops:
1004                if not isinstance(op, kind):
1005                    continue
1006
1007                a, b = op.unnest_operands()
1008
1009                # Absorb
1010                if isinstance(a, exp.Not) and a.this in op_set:
1011                    a.replace(exp.true() if kind == exp.And else exp.false())
1012                    continue
1013                if isinstance(b, exp.Not) and b.this in op_set:
1014                    b.replace(exp.true() if kind == exp.And else exp.false())
1015                    continue
1016                superset = set(op.flatten())
1017                if any(any(subset < superset for subset in subops[i]) for i in superset):
1018                    op.replace(exp.false() if kind == exp.And else exp.true())
1019                    continue
1020
1021                # Eliminate
1022                for other, complement in pairs[frozenset((a, b))]:
1023                    op.replace(complement)
1024                    other.replace(complement)
1025
1026        return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_equality(expression, *args, **kwargs):
82        def wrapped(expression, *args, **kwargs):
83            try:
84                return func(expression, *args, **kwargs)
85            except exceptions:
86                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

l     r
x + 1 = 3
a   b
@annotate_types_on_change
def simplify_literals(self, expression, root=True):
1078    @annotate_types_on_change
1079    def simplify_literals(self, expression, root=True):
1080        if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1081            return self._flat_simplify(expression, self._simplify_binary, root)
1082
1083        if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1084            return expression.this.this
1085
1086        if type(expression) in self.INVERSE_DATE_OPS:
1087            return (
1088                self._simplify_binary(expression, expression.this, expression.interval())
1089                or expression
1090            )
1091
1092        return expression
@annotate_types_on_change
def simplify_coalesce( self, expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1188    @annotate_types_on_change
1189    def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression:
1190        # COALESCE(x) -> x
1191        if (
1192            isinstance(expression, exp.Coalesce)
1193            and (not expression.expressions or _is_nonnull_constant(expression.this))
1194            # COALESCE is also used as a Spark partitioning hint
1195            and not isinstance(expression.parent, exp.Hint)
1196        ):
1197            return expression.this
1198
1199        if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1200            return expression
1201
1202        if not isinstance(expression, self.COMPARISONS):
1203            return expression
1204
1205        if isinstance(expression.left, exp.Coalesce):
1206            coalesce = expression.left
1207            other = expression.right
1208        elif isinstance(expression.right, exp.Coalesce):
1209            coalesce = expression.right
1210            other = expression.left
1211        else:
1212            return expression
1213
1214        # This transformation is valid for non-constants,
1215        # but it really only does anything if they are both constants.
1216        if not _is_constant(other):
1217            return expression
1218
1219        # Find the first constant arg
1220        for arg_index, arg in enumerate(coalesce.expressions):
1221            if _is_constant(arg):
1222                break
1223        else:
1224            return expression
1225
1226        coalesce.set("expressions", coalesce.expressions[:arg_index])
1227
1228        # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1229        # since we already remove COALESCE at the top of this function.
1230        coalesce = coalesce if coalesce.expressions else coalesce.this
1231
1232        # This expression is more complex than when we started, but it will get simplified further
1233        return exp.paren(
1234            exp.or_(
1235                exp.and_(
1236                    coalesce.is_(exp.null()).not_(copy=False),
1237                    expression.copy(),
1238                    copy=False,
1239                ),
1240                exp.and_(
1241                    coalesce.is_(exp.null()),
1242                    type(expression)(this=arg.copy(), expression=other.copy()),
1243                    copy=False,
1244                ),
1245                copy=False,
1246            ),
1247            copy=False,
1248        )
@annotate_types_on_change
def simplify_concat(self, expression):
1250    @annotate_types_on_change
1251    def simplify_concat(self, expression):
1252        """Reduces all groups that contain string literals by concatenating them."""
1253        if not isinstance(expression, self.CONCATS) or (
1254            # We can't reduce a CONCAT_WS call if we don't statically know the separator
1255            isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1256        ):
1257            return expression
1258
1259        if isinstance(expression, exp.ConcatWs):
1260            sep_expr, *expressions = expression.expressions
1261            sep = sep_expr.name
1262            concat_type = exp.ConcatWs
1263            args = {}
1264        else:
1265            expressions = expression.expressions
1266            sep = ""
1267            concat_type = exp.Concat
1268            args = {
1269                "safe": expression.args.get("safe"),
1270                "coalesce": expression.args.get("coalesce"),
1271            }
1272
1273        new_args = []
1274        for is_string_group, group in itertools.groupby(
1275            expressions or expression.flatten(), lambda e: e.is_string
1276        ):
1277            if is_string_group:
1278                new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1279            else:
1280                new_args.extend(group)
1281
1282        if len(new_args) == 1 and new_args[0].is_string:
1283            return new_args[0]
1284
1285        if concat_type is exp.ConcatWs:
1286            new_args = [sep_expr] + new_args
1287        elif isinstance(expression, exp.DPipe):
1288            return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1289
1290        return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

@annotate_types_on_change
def simplify_conditionals(self, expression):
1292    @annotate_types_on_change
1293    def simplify_conditionals(self, expression):
1294        """Simplifies expressions like IF, CASE if their condition is statically known."""
1295        if isinstance(expression, exp.Case):
1296            this = expression.this
1297            for case in expression.args["ifs"]:
1298                cond = case.this
1299                if this:
1300                    # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1301                    cond = cond.replace(this.pop().eq(cond))
1302
1303                if always_true(cond):
1304                    return case.args["true"]
1305
1306                if always_false(cond):
1307                    case.pop()
1308                    if not expression.args["ifs"]:
1309                        return expression.args.get("default") or exp.null()
1310        elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1311            if always_true(expression.this):
1312                return expression.args["true"]
1313            if always_false(expression.this):
1314                return expression.args.get("false") or exp.null()
1315
1316        return expression

Simplifies expressions like IF, CASE if their condition is statically known.

@annotate_types_on_change
def simplify_startswith( self, expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1318    @annotate_types_on_change
1319    def simplify_startswith(self, expression: exp.Expression) -> exp.Expression:
1320        """
1321        Reduces a prefix check to either TRUE or FALSE if both the string and the
1322        prefix are statically known.
1323
1324        Example:
1325            >>> from sqlglot import parse_one
1326            >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1327            'TRUE'
1328        """
1329        if (
1330            isinstance(expression, exp.StartsWith)
1331            and expression.this.is_string
1332            and expression.expression.is_string
1333        ):
1334            return exp.convert(expression.name.startswith(expression.expression.name))
1335
1336        return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
def simplify_datetrunc(expression, *args, **kwargs):
82        def wrapped(expression, *args, **kwargs):
83            try:
84                return func(expression, *args, **kwargs)
85            except exceptions:
86                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

@annotate_types_on_change
def sort_comparison( self, expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
1409    @annotate_types_on_change
1410    def sort_comparison(self, expression: exp.Expression) -> exp.Expression:
1411        if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1412            l, r = expression.this, expression.expression
1413            l_column = isinstance(l, exp.Column)
1414            r_column = isinstance(r, exp.Column)
1415            l_const = _is_constant(l)
1416            r_const = _is_constant(r)
1417
1418            if (
1419                (l_column and not r_column)
1420                or (r_const and not l_const)
1421                or isinstance(r, exp.SubqueryPredicate)
1422            ):
1423                return expression
1424            if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1425                return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1426                    this=r, expression=l
1427                )
1428        return expression
def gen(expression: Any, comments: bool = False) -> str:
1456def gen(expression: t.Any, comments: bool = False) -> str:
1457    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1458
1459    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1460    generator is expensive so we have a bare minimum sql generator here.
1461
1462    Args:
1463        expression: the expression to convert into a SQL string.
1464        comments: whether to include the expression's comments.
1465    """
1466    return Gen().gen(expression, comments=comments)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

Arguments:
  • expression: the expression to convert into a SQL string.
  • comments: whether to include the expression's comments.
class Gen:
1469class Gen:
1470    def __init__(self):
1471        self.stack = []
1472        self.sqls = []
1473
1474    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1475        self.stack = [expression]
1476        self.sqls.clear()
1477
1478        while self.stack:
1479            node = self.stack.pop()
1480
1481            if isinstance(node, exp.Expression):
1482                if comments and node.comments:
1483                    self.stack.append(f" /*{','.join(node.comments)}*/")
1484
1485                exp_handler_name = f"{node.key}_sql"
1486
1487                if hasattr(self, exp_handler_name):
1488                    getattr(self, exp_handler_name)(node)
1489                elif isinstance(node, exp.Func):
1490                    self._function(node)
1491                else:
1492                    key = node.key.upper()
1493                    self.stack.append(f"{key} " if self._args(node) else key)
1494            elif type(node) is list:
1495                for n in reversed(node):
1496                    if n is not None:
1497                        self.stack.extend((n, ","))
1498                if node:
1499                    self.stack.pop()
1500            else:
1501                if node is not None:
1502                    self.sqls.append(str(node))
1503
1504        return "".join(self.sqls)
1505
1506    def add_sql(self, e: exp.Add) -> None:
1507        self._binary(e, " + ")
1508
1509    def alias_sql(self, e: exp.Alias) -> None:
1510        self.stack.extend(
1511            (
1512                e.args.get("alias"),
1513                " AS ",
1514                e.args.get("this"),
1515            )
1516        )
1517
1518    def and_sql(self, e: exp.And) -> None:
1519        self._binary(e, " AND ")
1520
1521    def anonymous_sql(self, e: exp.Anonymous) -> None:
1522        this = e.this
1523        if isinstance(this, str):
1524            name = this.upper()
1525        elif isinstance(this, exp.Identifier):
1526            name = this.this
1527            name = f'"{name}"' if this.quoted else name.upper()
1528        else:
1529            raise ValueError(
1530                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1531            )
1532
1533        self.stack.extend(
1534            (
1535                ")",
1536                e.expressions,
1537                "(",
1538                name,
1539            )
1540        )
1541
1542    def between_sql(self, e: exp.Between) -> None:
1543        self.stack.extend(
1544            (
1545                e.args.get("high"),
1546                " AND ",
1547                e.args.get("low"),
1548                " BETWEEN ",
1549                e.this,
1550            )
1551        )
1552
1553    def boolean_sql(self, e: exp.Boolean) -> None:
1554        self.stack.append("TRUE" if e.this else "FALSE")
1555
1556    def bracket_sql(self, e: exp.Bracket) -> None:
1557        self.stack.extend(
1558            (
1559                "]",
1560                e.expressions,
1561                "[",
1562                e.this,
1563            )
1564        )
1565
1566    def column_sql(self, e: exp.Column) -> None:
1567        for p in reversed(e.parts):
1568            self.stack.extend((p, "."))
1569        self.stack.pop()
1570
1571    def datatype_sql(self, e: exp.DataType) -> None:
1572        self._args(e, 1)
1573        self.stack.append(f"{e.this.name} ")
1574
1575    def div_sql(self, e: exp.Div) -> None:
1576        self._binary(e, " / ")
1577
1578    def dot_sql(self, e: exp.Dot) -> None:
1579        self._binary(e, ".")
1580
1581    def eq_sql(self, e: exp.EQ) -> None:
1582        self._binary(e, " = ")
1583
1584    def from_sql(self, e: exp.From) -> None:
1585        self.stack.extend((e.this, "FROM "))
1586
1587    def gt_sql(self, e: exp.GT) -> None:
1588        self._binary(e, " > ")
1589
1590    def gte_sql(self, e: exp.GTE) -> None:
1591        self._binary(e, " >= ")
1592
1593    def identifier_sql(self, e: exp.Identifier) -> None:
1594        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1595
1596    def ilike_sql(self, e: exp.ILike) -> None:
1597        self._binary(e, " ILIKE ")
1598
1599    def in_sql(self, e: exp.In) -> None:
1600        self.stack.append(")")
1601        self._args(e, 1)
1602        self.stack.extend(
1603            (
1604                "(",
1605                " IN ",
1606                e.this,
1607            )
1608        )
1609
1610    def intdiv_sql(self, e: exp.IntDiv) -> None:
1611        self._binary(e, " DIV ")
1612
1613    def is_sql(self, e: exp.Is) -> None:
1614        self._binary(e, " IS ")
1615
1616    def like_sql(self, e: exp.Like) -> None:
1617        self._binary(e, " Like ")
1618
1619    def literal_sql(self, e: exp.Literal) -> None:
1620        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1621
1622    def lt_sql(self, e: exp.LT) -> None:
1623        self._binary(e, " < ")
1624
1625    def lte_sql(self, e: exp.LTE) -> None:
1626        self._binary(e, " <= ")
1627
1628    def mod_sql(self, e: exp.Mod) -> None:
1629        self._binary(e, " % ")
1630
1631    def mul_sql(self, e: exp.Mul) -> None:
1632        self._binary(e, " * ")
1633
1634    def neg_sql(self, e: exp.Neg) -> None:
1635        self._unary(e, "-")
1636
1637    def neq_sql(self, e: exp.NEQ) -> None:
1638        self._binary(e, " <> ")
1639
1640    def not_sql(self, e: exp.Not) -> None:
1641        self._unary(e, "NOT ")
1642
1643    def null_sql(self, e: exp.Null) -> None:
1644        self.stack.append("NULL")
1645
1646    def or_sql(self, e: exp.Or) -> None:
1647        self._binary(e, " OR ")
1648
1649    def paren_sql(self, e: exp.Paren) -> None:
1650        self.stack.extend(
1651            (
1652                ")",
1653                e.this,
1654                "(",
1655            )
1656        )
1657
1658    def sub_sql(self, e: exp.Sub) -> None:
1659        self._binary(e, " - ")
1660
1661    def subquery_sql(self, e: exp.Subquery) -> None:
1662        self._args(e, 2)
1663        alias = e.args.get("alias")
1664        if alias:
1665            self.stack.append(alias)
1666        self.stack.extend((")", e.this, "("))
1667
1668    def table_sql(self, e: exp.Table) -> None:
1669        self._args(e, 4)
1670        alias = e.args.get("alias")
1671        if alias:
1672            self.stack.append(alias)
1673        for p in reversed(e.parts):
1674            self.stack.extend((p, "."))
1675        self.stack.pop()
1676
1677    def tablealias_sql(self, e: exp.TableAlias) -> None:
1678        columns = e.columns
1679
1680        if columns:
1681            self.stack.extend((")", columns, "("))
1682
1683        self.stack.extend((e.this, " AS "))
1684
1685    def var_sql(self, e: exp.Var) -> None:
1686        self.stack.append(e.this)
1687
1688    def _binary(self, e: exp.Binary, op: str) -> None:
1689        self.stack.extend((e.expression, op, e.this))
1690
1691    def _unary(self, e: exp.Unary, op: str) -> None:
1692        self.stack.extend((e.this, op))
1693
1694    def _function(self, e: exp.Func) -> None:
1695        self.stack.extend(
1696            (
1697                ")",
1698                list(e.args.values()),
1699                "(",
1700                e.sql_name(),
1701            )
1702        )
1703
1704    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1705        kvs = []
1706        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1707
1708        for k in arg_types:
1709            v = node.args.get(k)
1710
1711            if v is not None:
1712                kvs.append([f":{k}", v])
1713        if kvs:
1714            self.stack.append(kvs)
1715            return True
1716        return False
stack
sqls
def gen( self, expression: sqlglot.expressions.Expression, comments: bool = False) -> str:
1474    def gen(self, expression: exp.Expression, comments: bool = False) -> str:
1475        self.stack = [expression]
1476        self.sqls.clear()
1477
1478        while self.stack:
1479            node = self.stack.pop()
1480
1481            if isinstance(node, exp.Expression):
1482                if comments and node.comments:
1483                    self.stack.append(f" /*{','.join(node.comments)}*/")
1484
1485                exp_handler_name = f"{node.key}_sql"
1486
1487                if hasattr(self, exp_handler_name):
1488                    getattr(self, exp_handler_name)(node)
1489                elif isinstance(node, exp.Func):
1490                    self._function(node)
1491                else:
1492                    key = node.key.upper()
1493                    self.stack.append(f"{key} " if self._args(node) else key)
1494            elif type(node) is list:
1495                for n in reversed(node):
1496                    if n is not None:
1497                        self.stack.extend((n, ","))
1498                if node:
1499                    self.stack.pop()
1500            else:
1501                if node is not None:
1502                    self.sqls.append(str(node))
1503
1504        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1506    def add_sql(self, e: exp.Add) -> None:
1507        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1509    def alias_sql(self, e: exp.Alias) -> None:
1510        self.stack.extend(
1511            (
1512                e.args.get("alias"),
1513                " AS ",
1514                e.args.get("this"),
1515            )
1516        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1518    def and_sql(self, e: exp.And) -> None:
1519        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1521    def anonymous_sql(self, e: exp.Anonymous) -> None:
1522        this = e.this
1523        if isinstance(this, str):
1524            name = this.upper()
1525        elif isinstance(this, exp.Identifier):
1526            name = this.this
1527            name = f'"{name}"' if this.quoted else name.upper()
1528        else:
1529            raise ValueError(
1530                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1531            )
1532
1533        self.stack.extend(
1534            (
1535                ")",
1536                e.expressions,
1537                "(",
1538                name,
1539            )
1540        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1542    def between_sql(self, e: exp.Between) -> None:
1543        self.stack.extend(
1544            (
1545                e.args.get("high"),
1546                " AND ",
1547                e.args.get("low"),
1548                " BETWEEN ",
1549                e.this,
1550            )
1551        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1553    def boolean_sql(self, e: exp.Boolean) -> None:
1554        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1556    def bracket_sql(self, e: exp.Bracket) -> None:
1557        self.stack.extend(
1558            (
1559                "]",
1560                e.expressions,
1561                "[",
1562                e.this,
1563            )
1564        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1566    def column_sql(self, e: exp.Column) -> None:
1567        for p in reversed(e.parts):
1568            self.stack.extend((p, "."))
1569        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1571    def datatype_sql(self, e: exp.DataType) -> None:
1572        self._args(e, 1)
1573        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1575    def div_sql(self, e: exp.Div) -> None:
1576        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1578    def dot_sql(self, e: exp.Dot) -> None:
1579        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1581    def eq_sql(self, e: exp.EQ) -> None:
1582        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1584    def from_sql(self, e: exp.From) -> None:
1585        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1587    def gt_sql(self, e: exp.GT) -> None:
1588        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1590    def gte_sql(self, e: exp.GTE) -> None:
1591        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1593    def identifier_sql(self, e: exp.Identifier) -> None:
1594        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1596    def ilike_sql(self, e: exp.ILike) -> None:
1597        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1599    def in_sql(self, e: exp.In) -> None:
1600        self.stack.append(")")
1601        self._args(e, 1)
1602        self.stack.extend(
1603            (
1604                "(",
1605                " IN ",
1606                e.this,
1607            )
1608        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1610    def intdiv_sql(self, e: exp.IntDiv) -> None:
1611        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1613    def is_sql(self, e: exp.Is) -> None:
1614        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1616    def like_sql(self, e: exp.Like) -> None:
1617        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1619    def literal_sql(self, e: exp.Literal) -> None:
1620        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1622    def lt_sql(self, e: exp.LT) -> None:
1623        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1625    def lte_sql(self, e: exp.LTE) -> None:
1626        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1628    def mod_sql(self, e: exp.Mod) -> None:
1629        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1631    def mul_sql(self, e: exp.Mul) -> None:
1632        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1634    def neg_sql(self, e: exp.Neg) -> None:
1635        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1637    def neq_sql(self, e: exp.NEQ) -> None:
1638        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1640    def not_sql(self, e: exp.Not) -> None:
1641        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1643    def null_sql(self, e: exp.Null) -> None:
1644        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1646    def or_sql(self, e: exp.Or) -> None:
1647        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1649    def paren_sql(self, e: exp.Paren) -> None:
1650        self.stack.extend(
1651            (
1652                ")",
1653                e.this,
1654                "(",
1655            )
1656        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1658    def sub_sql(self, e: exp.Sub) -> None:
1659        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1661    def subquery_sql(self, e: exp.Subquery) -> None:
1662        self._args(e, 2)
1663        alias = e.args.get("alias")
1664        if alias:
1665            self.stack.append(alias)
1666        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1668    def table_sql(self, e: exp.Table) -> None:
1669        self._args(e, 4)
1670        alias = e.args.get("alias")
1671        if alias:
1672            self.stack.append(alias)
1673        for p in reversed(e.parts):
1674            self.stack.extend((p, "."))
1675        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1677    def tablealias_sql(self, e: exp.TableAlias) -> None:
1678        columns = e.columns
1679
1680        if columns:
1681            self.stack.extend((")", columns, "("))
1682
1683        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1685    def var_sql(self, e: exp.Var) -> None:
1686        self.stack.append(e.this)