Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3from datetime import date, datetime, timedelta
   4import logging
   5import functools
   6import itertools
   7import typing as t
   8from collections import deque, defaultdict
   9from functools import reduce, wraps
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.annotate_types import TypeAnnotator
  15from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  16from sqlglot.schema import ensure_schema
  17
  18
  19if t.TYPE_CHECKING:
  20    from dateutil.relativedelta import relativedelta
  21    from sqlglot.dialects.dialect import DialectType
  22    from typing_extensions import TypeIs
  23
  24    DateRange = tuple[date, date]
  25    DateTruncBinaryTransform = t.Callable[
  26        [exp.Expr, date, str, Dialect, exp.DataType], t.Optional[exp.Expr]
  27    ]
  28
  29logger = logging.getLogger("sqlglot")
  30
  31
  32# Final means that an expression should not be simplified
  33FINAL = "final"
  34
  35SIMPLIFIABLE = (
  36    exp.Binary,
  37    exp.Func,
  38    exp.Lambda,
  39    exp.Predicate,
  40    exp.Unary,
  41)
  42
  43
  44def simplify(
  45    expression: exp.Expr,
  46    constant_propagation: bool = False,
  47    coalesce_simplification: bool = False,
  48    dialect: DialectType = None,
  49):
  50    """
  51    Rewrite sqlglot AST to simplify expressions.
  52
  53    Example:
  54        >>> import sqlglot
  55        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  56        >>> simplify(expression).sql()
  57        'TRUE'
  58
  59    Args:
  60        expression: expression to simplify
  61        constant_propagation: whether the constant propagation rule should be used
  62        coalesce_simplification: whether the simplify coalesce rule should be used.
  63            This rule tries to remove coalesce functions, which can be useful in certain analyses but
  64            can leave the query more verbose.
  65    Returns:
  66        sqlglot.Expr: simplified expression
  67    """
  68    return Simplifier(dialect=dialect).simplify(
  69        expression,
  70        constant_propagation=constant_propagation,
  71        coalesce_simplification=coalesce_simplification,
  72    )
  73
  74
  75class UnsupportedUnit(Exception):
  76    pass
  77
  78
  79def catch(*exceptions):
  80    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
  81
  82    def decorator(func):
  83        def wrapped(expression, *args, **kwargs):
  84            try:
  85                return func(expression, *args, **kwargs)
  86            except exceptions:
  87                return expression
  88
  89        return wrapped
  90
  91    return decorator
  92
  93
  94def annotate_types_on_change(func):
  95    @wraps(func)
  96    def _func(self, expression: exp.Expr, *args, **kwargs) -> exp.Expr | None:
  97        new_expression: exp.Expr | None = func(self, expression, *args, **kwargs)
  98
  99        if new_expression is None:
 100            return new_expression
 101
 102        if self.annotate_new_expressions and expression != new_expression:
 103            self._annotator.clear()
 104
 105            # We annotate this to ensure new children nodes are also annotated
 106            new_expression = self._annotator.annotate(
 107                expression=new_expression,
 108                annotate_scope=False,
 109            )
 110
 111            # Whatever expression the original expression is transformed into needs to preserve
 112            # the original type, otherwise the simplification could result in a different schema
 113            new_expression.type = expression.type
 114
 115        return new_expression
 116
 117    return _func
 118
 119
 120def flatten(expression: exp.Expr) -> exp.Expr:
 121    """
 122    A AND (B AND C) -> A AND B AND C
 123    A OR (B OR C) -> A OR B OR C
 124    """
 125    if isinstance(expression, exp.Connector):
 126        for node in expression.args.values():
 127            child = node.unnest()
 128            if isinstance(child, expression.__class__):
 129                node.replace(child)
 130    return expression
 131
 132
 133def simplify_parens(expression: exp.Expr, dialect: DialectType) -> exp.Expr:
 134    if not isinstance(expression, exp.Paren):
 135        return expression
 136
 137    this = expression.this
 138    parent = expression.parent
 139    parent_is_predicate = isinstance(parent, exp.Predicate)
 140
 141    if isinstance(this, exp.Select):
 142        return expression
 143
 144    if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
 145        return expression
 146
 147    if (
 148        Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
 149        and isinstance(parent, exp.Dot)
 150        and (isinstance(parent.right, (exp.Identifier, exp.Star)))
 151    ):
 152        return expression
 153
 154    if isinstance(this, exp.Predicate) and (
 155        not (
 156            parent_is_predicate
 157            or isinstance(parent, exp.Neg)
 158            or (isinstance(parent, exp.Binary) and not isinstance(parent, exp.Connector))
 159        )
 160    ):
 161        return this
 162
 163    if (
 164        not isinstance(parent, (exp.Condition, exp.Binary))
 165        or isinstance(parent, exp.Paren)
 166        or (
 167            not isinstance(this, exp.Binary)
 168            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 169        )
 170        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 171        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 172        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 173    ):
 174        return this
 175
 176    return expression
 177
 178
 179def propagate_constants(expression, root=True):
 180    """
 181    Propagate constants for conjunctions in DNF:
 182
 183    SELECT * FROM t WHERE a = b AND b = 5 becomes
 184    SELECT * FROM t WHERE a = 5 AND b = 5
 185
 186    Reference: https://www.sqlite.org/optoverview.html
 187    """
 188
 189    if (
 190        isinstance(expression, exp.And)
 191        and (root or not expression.same_parent)
 192        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 193    ):
 194        constant_mapping = {}
 195        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 196            if isinstance(expr, exp.EQ):
 197                l, r = expr.left, expr.right
 198
 199                # TODO: create a helper that can be used to detect nested literal expressions such
 200                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 201                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 202                    constant_mapping[l] = (id(l), r)
 203
 204        if constant_mapping:
 205            for column in find_all_in_scope(expression, exp.Column):
 206                parent = column.parent
 207                column_id, constant = constant_mapping.get(column) or (None, None)
 208                if (
 209                    column_id is not None
 210                    and id(column) != column_id
 211                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 212                ):
 213                    column.replace(constant.copy())
 214
 215    return expression
 216
 217
 218def _is_number(expression: exp.Expr) -> bool:
 219    return expression.is_number
 220
 221
 222def _is_interval(expression: exp.Expr) -> bool:
 223    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 224
 225
 226def _is_nonnull_constant(expression: exp.Expr) -> bool:
 227    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 228
 229
 230def _is_constant(expression: exp.Expr) -> bool:
 231    expr = expression.this if isinstance(expression, exp.Neg) else expression
 232    return isinstance(expr, exp.CONSTANTS) or _is_date_literal(expr)
 233
 234
 235def _datetrunc_range(date: date, unit: str, dialect: Dialect) -> DateRange | None:
 236    """
 237    Get the date range for a DATE_TRUNC equality comparison:
 238
 239    Example:
 240        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 241    Returns:
 242        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 243    """
 244    floor = date_floor(date, unit, dialect)
 245
 246    if date != floor:
 247        # This will always be False, except for NULL values.
 248        return None
 249
 250    return floor, floor + interval(unit)
 251
 252
 253def _datetrunc_eq_expression(
 254    left: exp.Expr, drange: DateRange, target_type: exp.DataType | None
 255) -> exp.Expr:
 256    """Get the logical expression for a date range"""
 257    return exp.and_(
 258        left >= date_literal(drange[0], target_type),
 259        left < date_literal(drange[1], target_type),
 260        copy=False,
 261    )
 262
 263
 264def _datetrunc_eq(
 265    left: exp.Expr,
 266    date: date,
 267    unit: str,
 268    dialect: Dialect,
 269    target_type: exp.DataType | None,
 270) -> exp.Expr | None:
 271    drange = _datetrunc_range(date, unit, dialect)
 272    if not drange:
 273        return None
 274
 275    return _datetrunc_eq_expression(left, drange, target_type)
 276
 277
 278def _datetrunc_neq(
 279    left: exp.Expr,
 280    date: date,
 281    unit: str,
 282    dialect: Dialect,
 283    target_type: exp.DataType | None,
 284) -> exp.Expr | None:
 285    drange = _datetrunc_range(date, unit, dialect)
 286    if not drange:
 287        return None
 288
 289    return exp.and_(
 290        left < date_literal(drange[0], target_type),
 291        left >= date_literal(drange[1], target_type),
 292        copy=False,
 293    )
 294
 295
 296def always_true(expression: object) -> bool:
 297    return (isinstance(expression, exp.Boolean) and expression.this) or (
 298        isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression)
 299    )
 300
 301
 302def always_false(expression: object) -> bool:
 303    return is_false(expression) or is_null(expression) or is_zero(expression)
 304
 305
 306def is_zero(expression: object) -> bool:
 307    return isinstance(expression, exp.Literal) and expression.to_py() == 0
 308
 309
 310def is_complement(a: object, b: object) -> bool:
 311    return isinstance(b, exp.Not) and b.this == a
 312
 313
 314def is_false(a: object) -> bool:
 315    return type(a) is exp.Boolean and not a.this
 316
 317
 318def is_null(a: object) -> bool:
 319    return type(a) is exp.Null
 320
 321
 322class SupportsComparison(t.Protocol):
 323    """Protocol for expressions or values that can be compared using <, <=, >, >=."""
 324
 325    def __lt__(self, other: t.Any) -> bool: ...
 326    def __le__(self, other: t.Any) -> bool: ...
 327    def __gt__(self, other: t.Any) -> bool: ...
 328    def __ge__(self, other: t.Any) -> bool: ...
 329
 330
 331def eval_boolean(
 332    expression: object, a: SupportsComparison, b: SupportsComparison
 333) -> exp.Boolean | None:
 334    if isinstance(expression, (exp.EQ, exp.Is)):
 335        return boolean_literal(a == b)
 336    if isinstance(expression, exp.NEQ):
 337        return boolean_literal(a != b)
 338    if isinstance(expression, exp.GT):
 339        return boolean_literal(a > b)
 340    if isinstance(expression, exp.GTE):
 341        return boolean_literal(a >= b)
 342    if isinstance(expression, exp.LT):
 343        return boolean_literal(a < b)
 344    if isinstance(expression, exp.LTE):
 345        return boolean_literal(a <= b)
 346    return None
 347
 348
 349def cast_as_date(value: datetime | date | str) -> date | None:
 350    if isinstance(value, datetime):
 351        return value.date()
 352    if isinstance(value, date):
 353        return value
 354    try:
 355        return datetime.fromisoformat(value).date()
 356    except ValueError:
 357        return None
 358
 359
 360def cast_as_datetime(
 361    value: datetime | date | str,
 362) -> datetime | None:
 363    if isinstance(value, datetime):
 364        return value
 365    if isinstance(value, date):
 366        return datetime(year=value.year, month=value.month, day=value.day)
 367    try:
 368        return datetime.fromisoformat(value)
 369    except ValueError:
 370        return None
 371
 372
 373def cast_value(value: datetime | date | str, to: exp.DataType) -> date | date | None:
 374    if not value:
 375        return None
 376    if to.is_type(exp.DType.DATE):
 377        return cast_as_date(value)
 378    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 379        return cast_as_datetime(value)
 380    return None
 381
 382
 383def extract_date(cast: exp.Expr) -> date | date | None:
 384    if isinstance(cast, exp.Cast):
 385        to = cast.to
 386    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
 387        to = exp.DType.DATE.into_expr()
 388    else:
 389        return None
 390
 391    if isinstance(cast.this, exp.Literal):
 392        value: t.Any = cast.this.name
 393    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 394        value = extract_date(cast.this)
 395    else:
 396        return None
 397    return cast_value(value, to)
 398
 399
 400def _is_date_literal(expression: exp.Expr) -> bool:
 401    return extract_date(expression) is not None
 402
 403
 404def extract_interval(expression: exp.Expr) -> relativedelta | None:
 405    try:
 406        n = int(expression.this.to_py())
 407        unit = expression.text("unit").lower()
 408        return interval(unit, n)
 409    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
 410        return None
 411
 412
 413def extract_type(*expressions: exp.Expr):
 414    target_type = None
 415    for expression in expressions:
 416        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
 417        if target_type:
 418            break
 419
 420    return target_type
 421
 422
 423def date_literal(date: object, target_type=None) -> exp.Expr:
 424    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
 425        target_type = exp.DType.DATETIME if isinstance(date, datetime) else exp.DType.DATE
 426
 427    return exp.cast(exp.Literal.string(date), target_type)
 428
 429
 430def interval(unit: str, n: int = 1) -> relativedelta:
 431    from dateutil.relativedelta import relativedelta
 432
 433    if unit == "year":
 434        return relativedelta(years=1 * n)
 435    if unit == "quarter":
 436        return relativedelta(months=3 * n)
 437    if unit == "month":
 438        return relativedelta(months=1 * n)
 439    if unit == "week":
 440        return relativedelta(weeks=1 * n)
 441    if unit == "day":
 442        return relativedelta(days=1 * n)
 443    if unit == "hour":
 444        return relativedelta(hours=1 * n)
 445    if unit == "minute":
 446        return relativedelta(minutes=1 * n)
 447    if unit == "second":
 448        return relativedelta(seconds=1 * n)
 449
 450    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 451
 452
 453def date_floor(d: date, unit: str, dialect: Dialect) -> date:
 454    if unit == "year":
 455        return d.replace(month=1, day=1)
 456    if unit == "quarter":
 457        if d.month <= 3:
 458            return d.replace(month=1, day=1)
 459        elif d.month <= 6:
 460            return d.replace(month=4, day=1)
 461        elif d.month <= 9:
 462            return d.replace(month=7, day=1)
 463        else:
 464            return d.replace(month=10, day=1)
 465    if unit == "month":
 466        return d.replace(month=d.month, day=1)
 467    if unit == "week":
 468        # Assuming week starts on Monday (0) and ends on Sunday (6)
 469        return d - timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
 470    if unit == "day":
 471        return d
 472
 473    raise UnsupportedUnit(f"Unsupported unit: {unit}")
 474
 475
 476def date_ceil(d: date, unit: str, dialect: Dialect) -> date:
 477    floor = date_floor(d, unit, dialect)
 478
 479    if floor == d:
 480        return d
 481
 482    return floor + interval(unit)
 483
 484
 485def boolean_literal(condition: bool | exp.Predicate) -> exp.Boolean:
 486    return exp.true() if condition else exp.false()
 487
 488
 489class Simplifier:
 490    def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
 491        self.dialect = Dialect.get_or_raise(dialect)
 492        self.annotate_new_expressions = annotate_new_expressions
 493
 494        self._annotator: TypeAnnotator = TypeAnnotator(
 495            schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
 496        )
 497
 498    # Value ranges for byte-sized signed/unsigned integers
 499    TINYINT_MIN: t.ClassVar = -128
 500    TINYINT_MAX: t.ClassVar = 127
 501    UTINYINT_MIN: t.ClassVar = 0
 502    UTINYINT_MAX: t.ClassVar = 255
 503
 504    COMPLEMENT_COMPARISONS: t.ClassVar = {
 505        exp.LT: exp.GTE,
 506        exp.GT: exp.LTE,
 507        exp.LTE: exp.GT,
 508        exp.GTE: exp.LT,
 509        exp.EQ: exp.NEQ,
 510        exp.NEQ: exp.EQ,
 511    }
 512
 513    COMPLEMENT_SUBQUERY_PREDICATES: t.ClassVar = {
 514        exp.All: exp.Any,
 515        exp.Any: exp.All,
 516    }
 517
 518    LT_LTE: t.ClassVar = (exp.LT, exp.LTE)
 519    GT_GTE: t.ClassVar = (exp.GT, exp.GTE)
 520
 521    COMPARISONS: t.ClassVar = (
 522        *LT_LTE,
 523        *GT_GTE,
 524        exp.EQ,
 525        exp.NEQ,
 526        exp.Is,
 527    )
 528
 529    INVERSE_COMPARISONS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = {
 530        exp.LT: exp.GT,
 531        exp.GT: exp.LT,
 532        exp.LTE: exp.GTE,
 533        exp.GTE: exp.LTE,
 534    }
 535
 536    NONDETERMINISTIC: t.ClassVar = (exp.Rand, exp.Randn)
 537    AND_OR: t.ClassVar = (exp.And, exp.Or)
 538
 539    INVERSE_DATE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = {
 540        exp.DateAdd: exp.Sub,
 541        exp.DateSub: exp.Add,
 542        exp.DatetimeAdd: exp.Sub,
 543        exp.DatetimeSub: exp.Add,
 544    }
 545
 546    INVERSE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = {
 547        **INVERSE_DATE_OPS,
 548        exp.Add: exp.Sub,
 549        exp.Sub: exp.Add,
 550    }
 551
 552    NULL_OK: t.ClassVar = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 553
 554    CONCATS: t.ClassVar = (exp.Concat, exp.DPipe)
 555
 556    DATETRUNC_BINARY_COMPARISONS: t.ClassVar[dict[type[exp.Expr], DateTruncBinaryTransform]] = {
 557        exp.LT: lambda l, dt, u, d, t: (
 558            l
 559            < date_literal(
 560                dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t
 561            )
 562        ),
 563        exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 564        exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 565        exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 566        exp.EQ: _datetrunc_eq,
 567        exp.NEQ: _datetrunc_neq,
 568    }
 569
 570    DATETRUNC_COMPARISONS: t.ClassVar = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 571    DATETRUNCS: t.ClassVar = (exp.DateTrunc, exp.TimestampTrunc)
 572
 573    SAFE_CONNECTOR_ELIMINATION_RESULT: t.ClassVar = (exp.Connector, exp.Boolean)
 574
 575    # CROSS joins result in an empty table if the right table is empty.
 576    # So we can only simplify certain types of joins to CROSS.
 577    # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 578    JOINS: t.ClassVar = {
 579        ("", ""),
 580        ("", "INNER"),
 581        ("RIGHT", ""),
 582        ("RIGHT", "OUTER"),
 583    }
 584
 585    def simplify(
 586        self,
 587        expression: exp.Expr,
 588        constant_propagation: bool = False,
 589        coalesce_simplification: bool = False,
 590    ) -> exp.Expr:
 591        wheres: list[exp.Where] = []
 592        joins: list[exp.Join] = []
 593
 594        for node in expression.walk(
 595            prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
 596        ):
 597            if node.meta.get(FINAL):
 598                continue
 599
 600            # group by expressions cannot be simplified, for example
 601            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 602            # the projection must exactly match the group by key
 603            group = node.args.get("group")
 604
 605            if group and hasattr(node, "selects"):
 606                groups = set(group.expressions)
 607                group.meta[FINAL] = True
 608
 609                for s in node.selects:
 610                    for n in s.walk():
 611                        if n in groups:
 612                            s.meta[FINAL] = True
 613                            break
 614
 615                having = node.args.get("having")
 616
 617                if having:
 618                    for n in having.walk():
 619                        if n in groups:
 620                            having.meta[FINAL] = True
 621                            break
 622
 623            if isinstance(node, exp.Condition):
 624                simplified = while_changing(
 625                    node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
 626                )
 627
 628                if node is expression:
 629                    expression = simplified
 630            elif isinstance(node, exp.Where):
 631                wheres.append(node)
 632            elif isinstance(node, exp.Join):
 633                # snowflake match_conditions have very strict ordering rules
 634                if match := node.args.get("match_condition"):
 635                    match.meta[FINAL] = True
 636
 637                joins.append(node)
 638
 639        for where in wheres:
 640            if always_true(where.this):
 641                where.pop()
 642        for join in joins:
 643            if (
 644                always_true(join.args.get("on"))
 645                and not join.args.get("using")
 646                and not join.args.get("method")
 647                and (join.side, join.kind) in self.JOINS
 648            ):
 649                join.args["on"].pop()
 650                join.set("side", None)
 651                join.set("kind", "CROSS")
 652
 653        return expression
 654
 655    def _simplify(
 656        self, expression: exp.Expr, constant_propagation: bool, coalesce_simplification: bool
 657    ):
 658        pre_transformation_stack = [expression]
 659        post_transformation_stack = []
 660
 661        while pre_transformation_stack:
 662            original = pre_transformation_stack.pop()
 663            node = original
 664
 665            if not isinstance(node, SIMPLIFIABLE):
 666                if isinstance(node, exp.Query):
 667                    self.simplify(node, constant_propagation, coalesce_simplification)
 668                continue
 669
 670            parent = node.parent
 671            root = node is expression
 672
 673            node = self.rewrite_between(node)
 674            node = self.uniq_sort(node, root)
 675            node = self.absorb_and_eliminate(node, root)
 676            node = self.simplify_concat(node)
 677            node = self.simplify_conditionals(node)
 678
 679            if constant_propagation:
 680                node = propagate_constants(node, root)
 681
 682            if node is not original:
 683                original.replace(node)
 684
 685            for n in node.iter_expressions(reverse=True):
 686                if n.meta.get(FINAL):
 687                    raise
 688            pre_transformation_stack.extend(
 689                n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 690            )
 691            post_transformation_stack.append((node, parent))
 692
 693        while post_transformation_stack:
 694            original, parent = post_transformation_stack.pop()
 695            root = original is expression
 696
 697            # Resets parent, arg_key, index pointers– this is needed because some of the
 698            # previous transformations mutate the AST, leading to an inconsistent state
 699            for k, v in tuple(original.args.items()):
 700                original.set(k, v)
 701
 702            # Post-order transformations
 703            node = self.simplify_not(original)
 704            node = flatten(node)
 705            node = self.simplify_connectors(node, root)
 706            node = self.remove_complements(node, root)
 707
 708            if coalesce_simplification:
 709                node = self.simplify_coalesce(node)
 710            node.parent = parent
 711
 712            node = self.simplify_literals(node, root)
 713            node = self.simplify_equality(node)
 714            node = simplify_parens(node, dialect=self.dialect)
 715            node = self.simplify_datetrunc(node)
 716            node = self.sort_comparison(node)
 717            node = self.simplify_startswith(node)
 718
 719            if node is not original:
 720                original.replace(node)
 721
 722        return node
 723
 724    @annotate_types_on_change
 725    def rewrite_between(self, expression: exp.Expr) -> exp.Expr:
 726        """Rewrite x between y and z to x >= y AND x <= z.
 727
 728        This is done because comparison simplification is only done on lt/lte/gt/gte.
 729        """
 730        if isinstance(expression, exp.Between):
 731            negate = isinstance(expression.parent, exp.Not)
 732
 733            expression = exp.and_(
 734                exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 735                exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 736                copy=False,
 737            )
 738
 739            if negate:
 740                expression = exp.paren(expression, copy=False)
 741
 742        return expression
 743
 744    @annotate_types_on_change
 745    def simplify_not(self, expression: exp.Expr) -> exp.Expr:
 746        """
 747        Demorgan's Law
 748        NOT (x OR y) -> NOT x AND NOT y
 749        NOT (x AND y) -> NOT x OR NOT y
 750        """
 751        if isinstance(expression, exp.Not):
 752            this = expression.this
 753            if is_null(this):
 754                return exp.and_(exp.null(), exp.true(), copy=False)
 755            if this.__class__ in self.COMPLEMENT_COMPARISONS:
 756                right = this.expression
 757                complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
 758                    right.__class__
 759                )
 760                if complement_subquery_predicate:
 761                    right = complement_subquery_predicate(this=right.this)
 762
 763                return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 764            if isinstance(this, exp.Paren):
 765                condition = this.unnest()
 766                if isinstance(condition, exp.And):
 767                    return exp.paren(
 768                        exp.or_(
 769                            exp.not_(condition.left, copy=False),
 770                            exp.not_(condition.right, copy=False),
 771                            copy=False,
 772                        ),
 773                        copy=False,
 774                    )
 775                if isinstance(condition, exp.Or):
 776                    return exp.paren(
 777                        exp.and_(
 778                            exp.not_(condition.left, copy=False),
 779                            exp.not_(condition.right, copy=False),
 780                            copy=False,
 781                        ),
 782                        copy=False,
 783                    )
 784                if is_null(condition):
 785                    return exp.and_(exp.null(), exp.true(), copy=False)
 786            if always_true(this):
 787                return exp.false()
 788            if is_false(this):
 789                return exp.true()
 790            if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
 791                inner = this.this
 792                if inner.is_type(exp.DType.BOOLEAN):
 793                    # double negation
 794                    # NOT NOT x -> x, if x is BOOLEAN type
 795                    return inner
 796        return expression
 797
 798    @annotate_types_on_change
 799    def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr:
 800        def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr):
 801            if isinstance(expression, exp.And):
 802                if is_false(left) or is_false(right):
 803                    return exp.false()
 804                if is_zero(left) or is_zero(right):
 805                    return exp.false()
 806                if (
 807                    (is_null(left) and is_null(right))
 808                    or (is_null(left) and always_true(right))
 809                    or (always_true(left) and is_null(right))
 810                ):
 811                    return exp.null()
 812                if always_true(left) and always_true(right):
 813                    return exp.true()
 814                if always_true(left):
 815                    return right
 816                if always_true(right):
 817                    return left
 818                return self._simplify_comparison(expression, left, right)
 819            elif isinstance(expression, exp.Or):
 820                if always_true(left) or always_true(right):
 821                    return exp.true()
 822                if (
 823                    (is_null(left) and is_null(right))
 824                    or (is_null(left) and always_false(right))
 825                    or (always_false(left) and is_null(right))
 826                ):
 827                    return exp.null()
 828                if is_false(left):
 829                    return right
 830                if is_false(right):
 831                    return left
 832                return self._simplify_comparison(expression, left, right, or_=True)
 833
 834        if isinstance(expression, exp.Connector):
 835            original_parent = expression.parent
 836            expression = self._flat_simplify(expression, _simplify_connectors, root)
 837
 838            # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
 839            # to ensure that the resulting type is boolean. We know this is true only for connectors,
 840            # boolean values and columns that are essentially operands to a connector:
 841            #
 842            # A AND (((B)))
 843            #          ~ this is safe to keep because it will eventually be part of another connector
 844            if not isinstance(
 845                expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
 846            ) and not expression.is_type(exp.DType.BOOLEAN):
 847                while True:
 848                    if isinstance(original_parent, exp.Connector):
 849                        break
 850                    if not isinstance(original_parent, exp.Paren):
 851                        expression = expression.and_(exp.true(), copy=False)
 852                        break
 853
 854                    original_parent = original_parent.parent
 855
 856        return expression
 857
 858    @annotate_types_on_change
 859    def _simplify_comparison(
 860        self, expression: exp.Expr, left: exp.Expr, right: exp.Expr, or_: bool = False
 861    ) -> exp.Expr | None:
 862        if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
 863            ll, lr = left.args.values()
 864            rl, rr = right.args.values()
 865
 866            largs = {ll, lr}
 867            rargs = {rl, rr}
 868
 869            matching = largs & rargs
 870            columns = {
 871                m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
 872            }
 873
 874            if matching and columns:
 875                try:
 876                    l = first(largs - columns)
 877                    r = first(rargs - columns)
 878                except StopIteration:
 879                    return expression
 880
 881                if l.is_number and r.is_number:
 882                    l = l.to_py()
 883                    r = r.to_py()
 884                elif l.is_string and r.is_string:
 885                    l = l.name
 886                    r = r.name
 887                else:
 888                    l = extract_date(l)
 889                    if not l:
 890                        return None
 891                    r = extract_date(r)
 892                    if not r:
 893                        return None
 894                    # python won't compare date and datetime, but many engines will upcast
 895                    l, r = cast_as_datetime(l), cast_as_datetime(r)
 896
 897                for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 898                    if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
 899                        return left if (av > bv if or_ else av <= bv) else right
 900                    if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
 901                        return left if (av < bv if or_ else av >= bv) else right
 902
 903                    # we can't ever shortcut to true because the column could be null
 904                    if not or_:
 905                        if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
 906                            if av <= bv:
 907                                return exp.false()
 908                        elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
 909                            if av >= bv:
 910                                return exp.false()
 911                        elif isinstance(a, exp.EQ):
 912                            if isinstance(b, exp.LT):
 913                                return exp.false() if av >= bv else a
 914                            if isinstance(b, exp.LTE):
 915                                return exp.false() if av > bv else a
 916                            if isinstance(b, exp.GT):
 917                                return exp.false() if av <= bv else a
 918                            if isinstance(b, exp.GTE):
 919                                return exp.false() if av < bv else a
 920                            if isinstance(b, exp.NEQ):
 921                                return exp.false() if av == bv else a
 922        return None
 923
 924    @annotate_types_on_change
 925    def remove_complements(self, expression: object, root: bool = True) -> object:
 926        """
 927        Removing complements.
 928
 929        A AND NOT A -> FALSE (only for non-NULL A)
 930        A OR NOT A -> TRUE (only for non-NULL A)
 931        """
 932        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 933            ops = set(expression.flatten())
 934            for op in ops:
 935                if isinstance(op, exp.Not) and op.this in ops:
 936                    if expression.meta.get("nonnull") is True:
 937                        return exp.false() if isinstance(expression, exp.And) else exp.true()
 938
 939        return expression
 940
 941    @annotate_types_on_change
 942    def uniq_sort(self, expression: object, root: bool = True) -> object:
 943        """
 944        Uniq and sort a connector.
 945
 946        C AND A AND B AND B -> A AND B AND C
 947        """
 948        if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 949            flattened = tuple(expression.flatten())
 950
 951            if isinstance(expression, exp.Xor):
 952                result_func = exp.xor
 953                # Do not deduplicate XOR as A XOR A != A if A == True
 954                deduped = None
 955                arr = tuple((gen(e), e) for e in flattened)
 956            else:
 957                result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 958                deduped = {gen(e): e for e in flattened}
 959                arr = tuple(deduped.items())
 960
 961            # check if the operands are already sorted, if not sort them
 962            # A AND C AND B -> A AND B AND C
 963            for i, (sql, e) in enumerate(arr[1:]):
 964                if sql < arr[i][0]:
 965                    expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 966                    break
 967            else:
 968                # we didn't have to sort but maybe we need to dedup
 969                if deduped and len(deduped) < len(flattened):
 970                    unique_operand = flattened[0]
 971                    if len(deduped) == 1:
 972                        expression = unique_operand.and_(exp.true(), copy=False)
 973                    else:
 974                        expression = result_func(*deduped.values(), copy=False)
 975
 976        return expression
 977
 978    @annotate_types_on_change
 979    def absorb_and_eliminate(self, expression: object, root: bool = True) -> object:
 980        """
 981        absorption:
 982            A AND (A OR B) -> A
 983            A OR (A AND B) -> A
 984            A AND (NOT A OR B) -> A AND B
 985            A OR (NOT A AND B) -> A OR B
 986        elimination:
 987            (A AND B) OR (A AND NOT B) -> A
 988            (A OR B) AND (A OR NOT B) -> A
 989        """
 990        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 991            kind = exp.Or if isinstance(expression, exp.And) else exp.And
 992
 993            ops = tuple(expression.flatten())
 994
 995            # Initialize lookup tables:
 996            # Set of all operands, used to find complements for absorption.
 997            op_set = set()
 998            # Sub-operands, used to find subsets for absorption.
 999            subops = defaultdict(list)
1000            # Pairs of complements, used for elimination.
1001            pairs = defaultdict(list)
1002
1003            # Populate the lookup tables
1004            for op in ops:
1005                op_set.add(op)
1006
1007                if not isinstance(op, kind):
1008                    # In cases like: A OR (A AND B)
1009                    # Subop will be: ^
1010                    subops[op].append({op})
1011                    continue
1012
1013                # In cases like: (A AND B) OR (A AND B AND C)
1014                # Subops will be: ^     ^
1015                subset = set(op.flatten())
1016                for i in subset:
1017                    subops[i].append(subset)
1018
1019                a, b = op.unnest_operands()
1020                if isinstance(a, exp.Not):
1021                    pairs[frozenset((a.this, b))].append((op, b))
1022                if isinstance(b, exp.Not):
1023                    pairs[frozenset((a, b.this))].append((op, a))
1024
1025            for op in ops:
1026                if not isinstance(op, kind):
1027                    continue
1028
1029                a, b = op.unnest_operands()
1030
1031                # Absorb
1032                if isinstance(a, exp.Not) and a.this in op_set:
1033                    a.replace(exp.true() if kind == exp.And else exp.false())
1034                    continue
1035                if isinstance(b, exp.Not) and b.this in op_set:
1036                    b.replace(exp.true() if kind == exp.And else exp.false())
1037                    continue
1038                superset = set(op.flatten())
1039                if any(any(subset < superset for subset in subops[i]) for i in superset):
1040                    op.replace(exp.false() if kind == exp.And else exp.true())
1041                    continue
1042
1043                # Eliminate
1044                for other, complement in pairs[frozenset((a, b))]:
1045                    op.replace(complement)
1046                    other.replace(complement)
1047
1048        return expression
1049
1050    @annotate_types_on_change
1051    @catch(ModuleNotFoundError, UnsupportedUnit)
1052    def simplify_equality(self, expression: exp.Expr) -> exp.Expr:
1053        """
1054        Use the subtraction and addition properties of equality to simplify expressions:
1055
1056            x + 1 = 3 becomes x = 2
1057
1058        There are two binary operations in the above expression: + and =
1059        Here's how we reference all the operands in the code below:
1060
1061            l     r
1062            x + 1 = 3
1063            a   b
1064        """
1065        if isinstance(expression, self.COMPARISONS):
1066            l, r = expression.left, expression.right
1067
1068            if l.__class__ not in self.INVERSE_OPS:
1069                return expression
1070
1071            if r.is_number:
1072                a_predicate = _is_number
1073                b_predicate = _is_number
1074            elif _is_date_literal(r):
1075                a_predicate = _is_date_literal
1076                b_predicate = _is_interval
1077            else:
1078                return expression
1079
1080            if l.__class__ in self.INVERSE_DATE_OPS:
1081                l = t.cast(exp.IntervalOp, l)
1082                a: exp.Expr = l.this
1083                b: exp.Expr = l.interval()
1084            else:
1085                l = t.cast(exp.Binary, l)
1086                a, b = l.left, l.right
1087
1088            if not a_predicate(a) and b_predicate(b):
1089                pass
1090            elif not a_predicate(b) and b_predicate(a):
1091                a, b = b, a
1092            else:
1093                return expression
1094
1095            return expression.__class__(
1096                this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
1097            )
1098        return expression
1099
1100    def _is_inverse_date_op(self, expression: exp.Expr) -> TypeIs[exp.IntervalOp]:
1101        return type(expression) in self.INVERSE_DATE_OPS
1102
1103    @annotate_types_on_change
1104    def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr:
1105        if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1106            return self._flat_simplify(expression, self._simplify_binary, root)
1107
1108        if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1109            return expression.this.this
1110
1111        if self._is_inverse_date_op(expression):
1112            return (
1113                self._simplify_binary(expression, expression.this, expression.interval())
1114                or expression
1115            )
1116
1117        return expression
1118
1119    def _simplify_integer_cast(self, expr: t.Any) -> t.Any:
1120        if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
1121            this = self._simplify_integer_cast(expr.this)
1122        else:
1123            this = expr.this
1124
1125        if isinstance(expr, exp.Cast) and this.is_int:
1126            num = this.to_py()
1127
1128            # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
1129            # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
1130            # engine-dependent
1131            if (
1132                self.TINYINT_MIN <= num <= self.TINYINT_MAX
1133                and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
1134            ) or (
1135                self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
1136                and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
1137            ):
1138                return this
1139
1140        return expr
1141
1142    def _simplify_binary(self, expression, a, b):
1143        if isinstance(expression, self.COMPARISONS):
1144            a = self._simplify_integer_cast(a)
1145            b = self._simplify_integer_cast(b)
1146
1147        if isinstance(expression, exp.Is):
1148            if isinstance(b, exp.Not):
1149                c = b.this
1150                not_ = True
1151            else:
1152                c = b
1153                not_ = False
1154
1155            if is_null(c):
1156                if isinstance(a, exp.Literal):
1157                    return exp.true() if not_ else exp.false()
1158                if is_null(a):
1159                    return exp.false() if not_ else exp.true()
1160        elif isinstance(expression, self.NULL_OK):
1161            return None
1162        elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
1163            return exp.null()
1164
1165        if a.is_number and b.is_number:
1166            num_a = a.to_py()
1167            num_b = b.to_py()
1168
1169            if isinstance(expression, exp.Add):
1170                return exp.Literal.number(num_a + num_b)
1171            if isinstance(expression, exp.Mul):
1172                return exp.Literal.number(num_a * num_b)
1173
1174            # We only simplify Sub, Div if a and b have the same parent because they're not associative
1175            if isinstance(expression, exp.Sub):
1176                return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
1177            if isinstance(expression, exp.Div):
1178                # engines have differing int div behavior so intdiv is not safe
1179                if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
1180                    return None
1181                return exp.Literal.number(num_a / num_b)
1182
1183            boolean = eval_boolean(expression, num_a, num_b)
1184
1185            if boolean:
1186                return boolean
1187        elif a.is_string and b.is_string:
1188            boolean = eval_boolean(expression, a.this, b.this)
1189
1190            if boolean:
1191                return boolean
1192        elif _is_date_literal(a) and isinstance(b, exp.Interval):
1193            date, b = extract_date(a), extract_interval(b)
1194            if date and b:
1195                if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
1196                    return date_literal(date + b, extract_type(a))
1197                if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
1198                    return date_literal(date - b, extract_type(a))
1199        elif isinstance(a, exp.Interval) and _is_date_literal(b):
1200            a, date = extract_interval(a), extract_date(b)
1201            # you cannot subtract a date from an interval
1202            if a and b and isinstance(expression, exp.Add):
1203                return date_literal(a + date, extract_type(b))
1204        elif _is_date_literal(a) and _is_date_literal(b):
1205            if isinstance(expression, exp.Predicate):
1206                a, b = extract_date(a), extract_date(b)
1207                boolean = eval_boolean(expression, a, b)
1208                if boolean:
1209                    return boolean
1210
1211        return None
1212
1213    @annotate_types_on_change
1214    def simplify_coalesce(self, expression: exp.Expr) -> exp.Expr:
1215        # COALESCE(x) -> x
1216        if (
1217            isinstance(expression, exp.Coalesce)
1218            and (not expression.expressions or _is_nonnull_constant(expression.this))
1219            # COALESCE is also used as a Spark partitioning hint
1220            and not isinstance(expression.parent, exp.Hint)
1221        ):
1222            return expression.this
1223
1224        if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1225            return expression
1226
1227        if not isinstance(expression, self.COMPARISONS):
1228            return expression
1229
1230        if isinstance(expression.left, exp.Coalesce):
1231            coalesce = expression.left
1232            other = expression.right
1233        elif isinstance(expression.right, exp.Coalesce):
1234            coalesce = expression.right
1235            other = expression.left
1236        else:
1237            return expression
1238
1239        # This transformation is valid for non-constants,
1240        # but it really only does anything if they are both constants.
1241        if not _is_constant(other):
1242            return expression
1243
1244        # Find the first constant arg
1245        for arg_index, arg in enumerate(coalesce.expressions):
1246            if _is_constant(arg):
1247                break
1248        else:
1249            return expression
1250
1251        coalesce.set("expressions", coalesce.expressions[:arg_index])
1252
1253        # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1254        # since we already remove COALESCE at the top of this function.
1255        this: exp.Expr = coalesce if coalesce.expressions else coalesce.this
1256
1257        # This expression is more complex than when we started, but it will get simplified further
1258        return exp.paren(
1259            exp.or_(
1260                exp.and_(
1261                    this.is_(exp.null()).not_(copy=False),
1262                    expression.copy(),
1263                    copy=False,
1264                ),
1265                exp.and_(
1266                    this.is_(exp.null()),
1267                    type(expression)(this=arg.copy(), expression=other.copy()),
1268                    copy=False,
1269                ),
1270                copy=False,
1271            ),
1272            copy=False,
1273        )
1274
1275    @annotate_types_on_change
1276    def simplify_concat(self, expression):
1277        """Reduces all groups that contain string literals by concatenating them."""
1278        if not isinstance(expression, self.CONCATS) or (
1279            # We can't reduce a CONCAT_WS call if we don't statically know the separator
1280            isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1281        ):
1282            return expression
1283
1284        if isinstance(expression, exp.ConcatWs):
1285            sep_expr, *expressions = expression.expressions
1286            sep = sep_expr.name
1287            concat_type = exp.ConcatWs
1288            args = {}
1289        else:
1290            expressions = expression.expressions
1291            sep = ""
1292            concat_type = exp.Concat
1293            args = {
1294                "safe": expression.args.get("safe"),
1295                "coalesce": expression.args.get("coalesce"),
1296            }
1297
1298        new_args = []
1299        for is_string_group, group in itertools.groupby(
1300            expressions or expression.flatten(), lambda e: e.is_string
1301        ):
1302            if is_string_group:
1303                new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1304            else:
1305                new_args.extend(group)
1306
1307        if len(new_args) == 1 and new_args[0].is_string:
1308            return new_args[0]
1309
1310        if concat_type is exp.ConcatWs:
1311            new_args = [sep_expr] + new_args
1312        elif isinstance(expression, exp.DPipe):
1313            return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1314
1315        return concat_type(expressions=new_args, **args)
1316
1317    @annotate_types_on_change
1318    def simplify_conditionals(self, expression):
1319        """Simplifies expressions like IF, CASE if their condition is statically known."""
1320        if isinstance(expression, exp.Case):
1321            this = expression.this
1322            for case in expression.args["ifs"]:
1323                cond = case.this
1324                if this:
1325                    # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1326                    cond = cond.replace(this.pop().eq(cond))
1327
1328                if always_true(cond):
1329                    return case.args["true"]
1330
1331                if always_false(cond):
1332                    case.pop()
1333                    if not expression.args["ifs"]:
1334                        return expression.args.get("default") or exp.null()
1335        elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1336            if always_true(expression.this):
1337                return expression.args["true"]
1338            if always_false(expression.this):
1339                return expression.args.get("false") or exp.null()
1340
1341        return expression
1342
1343    @annotate_types_on_change
1344    def simplify_startswith(self, expression: exp.Expr) -> exp.Expr:
1345        """
1346        Reduces a prefix check to either TRUE or FALSE if both the string and the
1347        prefix are statically known.
1348
1349        Example:
1350            >>> from sqlglot import parse_one
1351            >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1352            'TRUE'
1353        """
1354        if (
1355            isinstance(expression, exp.StartsWith)
1356            and expression.this.is_string
1357            and expression.expression.is_string
1358        ):
1359            return exp.convert(expression.name.startswith(expression.expression.name))
1360
1361        return expression
1362
1363    def _is_datetrunc_predicate(self, left: exp.Expr, right: exp.Expr) -> bool:
1364        return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
1365
1366    @annotate_types_on_change
1367    @catch(ModuleNotFoundError, UnsupportedUnit)
1368    def simplify_datetrunc(self, expression: exp.Expr) -> exp.Expr:
1369        """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1370        comparison = expression.__class__
1371
1372        if isinstance(expression, self.DATETRUNCS):
1373            this = expression.this
1374            trunc_type = extract_type(this)
1375            date = extract_date(this)
1376            if date and expression.unit:
1377                return date_literal(
1378                    date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
1379                )
1380        elif comparison not in self.DATETRUNC_COMPARISONS:
1381            return expression
1382
1383        if isinstance(expression, exp.Binary):
1384            l, r = expression.left, expression.right
1385
1386            if not self._is_datetrunc_predicate(l, r) or not isinstance(
1387                l, (exp.DateTrunc, exp.TimestampTrunc)
1388            ):
1389                return expression
1390
1391            trunc_arg = l.this
1392            unit = l.args["unit"].name.lower()
1393            date = extract_date(r)
1394
1395            if not date:
1396                return expression
1397
1398            return (
1399                self.DATETRUNC_BINARY_COMPARISONS[comparison](
1400                    trunc_arg, date, unit, self.dialect, extract_type(r)
1401                )
1402                or expression
1403            )
1404
1405        if isinstance(expression, exp.In):
1406            l = expression.this
1407            rs = expression.expressions
1408
1409            if (
1410                rs
1411                and all(self._is_datetrunc_predicate(l, r) for r in rs)
1412                and isinstance(l, (exp.DateTrunc, exp.TimestampTrunc))
1413            ):
1414                unit = l.args["unit"].name.lower()
1415
1416                ranges = []
1417                for r in rs:
1418                    date = extract_date(r)
1419                    if not date:
1420                        return expression
1421                    drange = _datetrunc_range(date, unit, self.dialect)
1422                    if drange:
1423                        ranges.append(drange)
1424
1425                if not ranges:
1426                    return expression
1427
1428                ranges = merge_ranges(ranges)
1429                target_type = extract_type(*rs)
1430
1431                return exp.or_(
1432                    *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
1433                    copy=False,
1434                )
1435
1436        return expression
1437
1438    @annotate_types_on_change
1439    def sort_comparison(self, expression: exp.Expr) -> exp.Expr:
1440        if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1441            l, r = expression.this, expression.expression
1442            l_column = isinstance(l, exp.Column)
1443            r_column = isinstance(r, exp.Column)
1444            l_const = _is_constant(l)
1445            r_const = _is_constant(r)
1446
1447            if (
1448                (l_column and not r_column)
1449                or (r_const and not l_const)
1450                or isinstance(r, exp.SubqueryPredicate)
1451            ):
1452                return expression
1453            if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1454                return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1455                    this=r, expression=l
1456                )
1457        return expression
1458
1459    def _flat_simplify(
1460        self,
1461        expression: exp.Expr,
1462        simplifier: t.Callable[[exp.Expr, exp.Expr, exp.Expr], exp.Expr | None],
1463        root: bool = True,
1464    ) -> exp.Expr:
1465        if root or not expression.same_parent:
1466            operands = []
1467            queue = deque(expression.flatten(unnest=False))
1468            size = len(queue)
1469
1470            while queue:
1471                a = queue.popleft()
1472
1473                for b in queue:
1474                    result = simplifier(expression, a, b)
1475
1476                    if result and result is not expression:
1477                        queue.remove(b)
1478                        queue.appendleft(result)
1479                        break
1480                else:
1481                    operands.append(a)
1482
1483            if len(operands) < size:
1484                return functools.reduce(
1485                    lambda a, b: expression.__class__(this=a, expression=b), operands
1486                )
1487        return expression
1488
1489
1490def gen(expression: exp.Expr, comments: bool = False) -> str:
1491    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1492
1493    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1494    generator is expensive so we have a bare minimum sql generator here.
1495
1496    Args:
1497        expression: the expression to convert into a SQL string.
1498        comments: whether to include the expression's comments.
1499    """
1500    return Gen().gen(expression, comments=comments)
1501
1502
1503class Gen:
1504    def __init__(self):
1505        self.stack = []
1506        self.sqls = []
1507
1508    def gen(self, expression: exp.Expr, comments: bool = False) -> str:
1509        self.stack = [expression]
1510        self.sqls.clear()
1511
1512        while self.stack:
1513            node = self.stack.pop()
1514
1515            if isinstance(node, exp.Expr):
1516                if comments and node.comments:
1517                    self.stack.append(f" /*{','.join(node.comments)}*/")
1518
1519                exp_handler_name = f"{node.key}_sql"
1520
1521                if hasattr(self, exp_handler_name):
1522                    getattr(self, exp_handler_name)(node)
1523                elif isinstance(node, exp.Func):
1524                    self._function(node)
1525                else:
1526                    key = node.key.upper()
1527                    self.stack.append(f"{key} " if self._args(node) else key)
1528            elif type(node) is list:
1529                for n in reversed(node):
1530                    if n is not None:
1531                        self.stack.extend((n, ","))
1532                if node:
1533                    self.stack.pop()
1534            else:
1535                if node is not None:
1536                    self.sqls.append(str(node))
1537
1538        return "".join(self.sqls)
1539
1540    def add_sql(self, e: exp.Add) -> None:
1541        self._binary(e, " + ")
1542
1543    def alias_sql(self, e: exp.Alias) -> None:
1544        self.stack.extend(
1545            (
1546                e.args.get("alias"),
1547                " AS ",
1548                e.args.get("this"),
1549            )
1550        )
1551
1552    def and_sql(self, e: exp.And) -> None:
1553        self._binary(e, " AND ")
1554
1555    def anonymous_sql(self, e: exp.Anonymous) -> None:
1556        this = e.this
1557        if isinstance(this, str):
1558            name = this.upper()
1559        elif isinstance(this, exp.Identifier):
1560            name = this.this
1561            name = f'"{name}"' if this.quoted else name.upper()
1562        else:
1563            raise ValueError(
1564                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1565            )
1566
1567        self.stack.extend(
1568            (
1569                ")",
1570                e.expressions,
1571                "(",
1572                name,
1573            )
1574        )
1575
1576    def between_sql(self, e: exp.Between) -> None:
1577        self.stack.extend(
1578            (
1579                e.args.get("high"),
1580                " AND ",
1581                e.args.get("low"),
1582                " BETWEEN ",
1583                e.this,
1584            )
1585        )
1586
1587    def boolean_sql(self, e: exp.Boolean) -> None:
1588        self.stack.append("TRUE" if e.this else "FALSE")
1589
1590    def bracket_sql(self, e: exp.Bracket) -> None:
1591        self.stack.extend(
1592            (
1593                "]",
1594                e.expressions,
1595                "[",
1596                e.this,
1597            )
1598        )
1599
1600    def column_sql(self, e: exp.Column) -> None:
1601        for p in reversed(e.parts):
1602            self.stack.extend((p, "."))
1603        self.stack.pop()
1604
1605    def datatype_sql(self, e: exp.DataType) -> None:
1606        self._args(e, 1)
1607        self.stack.append(f"{e.this.name} ")
1608
1609    def div_sql(self, e: exp.Div) -> None:
1610        self._binary(e, " / ")
1611
1612    def dot_sql(self, e: exp.Dot) -> None:
1613        self._binary(e, ".")
1614
1615    def eq_sql(self, e: exp.EQ) -> None:
1616        self._binary(e, " = ")
1617
1618    def from_sql(self, e: exp.From) -> None:
1619        self.stack.extend((e.this, "FROM "))
1620
1621    def gt_sql(self, e: exp.GT) -> None:
1622        self._binary(e, " > ")
1623
1624    def gte_sql(self, e: exp.GTE) -> None:
1625        self._binary(e, " >= ")
1626
1627    def identifier_sql(self, e: exp.Identifier) -> None:
1628        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1629
1630    def ilike_sql(self, e: exp.ILike) -> None:
1631        self._binary(e, " ILIKE ")
1632
1633    def in_sql(self, e: exp.In) -> None:
1634        self.stack.append(")")
1635        self._args(e, 1)
1636        self.stack.extend(
1637            (
1638                "(",
1639                " IN ",
1640                e.this,
1641            )
1642        )
1643
1644    def intdiv_sql(self, e: exp.IntDiv) -> None:
1645        self._binary(e, " DIV ")
1646
1647    def is_sql(self, e: exp.Is) -> None:
1648        self._binary(e, " IS ")
1649
1650    def like_sql(self, e: exp.Like) -> None:
1651        self._binary(e, " Like ")
1652
1653    def literal_sql(self, e: exp.Literal) -> None:
1654        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1655
1656    def lt_sql(self, e: exp.LT) -> None:
1657        self._binary(e, " < ")
1658
1659    def lte_sql(self, e: exp.LTE) -> None:
1660        self._binary(e, " <= ")
1661
1662    def mod_sql(self, e: exp.Mod) -> None:
1663        self._binary(e, " % ")
1664
1665    def mul_sql(self, e: exp.Mul) -> None:
1666        self._binary(e, " * ")
1667
1668    def neg_sql(self, e: exp.Neg) -> None:
1669        self._unary(e, "-")
1670
1671    def neq_sql(self, e: exp.NEQ) -> None:
1672        self._binary(e, " <> ")
1673
1674    def not_sql(self, e: exp.Not) -> None:
1675        self._unary(e, "NOT ")
1676
1677    def null_sql(self, e: exp.Null) -> None:
1678        self.stack.append("NULL")
1679
1680    def or_sql(self, e: exp.Or) -> None:
1681        self._binary(e, " OR ")
1682
1683    def paren_sql(self, e: exp.Paren) -> None:
1684        self.stack.extend(
1685            (
1686                ")",
1687                e.this,
1688                "(",
1689            )
1690        )
1691
1692    def sub_sql(self, e: exp.Sub) -> None:
1693        self._binary(e, " - ")
1694
1695    def subquery_sql(self, e: exp.Subquery) -> None:
1696        self._args(e, 2)
1697        alias = e.args.get("alias")
1698        if alias:
1699            self.stack.append(alias)
1700        self.stack.extend((")", e.this, "("))
1701
1702    def table_sql(self, e: exp.Table) -> None:
1703        self._args(e, 4)
1704        alias = e.args.get("alias")
1705        if alias:
1706            self.stack.append(alias)
1707        for p in reversed(e.parts):
1708            self.stack.extend((p, "."))
1709        self.stack.pop()
1710
1711    def tablealias_sql(self, e: exp.TableAlias) -> None:
1712        columns = e.columns
1713
1714        if columns:
1715            self.stack.extend((")", columns, "("))
1716
1717        self.stack.extend((e.this, " AS "))
1718
1719    def var_sql(self, e: exp.Var) -> None:
1720        self.stack.append(e.this)
1721
1722    def _binary(self, e: exp.Binary, op: str) -> None:
1723        self.stack.extend((e.expression, op, e.this))
1724
1725    def _unary(self, e: exp.Unary, op: str) -> None:
1726        self.stack.extend((e.this, op))
1727
1728    def _function(self, e: exp.Func) -> None:
1729        self.stack.extend(
1730            (
1731                ")",
1732                list(e.args.values()),
1733                "(",
1734                e.sql_name(),
1735            )
1736        )
1737
1738    def _args(self, node: exp.Expr, arg_index: int = 0) -> bool:
1739        kvs = []
1740        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1741
1742        for k in arg_types:
1743            v = node.args.get(k)
1744
1745            if v is not None:
1746                kvs.append([f":{k}", v])
1747        if kvs:
1748            self.stack.append(kvs)
1749            return True
1750        return False
logger = <Logger sqlglot (WARNING)>
FINAL = 'final'
def simplify( expression: sqlglot.expressions.core.Expr, constant_propagation: bool = False, coalesce_simplification: bool = False, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None):
45def simplify(
46    expression: exp.Expr,
47    constant_propagation: bool = False,
48    coalesce_simplification: bool = False,
49    dialect: DialectType = None,
50):
51    """
52    Rewrite sqlglot AST to simplify expressions.
53
54    Example:
55        >>> import sqlglot
56        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
57        >>> simplify(expression).sql()
58        'TRUE'
59
60    Args:
61        expression: expression to simplify
62        constant_propagation: whether the constant propagation rule should be used
63        coalesce_simplification: whether the simplify coalesce rule should be used.
64            This rule tries to remove coalesce functions, which can be useful in certain analyses but
65            can leave the query more verbose.
66    Returns:
67        sqlglot.Expr: simplified expression
68    """
69    return Simplifier(dialect=dialect).simplify(
70        expression,
71        constant_propagation=constant_propagation,
72        coalesce_simplification=coalesce_simplification,
73    )

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression: expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
  • coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:

sqlglot.Expr: simplified expression

class UnsupportedUnit(builtins.Exception):
76class UnsupportedUnit(Exception):
77    pass

Common base class for all non-exit exceptions.

def catch(*exceptions):
80def catch(*exceptions):
81    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
82
83    def decorator(func):
84        def wrapped(expression, *args, **kwargs):
85            try:
86                return func(expression, *args, **kwargs)
87            except exceptions:
88                return expression
89
90        return wrapped
91
92    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def annotate_types_on_change(func):
 95def annotate_types_on_change(func):
 96    @wraps(func)
 97    def _func(self, expression: exp.Expr, *args, **kwargs) -> exp.Expr | None:
 98        new_expression: exp.Expr | None = func(self, expression, *args, **kwargs)
 99
100        if new_expression is None:
101            return new_expression
102
103        if self.annotate_new_expressions and expression != new_expression:
104            self._annotator.clear()
105
106            # We annotate this to ensure new children nodes are also annotated
107            new_expression = self._annotator.annotate(
108                expression=new_expression,
109                annotate_scope=False,
110            )
111
112            # Whatever expression the original expression is transformed into needs to preserve
113            # the original type, otherwise the simplification could result in a different schema
114            new_expression.type = expression.type
115
116        return new_expression
117
118    return _func
def flatten( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
121def flatten(expression: exp.Expr) -> exp.Expr:
122    """
123    A AND (B AND C) -> A AND B AND C
124    A OR (B OR C) -> A OR B OR C
125    """
126    if isinstance(expression, exp.Connector):
127        for node in expression.args.values():
128            child = node.unnest()
129            if isinstance(child, expression.__class__):
130                node.replace(child)
131    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_parens( expression: sqlglot.expressions.core.Expr, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.core.Expr:
134def simplify_parens(expression: exp.Expr, dialect: DialectType) -> exp.Expr:
135    if not isinstance(expression, exp.Paren):
136        return expression
137
138    this = expression.this
139    parent = expression.parent
140    parent_is_predicate = isinstance(parent, exp.Predicate)
141
142    if isinstance(this, exp.Select):
143        return expression
144
145    if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)):
146        return expression
147
148    if (
149        Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS
150        and isinstance(parent, exp.Dot)
151        and (isinstance(parent.right, (exp.Identifier, exp.Star)))
152    ):
153        return expression
154
155    if isinstance(this, exp.Predicate) and (
156        not (
157            parent_is_predicate
158            or isinstance(parent, exp.Neg)
159            or (isinstance(parent, exp.Binary) and not isinstance(parent, exp.Connector))
160        )
161    ):
162        return this
163
164    if (
165        not isinstance(parent, (exp.Condition, exp.Binary))
166        or isinstance(parent, exp.Paren)
167        or (
168            not isinstance(this, exp.Binary)
169            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
170        )
171        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
172        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
173        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
174    ):
175        return this
176
177    return expression
def propagate_constants(expression, root=True):
180def propagate_constants(expression, root=True):
181    """
182    Propagate constants for conjunctions in DNF:
183
184    SELECT * FROM t WHERE a = b AND b = 5 becomes
185    SELECT * FROM t WHERE a = 5 AND b = 5
186
187    Reference: https://www.sqlite.org/optoverview.html
188    """
189
190    if (
191        isinstance(expression, exp.And)
192        and (root or not expression.same_parent)
193        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
194    ):
195        constant_mapping = {}
196        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
197            if isinstance(expr, exp.EQ):
198                l, r = expr.left, expr.right
199
200                # TODO: create a helper that can be used to detect nested literal expressions such
201                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
202                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
203                    constant_mapping[l] = (id(l), r)
204
205        if constant_mapping:
206            for column in find_all_in_scope(expression, exp.Column):
207                parent = column.parent
208                column_id, constant = constant_mapping.get(column) or (None, None)
209                if (
210                    column_id is not None
211                    and id(column) != column_id
212                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
213                ):
214                    column.replace(constant.copy())
215
216    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def always_true(expression: object) -> bool:
297def always_true(expression: object) -> bool:
298    return (isinstance(expression, exp.Boolean) and expression.this) or (
299        isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression)
300    )
def always_false(expression: object) -> bool:
303def always_false(expression: object) -> bool:
304    return is_false(expression) or is_null(expression) or is_zero(expression)
def is_zero(expression: object) -> bool:
307def is_zero(expression: object) -> bool:
308    return isinstance(expression, exp.Literal) and expression.to_py() == 0
def is_complement(a: object, b: object) -> bool:
311def is_complement(a: object, b: object) -> bool:
312    return isinstance(b, exp.Not) and b.this == a
def is_false(a: object) -> bool:
315def is_false(a: object) -> bool:
316    return type(a) is exp.Boolean and not a.this
def is_null(a: object) -> bool:
319def is_null(a: object) -> bool:
320    return type(a) is exp.Null
class SupportsComparison(typing.Protocol):
323class SupportsComparison(t.Protocol):
324    """Protocol for expressions or values that can be compared using <, <=, >, >=."""
325
326    def __lt__(self, other: t.Any) -> bool: ...
327    def __le__(self, other: t.Any) -> bool: ...
328    def __gt__(self, other: t.Any) -> bool: ...
329    def __ge__(self, other: t.Any) -> bool: ...

Protocol for expressions or values that can be compared using <, <=, >, >=.

SupportsComparison(*args, **kwargs)
1431def _no_init_or_replace_init(self, *args, **kwargs):
1432    cls = type(self)
1433
1434    if cls._is_protocol:
1435        raise TypeError('Protocols cannot be instantiated')
1436
1437    # Already using a custom `__init__`. No need to calculate correct
1438    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1439    if cls.__init__ is not _no_init_or_replace_init:
1440        return
1441
1442    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1443    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1444    # searches for a proper new `__init__` in the MRO. The new `__init__`
1445    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1446    # instantiation of the protocol subclass will thus use the new
1447    # `__init__` and no longer call `_no_init_or_replace_init`.
1448    for base in cls.__mro__:
1449        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1450        if init is not _no_init_or_replace_init:
1451            cls.__init__ = init
1452            break
1453    else:
1454        # should not happen
1455        cls.__init__ = object.__init__
1456
1457    cls.__init__(self, *args, **kwargs)
def eval_boolean( expression: object, a: SupportsComparison, b: SupportsComparison) -> sqlglot.expressions.core.Boolean | None:
332def eval_boolean(
333    expression: object, a: SupportsComparison, b: SupportsComparison
334) -> exp.Boolean | None:
335    if isinstance(expression, (exp.EQ, exp.Is)):
336        return boolean_literal(a == b)
337    if isinstance(expression, exp.NEQ):
338        return boolean_literal(a != b)
339    if isinstance(expression, exp.GT):
340        return boolean_literal(a > b)
341    if isinstance(expression, exp.GTE):
342        return boolean_literal(a >= b)
343    if isinstance(expression, exp.LT):
344        return boolean_literal(a < b)
345    if isinstance(expression, exp.LTE):
346        return boolean_literal(a <= b)
347    return None
def cast_as_date(value: datetime.datetime | datetime.date | str) -> datetime.date | None:
350def cast_as_date(value: datetime | date | str) -> date | None:
351    if isinstance(value, datetime):
352        return value.date()
353    if isinstance(value, date):
354        return value
355    try:
356        return datetime.fromisoformat(value).date()
357    except ValueError:
358        return None
def cast_as_datetime( value: datetime.datetime | datetime.date | str) -> datetime.datetime | None:
361def cast_as_datetime(
362    value: datetime | date | str,
363) -> datetime | None:
364    if isinstance(value, datetime):
365        return value
366    if isinstance(value, date):
367        return datetime(year=value.year, month=value.month, day=value.day)
368    try:
369        return datetime.fromisoformat(value)
370    except ValueError:
371        return None
def cast_value( value: datetime.datetime | datetime.date | str, to: sqlglot.expressions.datatypes.DataType) -> datetime.date | None:
374def cast_value(value: datetime | date | str, to: exp.DataType) -> date | date | None:
375    if not value:
376        return None
377    if to.is_type(exp.DType.DATE):
378        return cast_as_date(value)
379    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
380        return cast_as_datetime(value)
381    return None
def extract_date(cast: sqlglot.expressions.core.Expr) -> datetime.date | None:
384def extract_date(cast: exp.Expr) -> date | date | None:
385    if isinstance(cast, exp.Cast):
386        to = cast.to
387    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
388        to = exp.DType.DATE.into_expr()
389    else:
390        return None
391
392    if isinstance(cast.this, exp.Literal):
393        value: t.Any = cast.this.name
394    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
395        value = extract_date(cast.this)
396    else:
397        return None
398    return cast_value(value, to)
def extract_interval( expression: sqlglot.expressions.core.Expr) -> dateutil.relativedelta.relativedelta | None:
405def extract_interval(expression: exp.Expr) -> relativedelta | None:
406    try:
407        n = int(expression.this.to_py())
408        unit = expression.text("unit").lower()
409        return interval(unit, n)
410    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
411        return None
def extract_type(*expressions: sqlglot.expressions.core.Expr):
414def extract_type(*expressions: exp.Expr):
415    target_type = None
416    for expression in expressions:
417        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
418        if target_type:
419            break
420
421    return target_type
def date_literal(date: object, target_type=None) -> sqlglot.expressions.core.Expr:
424def date_literal(date: object, target_type=None) -> exp.Expr:
425    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
426        target_type = exp.DType.DATETIME if isinstance(date, datetime) else exp.DType.DATE
427
428    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1) -> dateutil.relativedelta.relativedelta:
431def interval(unit: str, n: int = 1) -> relativedelta:
432    from dateutil.relativedelta import relativedelta
433
434    if unit == "year":
435        return relativedelta(years=1 * n)
436    if unit == "quarter":
437        return relativedelta(months=3 * n)
438    if unit == "month":
439        return relativedelta(months=1 * n)
440    if unit == "week":
441        return relativedelta(weeks=1 * n)
442    if unit == "day":
443        return relativedelta(days=1 * n)
444    if unit == "hour":
445        return relativedelta(hours=1 * n)
446    if unit == "minute":
447        return relativedelta(minutes=1 * n)
448    if unit == "second":
449        return relativedelta(seconds=1 * n)
450
451    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
454def date_floor(d: date, unit: str, dialect: Dialect) -> date:
455    if unit == "year":
456        return d.replace(month=1, day=1)
457    if unit == "quarter":
458        if d.month <= 3:
459            return d.replace(month=1, day=1)
460        elif d.month <= 6:
461            return d.replace(month=4, day=1)
462        elif d.month <= 9:
463            return d.replace(month=7, day=1)
464        else:
465            return d.replace(month=10, day=1)
466    if unit == "month":
467        return d.replace(month=d.month, day=1)
468    if unit == "week":
469        # Assuming week starts on Monday (0) and ends on Sunday (6)
470        return d - timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
471    if unit == "day":
472        return d
473
474    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.Dialect) -> datetime.date:
477def date_ceil(d: date, unit: str, dialect: Dialect) -> date:
478    floor = date_floor(d, unit, dialect)
479
480    if floor == d:
481        return d
482
483    return floor + interval(unit)
def boolean_literal( condition: bool | sqlglot.expressions.core.Predicate) -> sqlglot.expressions.core.Boolean:
486def boolean_literal(condition: bool | exp.Predicate) -> exp.Boolean:
487    return exp.true() if condition else exp.false()
class Simplifier:
 490class Simplifier:
 491    def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
 492        self.dialect = Dialect.get_or_raise(dialect)
 493        self.annotate_new_expressions = annotate_new_expressions
 494
 495        self._annotator: TypeAnnotator = TypeAnnotator(
 496            schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
 497        )
 498
 499    # Value ranges for byte-sized signed/unsigned integers
 500    TINYINT_MIN: t.ClassVar = -128
 501    TINYINT_MAX: t.ClassVar = 127
 502    UTINYINT_MIN: t.ClassVar = 0
 503    UTINYINT_MAX: t.ClassVar = 255
 504
 505    COMPLEMENT_COMPARISONS: t.ClassVar = {
 506        exp.LT: exp.GTE,
 507        exp.GT: exp.LTE,
 508        exp.LTE: exp.GT,
 509        exp.GTE: exp.LT,
 510        exp.EQ: exp.NEQ,
 511        exp.NEQ: exp.EQ,
 512    }
 513
 514    COMPLEMENT_SUBQUERY_PREDICATES: t.ClassVar = {
 515        exp.All: exp.Any,
 516        exp.Any: exp.All,
 517    }
 518
 519    LT_LTE: t.ClassVar = (exp.LT, exp.LTE)
 520    GT_GTE: t.ClassVar = (exp.GT, exp.GTE)
 521
 522    COMPARISONS: t.ClassVar = (
 523        *LT_LTE,
 524        *GT_GTE,
 525        exp.EQ,
 526        exp.NEQ,
 527        exp.Is,
 528    )
 529
 530    INVERSE_COMPARISONS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = {
 531        exp.LT: exp.GT,
 532        exp.GT: exp.LT,
 533        exp.LTE: exp.GTE,
 534        exp.GTE: exp.LTE,
 535    }
 536
 537    NONDETERMINISTIC: t.ClassVar = (exp.Rand, exp.Randn)
 538    AND_OR: t.ClassVar = (exp.And, exp.Or)
 539
 540    INVERSE_DATE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = {
 541        exp.DateAdd: exp.Sub,
 542        exp.DateSub: exp.Add,
 543        exp.DatetimeAdd: exp.Sub,
 544        exp.DatetimeSub: exp.Add,
 545    }
 546
 547    INVERSE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = {
 548        **INVERSE_DATE_OPS,
 549        exp.Add: exp.Sub,
 550        exp.Sub: exp.Add,
 551    }
 552
 553    NULL_OK: t.ClassVar = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 554
 555    CONCATS: t.ClassVar = (exp.Concat, exp.DPipe)
 556
 557    DATETRUNC_BINARY_COMPARISONS: t.ClassVar[dict[type[exp.Expr], DateTruncBinaryTransform]] = {
 558        exp.LT: lambda l, dt, u, d, t: (
 559            l
 560            < date_literal(
 561                dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t
 562            )
 563        ),
 564        exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 565        exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 566        exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 567        exp.EQ: _datetrunc_eq,
 568        exp.NEQ: _datetrunc_neq,
 569    }
 570
 571    DATETRUNC_COMPARISONS: t.ClassVar = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 572    DATETRUNCS: t.ClassVar = (exp.DateTrunc, exp.TimestampTrunc)
 573
 574    SAFE_CONNECTOR_ELIMINATION_RESULT: t.ClassVar = (exp.Connector, exp.Boolean)
 575
 576    # CROSS joins result in an empty table if the right table is empty.
 577    # So we can only simplify certain types of joins to CROSS.
 578    # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 579    JOINS: t.ClassVar = {
 580        ("", ""),
 581        ("", "INNER"),
 582        ("RIGHT", ""),
 583        ("RIGHT", "OUTER"),
 584    }
 585
 586    def simplify(
 587        self,
 588        expression: exp.Expr,
 589        constant_propagation: bool = False,
 590        coalesce_simplification: bool = False,
 591    ) -> exp.Expr:
 592        wheres: list[exp.Where] = []
 593        joins: list[exp.Join] = []
 594
 595        for node in expression.walk(
 596            prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
 597        ):
 598            if node.meta.get(FINAL):
 599                continue
 600
 601            # group by expressions cannot be simplified, for example
 602            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 603            # the projection must exactly match the group by key
 604            group = node.args.get("group")
 605
 606            if group and hasattr(node, "selects"):
 607                groups = set(group.expressions)
 608                group.meta[FINAL] = True
 609
 610                for s in node.selects:
 611                    for n in s.walk():
 612                        if n in groups:
 613                            s.meta[FINAL] = True
 614                            break
 615
 616                having = node.args.get("having")
 617
 618                if having:
 619                    for n in having.walk():
 620                        if n in groups:
 621                            having.meta[FINAL] = True
 622                            break
 623
 624            if isinstance(node, exp.Condition):
 625                simplified = while_changing(
 626                    node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
 627                )
 628
 629                if node is expression:
 630                    expression = simplified
 631            elif isinstance(node, exp.Where):
 632                wheres.append(node)
 633            elif isinstance(node, exp.Join):
 634                # snowflake match_conditions have very strict ordering rules
 635                if match := node.args.get("match_condition"):
 636                    match.meta[FINAL] = True
 637
 638                joins.append(node)
 639
 640        for where in wheres:
 641            if always_true(where.this):
 642                where.pop()
 643        for join in joins:
 644            if (
 645                always_true(join.args.get("on"))
 646                and not join.args.get("using")
 647                and not join.args.get("method")
 648                and (join.side, join.kind) in self.JOINS
 649            ):
 650                join.args["on"].pop()
 651                join.set("side", None)
 652                join.set("kind", "CROSS")
 653
 654        return expression
 655
 656    def _simplify(
 657        self, expression: exp.Expr, constant_propagation: bool, coalesce_simplification: bool
 658    ):
 659        pre_transformation_stack = [expression]
 660        post_transformation_stack = []
 661
 662        while pre_transformation_stack:
 663            original = pre_transformation_stack.pop()
 664            node = original
 665
 666            if not isinstance(node, SIMPLIFIABLE):
 667                if isinstance(node, exp.Query):
 668                    self.simplify(node, constant_propagation, coalesce_simplification)
 669                continue
 670
 671            parent = node.parent
 672            root = node is expression
 673
 674            node = self.rewrite_between(node)
 675            node = self.uniq_sort(node, root)
 676            node = self.absorb_and_eliminate(node, root)
 677            node = self.simplify_concat(node)
 678            node = self.simplify_conditionals(node)
 679
 680            if constant_propagation:
 681                node = propagate_constants(node, root)
 682
 683            if node is not original:
 684                original.replace(node)
 685
 686            for n in node.iter_expressions(reverse=True):
 687                if n.meta.get(FINAL):
 688                    raise
 689            pre_transformation_stack.extend(
 690                n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL)
 691            )
 692            post_transformation_stack.append((node, parent))
 693
 694        while post_transformation_stack:
 695            original, parent = post_transformation_stack.pop()
 696            root = original is expression
 697
 698            # Resets parent, arg_key, index pointers– this is needed because some of the
 699            # previous transformations mutate the AST, leading to an inconsistent state
 700            for k, v in tuple(original.args.items()):
 701                original.set(k, v)
 702
 703            # Post-order transformations
 704            node = self.simplify_not(original)
 705            node = flatten(node)
 706            node = self.simplify_connectors(node, root)
 707            node = self.remove_complements(node, root)
 708
 709            if coalesce_simplification:
 710                node = self.simplify_coalesce(node)
 711            node.parent = parent
 712
 713            node = self.simplify_literals(node, root)
 714            node = self.simplify_equality(node)
 715            node = simplify_parens(node, dialect=self.dialect)
 716            node = self.simplify_datetrunc(node)
 717            node = self.sort_comparison(node)
 718            node = self.simplify_startswith(node)
 719
 720            if node is not original:
 721                original.replace(node)
 722
 723        return node
 724
 725    @annotate_types_on_change
 726    def rewrite_between(self, expression: exp.Expr) -> exp.Expr:
 727        """Rewrite x between y and z to x >= y AND x <= z.
 728
 729        This is done because comparison simplification is only done on lt/lte/gt/gte.
 730        """
 731        if isinstance(expression, exp.Between):
 732            negate = isinstance(expression.parent, exp.Not)
 733
 734            expression = exp.and_(
 735                exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 736                exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 737                copy=False,
 738            )
 739
 740            if negate:
 741                expression = exp.paren(expression, copy=False)
 742
 743        return expression
 744
 745    @annotate_types_on_change
 746    def simplify_not(self, expression: exp.Expr) -> exp.Expr:
 747        """
 748        Demorgan's Law
 749        NOT (x OR y) -> NOT x AND NOT y
 750        NOT (x AND y) -> NOT x OR NOT y
 751        """
 752        if isinstance(expression, exp.Not):
 753            this = expression.this
 754            if is_null(this):
 755                return exp.and_(exp.null(), exp.true(), copy=False)
 756            if this.__class__ in self.COMPLEMENT_COMPARISONS:
 757                right = this.expression
 758                complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
 759                    right.__class__
 760                )
 761                if complement_subquery_predicate:
 762                    right = complement_subquery_predicate(this=right.this)
 763
 764                return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
 765            if isinstance(this, exp.Paren):
 766                condition = this.unnest()
 767                if isinstance(condition, exp.And):
 768                    return exp.paren(
 769                        exp.or_(
 770                            exp.not_(condition.left, copy=False),
 771                            exp.not_(condition.right, copy=False),
 772                            copy=False,
 773                        ),
 774                        copy=False,
 775                    )
 776                if isinstance(condition, exp.Or):
 777                    return exp.paren(
 778                        exp.and_(
 779                            exp.not_(condition.left, copy=False),
 780                            exp.not_(condition.right, copy=False),
 781                            copy=False,
 782                        ),
 783                        copy=False,
 784                    )
 785                if is_null(condition):
 786                    return exp.and_(exp.null(), exp.true(), copy=False)
 787            if always_true(this):
 788                return exp.false()
 789            if is_false(this):
 790                return exp.true()
 791            if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
 792                inner = this.this
 793                if inner.is_type(exp.DType.BOOLEAN):
 794                    # double negation
 795                    # NOT NOT x -> x, if x is BOOLEAN type
 796                    return inner
 797        return expression
 798
 799    @annotate_types_on_change
 800    def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr:
 801        def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr):
 802            if isinstance(expression, exp.And):
 803                if is_false(left) or is_false(right):
 804                    return exp.false()
 805                if is_zero(left) or is_zero(right):
 806                    return exp.false()
 807                if (
 808                    (is_null(left) and is_null(right))
 809                    or (is_null(left) and always_true(right))
 810                    or (always_true(left) and is_null(right))
 811                ):
 812                    return exp.null()
 813                if always_true(left) and always_true(right):
 814                    return exp.true()
 815                if always_true(left):
 816                    return right
 817                if always_true(right):
 818                    return left
 819                return self._simplify_comparison(expression, left, right)
 820            elif isinstance(expression, exp.Or):
 821                if always_true(left) or always_true(right):
 822                    return exp.true()
 823                if (
 824                    (is_null(left) and is_null(right))
 825                    or (is_null(left) and always_false(right))
 826                    or (always_false(left) and is_null(right))
 827                ):
 828                    return exp.null()
 829                if is_false(left):
 830                    return right
 831                if is_false(right):
 832                    return left
 833                return self._simplify_comparison(expression, left, right, or_=True)
 834
 835        if isinstance(expression, exp.Connector):
 836            original_parent = expression.parent
 837            expression = self._flat_simplify(expression, _simplify_connectors, root)
 838
 839            # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
 840            # to ensure that the resulting type is boolean. We know this is true only for connectors,
 841            # boolean values and columns that are essentially operands to a connector:
 842            #
 843            # A AND (((B)))
 844            #          ~ this is safe to keep because it will eventually be part of another connector
 845            if not isinstance(
 846                expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
 847            ) and not expression.is_type(exp.DType.BOOLEAN):
 848                while True:
 849                    if isinstance(original_parent, exp.Connector):
 850                        break
 851                    if not isinstance(original_parent, exp.Paren):
 852                        expression = expression.and_(exp.true(), copy=False)
 853                        break
 854
 855                    original_parent = original_parent.parent
 856
 857        return expression
 858
 859    @annotate_types_on_change
 860    def _simplify_comparison(
 861        self, expression: exp.Expr, left: exp.Expr, right: exp.Expr, or_: bool = False
 862    ) -> exp.Expr | None:
 863        if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS):
 864            ll, lr = left.args.values()
 865            rl, rr = right.args.values()
 866
 867            largs = {ll, lr}
 868            rargs = {rl, rr}
 869
 870            matching = largs & rargs
 871            columns = {
 872                m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC)
 873            }
 874
 875            if matching and columns:
 876                try:
 877                    l = first(largs - columns)
 878                    r = first(rargs - columns)
 879                except StopIteration:
 880                    return expression
 881
 882                if l.is_number and r.is_number:
 883                    l = l.to_py()
 884                    r = r.to_py()
 885                elif l.is_string and r.is_string:
 886                    l = l.name
 887                    r = r.name
 888                else:
 889                    l = extract_date(l)
 890                    if not l:
 891                        return None
 892                    r = extract_date(r)
 893                    if not r:
 894                        return None
 895                    # python won't compare date and datetime, but many engines will upcast
 896                    l, r = cast_as_datetime(l), cast_as_datetime(r)
 897
 898                for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 899                    if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE):
 900                        return left if (av > bv if or_ else av <= bv) else right
 901                    if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE):
 902                        return left if (av < bv if or_ else av >= bv) else right
 903
 904                    # we can't ever shortcut to true because the column could be null
 905                    if not or_:
 906                        if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE):
 907                            if av <= bv:
 908                                return exp.false()
 909                        elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE):
 910                            if av >= bv:
 911                                return exp.false()
 912                        elif isinstance(a, exp.EQ):
 913                            if isinstance(b, exp.LT):
 914                                return exp.false() if av >= bv else a
 915                            if isinstance(b, exp.LTE):
 916                                return exp.false() if av > bv else a
 917                            if isinstance(b, exp.GT):
 918                                return exp.false() if av <= bv else a
 919                            if isinstance(b, exp.GTE):
 920                                return exp.false() if av < bv else a
 921                            if isinstance(b, exp.NEQ):
 922                                return exp.false() if av == bv else a
 923        return None
 924
 925    @annotate_types_on_change
 926    def remove_complements(self, expression: object, root: bool = True) -> object:
 927        """
 928        Removing complements.
 929
 930        A AND NOT A -> FALSE (only for non-NULL A)
 931        A OR NOT A -> TRUE (only for non-NULL A)
 932        """
 933        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 934            ops = set(expression.flatten())
 935            for op in ops:
 936                if isinstance(op, exp.Not) and op.this in ops:
 937                    if expression.meta.get("nonnull") is True:
 938                        return exp.false() if isinstance(expression, exp.And) else exp.true()
 939
 940        return expression
 941
 942    @annotate_types_on_change
 943    def uniq_sort(self, expression: object, root: bool = True) -> object:
 944        """
 945        Uniq and sort a connector.
 946
 947        C AND A AND B AND B -> A AND B AND C
 948        """
 949        if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 950            flattened = tuple(expression.flatten())
 951
 952            if isinstance(expression, exp.Xor):
 953                result_func = exp.xor
 954                # Do not deduplicate XOR as A XOR A != A if A == True
 955                deduped = None
 956                arr = tuple((gen(e), e) for e in flattened)
 957            else:
 958                result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 959                deduped = {gen(e): e for e in flattened}
 960                arr = tuple(deduped.items())
 961
 962            # check if the operands are already sorted, if not sort them
 963            # A AND C AND B -> A AND B AND C
 964            for i, (sql, e) in enumerate(arr[1:]):
 965                if sql < arr[i][0]:
 966                    expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 967                    break
 968            else:
 969                # we didn't have to sort but maybe we need to dedup
 970                if deduped and len(deduped) < len(flattened):
 971                    unique_operand = flattened[0]
 972                    if len(deduped) == 1:
 973                        expression = unique_operand.and_(exp.true(), copy=False)
 974                    else:
 975                        expression = result_func(*deduped.values(), copy=False)
 976
 977        return expression
 978
 979    @annotate_types_on_change
 980    def absorb_and_eliminate(self, expression: object, root: bool = True) -> object:
 981        """
 982        absorption:
 983            A AND (A OR B) -> A
 984            A OR (A AND B) -> A
 985            A AND (NOT A OR B) -> A AND B
 986            A OR (NOT A AND B) -> A OR B
 987        elimination:
 988            (A AND B) OR (A AND NOT B) -> A
 989            (A OR B) AND (A OR NOT B) -> A
 990        """
 991        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 992            kind = exp.Or if isinstance(expression, exp.And) else exp.And
 993
 994            ops = tuple(expression.flatten())
 995
 996            # Initialize lookup tables:
 997            # Set of all operands, used to find complements for absorption.
 998            op_set = set()
 999            # Sub-operands, used to find subsets for absorption.
1000            subops = defaultdict(list)
1001            # Pairs of complements, used for elimination.
1002            pairs = defaultdict(list)
1003
1004            # Populate the lookup tables
1005            for op in ops:
1006                op_set.add(op)
1007
1008                if not isinstance(op, kind):
1009                    # In cases like: A OR (A AND B)
1010                    # Subop will be: ^
1011                    subops[op].append({op})
1012                    continue
1013
1014                # In cases like: (A AND B) OR (A AND B AND C)
1015                # Subops will be: ^     ^
1016                subset = set(op.flatten())
1017                for i in subset:
1018                    subops[i].append(subset)
1019
1020                a, b = op.unnest_operands()
1021                if isinstance(a, exp.Not):
1022                    pairs[frozenset((a.this, b))].append((op, b))
1023                if isinstance(b, exp.Not):
1024                    pairs[frozenset((a, b.this))].append((op, a))
1025
1026            for op in ops:
1027                if not isinstance(op, kind):
1028                    continue
1029
1030                a, b = op.unnest_operands()
1031
1032                # Absorb
1033                if isinstance(a, exp.Not) and a.this in op_set:
1034                    a.replace(exp.true() if kind == exp.And else exp.false())
1035                    continue
1036                if isinstance(b, exp.Not) and b.this in op_set:
1037                    b.replace(exp.true() if kind == exp.And else exp.false())
1038                    continue
1039                superset = set(op.flatten())
1040                if any(any(subset < superset for subset in subops[i]) for i in superset):
1041                    op.replace(exp.false() if kind == exp.And else exp.true())
1042                    continue
1043
1044                # Eliminate
1045                for other, complement in pairs[frozenset((a, b))]:
1046                    op.replace(complement)
1047                    other.replace(complement)
1048
1049        return expression
1050
1051    @annotate_types_on_change
1052    @catch(ModuleNotFoundError, UnsupportedUnit)
1053    def simplify_equality(self, expression: exp.Expr) -> exp.Expr:
1054        """
1055        Use the subtraction and addition properties of equality to simplify expressions:
1056
1057            x + 1 = 3 becomes x = 2
1058
1059        There are two binary operations in the above expression: + and =
1060        Here's how we reference all the operands in the code below:
1061
1062            l     r
1063            x + 1 = 3
1064            a   b
1065        """
1066        if isinstance(expression, self.COMPARISONS):
1067            l, r = expression.left, expression.right
1068
1069            if l.__class__ not in self.INVERSE_OPS:
1070                return expression
1071
1072            if r.is_number:
1073                a_predicate = _is_number
1074                b_predicate = _is_number
1075            elif _is_date_literal(r):
1076                a_predicate = _is_date_literal
1077                b_predicate = _is_interval
1078            else:
1079                return expression
1080
1081            if l.__class__ in self.INVERSE_DATE_OPS:
1082                l = t.cast(exp.IntervalOp, l)
1083                a: exp.Expr = l.this
1084                b: exp.Expr = l.interval()
1085            else:
1086                l = t.cast(exp.Binary, l)
1087                a, b = l.left, l.right
1088
1089            if not a_predicate(a) and b_predicate(b):
1090                pass
1091            elif not a_predicate(b) and b_predicate(a):
1092                a, b = b, a
1093            else:
1094                return expression
1095
1096            return expression.__class__(
1097                this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b)
1098            )
1099        return expression
1100
1101    def _is_inverse_date_op(self, expression: exp.Expr) -> TypeIs[exp.IntervalOp]:
1102        return type(expression) in self.INVERSE_DATE_OPS
1103
1104    @annotate_types_on_change
1105    def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr:
1106        if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1107            return self._flat_simplify(expression, self._simplify_binary, root)
1108
1109        if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1110            return expression.this.this
1111
1112        if self._is_inverse_date_op(expression):
1113            return (
1114                self._simplify_binary(expression, expression.this, expression.interval())
1115                or expression
1116            )
1117
1118        return expression
1119
1120    def _simplify_integer_cast(self, expr: t.Any) -> t.Any:
1121        if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
1122            this = self._simplify_integer_cast(expr.this)
1123        else:
1124            this = expr.this
1125
1126        if isinstance(expr, exp.Cast) and this.is_int:
1127            num = this.to_py()
1128
1129            # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
1130            # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
1131            # engine-dependent
1132            if (
1133                self.TINYINT_MIN <= num <= self.TINYINT_MAX
1134                and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
1135            ) or (
1136                self.UTINYINT_MIN <= num <= self.UTINYINT_MAX
1137                and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
1138            ):
1139                return this
1140
1141        return expr
1142
1143    def _simplify_binary(self, expression, a, b):
1144        if isinstance(expression, self.COMPARISONS):
1145            a = self._simplify_integer_cast(a)
1146            b = self._simplify_integer_cast(b)
1147
1148        if isinstance(expression, exp.Is):
1149            if isinstance(b, exp.Not):
1150                c = b.this
1151                not_ = True
1152            else:
1153                c = b
1154                not_ = False
1155
1156            if is_null(c):
1157                if isinstance(a, exp.Literal):
1158                    return exp.true() if not_ else exp.false()
1159                if is_null(a):
1160                    return exp.false() if not_ else exp.true()
1161        elif isinstance(expression, self.NULL_OK):
1162            return None
1163        elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If):
1164            return exp.null()
1165
1166        if a.is_number and b.is_number:
1167            num_a = a.to_py()
1168            num_b = b.to_py()
1169
1170            if isinstance(expression, exp.Add):
1171                return exp.Literal.number(num_a + num_b)
1172            if isinstance(expression, exp.Mul):
1173                return exp.Literal.number(num_a * num_b)
1174
1175            # We only simplify Sub, Div if a and b have the same parent because they're not associative
1176            if isinstance(expression, exp.Sub):
1177                return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
1178            if isinstance(expression, exp.Div):
1179                # engines have differing int div behavior so intdiv is not safe
1180                if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
1181                    return None
1182                return exp.Literal.number(num_a / num_b)
1183
1184            boolean = eval_boolean(expression, num_a, num_b)
1185
1186            if boolean:
1187                return boolean
1188        elif a.is_string and b.is_string:
1189            boolean = eval_boolean(expression, a.this, b.this)
1190
1191            if boolean:
1192                return boolean
1193        elif _is_date_literal(a) and isinstance(b, exp.Interval):
1194            date, b = extract_date(a), extract_interval(b)
1195            if date and b:
1196                if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
1197                    return date_literal(date + b, extract_type(a))
1198                if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
1199                    return date_literal(date - b, extract_type(a))
1200        elif isinstance(a, exp.Interval) and _is_date_literal(b):
1201            a, date = extract_interval(a), extract_date(b)
1202            # you cannot subtract a date from an interval
1203            if a and b and isinstance(expression, exp.Add):
1204                return date_literal(a + date, extract_type(b))
1205        elif _is_date_literal(a) and _is_date_literal(b):
1206            if isinstance(expression, exp.Predicate):
1207                a, b = extract_date(a), extract_date(b)
1208                boolean = eval_boolean(expression, a, b)
1209                if boolean:
1210                    return boolean
1211
1212        return None
1213
1214    @annotate_types_on_change
1215    def simplify_coalesce(self, expression: exp.Expr) -> exp.Expr:
1216        # COALESCE(x) -> x
1217        if (
1218            isinstance(expression, exp.Coalesce)
1219            and (not expression.expressions or _is_nonnull_constant(expression.this))
1220            # COALESCE is also used as a Spark partitioning hint
1221            and not isinstance(expression.parent, exp.Hint)
1222        ):
1223            return expression.this
1224
1225        if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1226            return expression
1227
1228        if not isinstance(expression, self.COMPARISONS):
1229            return expression
1230
1231        if isinstance(expression.left, exp.Coalesce):
1232            coalesce = expression.left
1233            other = expression.right
1234        elif isinstance(expression.right, exp.Coalesce):
1235            coalesce = expression.right
1236            other = expression.left
1237        else:
1238            return expression
1239
1240        # This transformation is valid for non-constants,
1241        # but it really only does anything if they are both constants.
1242        if not _is_constant(other):
1243            return expression
1244
1245        # Find the first constant arg
1246        for arg_index, arg in enumerate(coalesce.expressions):
1247            if _is_constant(arg):
1248                break
1249        else:
1250            return expression
1251
1252        coalesce.set("expressions", coalesce.expressions[:arg_index])
1253
1254        # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1255        # since we already remove COALESCE at the top of this function.
1256        this: exp.Expr = coalesce if coalesce.expressions else coalesce.this
1257
1258        # This expression is more complex than when we started, but it will get simplified further
1259        return exp.paren(
1260            exp.or_(
1261                exp.and_(
1262                    this.is_(exp.null()).not_(copy=False),
1263                    expression.copy(),
1264                    copy=False,
1265                ),
1266                exp.and_(
1267                    this.is_(exp.null()),
1268                    type(expression)(this=arg.copy(), expression=other.copy()),
1269                    copy=False,
1270                ),
1271                copy=False,
1272            ),
1273            copy=False,
1274        )
1275
1276    @annotate_types_on_change
1277    def simplify_concat(self, expression):
1278        """Reduces all groups that contain string literals by concatenating them."""
1279        if not isinstance(expression, self.CONCATS) or (
1280            # We can't reduce a CONCAT_WS call if we don't statically know the separator
1281            isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1282        ):
1283            return expression
1284
1285        if isinstance(expression, exp.ConcatWs):
1286            sep_expr, *expressions = expression.expressions
1287            sep = sep_expr.name
1288            concat_type = exp.ConcatWs
1289            args = {}
1290        else:
1291            expressions = expression.expressions
1292            sep = ""
1293            concat_type = exp.Concat
1294            args = {
1295                "safe": expression.args.get("safe"),
1296                "coalesce": expression.args.get("coalesce"),
1297            }
1298
1299        new_args = []
1300        for is_string_group, group in itertools.groupby(
1301            expressions or expression.flatten(), lambda e: e.is_string
1302        ):
1303            if is_string_group:
1304                new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1305            else:
1306                new_args.extend(group)
1307
1308        if len(new_args) == 1 and new_args[0].is_string:
1309            return new_args[0]
1310
1311        if concat_type is exp.ConcatWs:
1312            new_args = [sep_expr] + new_args
1313        elif isinstance(expression, exp.DPipe):
1314            return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1315
1316        return concat_type(expressions=new_args, **args)
1317
1318    @annotate_types_on_change
1319    def simplify_conditionals(self, expression):
1320        """Simplifies expressions like IF, CASE if their condition is statically known."""
1321        if isinstance(expression, exp.Case):
1322            this = expression.this
1323            for case in expression.args["ifs"]:
1324                cond = case.this
1325                if this:
1326                    # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1327                    cond = cond.replace(this.pop().eq(cond))
1328
1329                if always_true(cond):
1330                    return case.args["true"]
1331
1332                if always_false(cond):
1333                    case.pop()
1334                    if not expression.args["ifs"]:
1335                        return expression.args.get("default") or exp.null()
1336        elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1337            if always_true(expression.this):
1338                return expression.args["true"]
1339            if always_false(expression.this):
1340                return expression.args.get("false") or exp.null()
1341
1342        return expression
1343
1344    @annotate_types_on_change
1345    def simplify_startswith(self, expression: exp.Expr) -> exp.Expr:
1346        """
1347        Reduces a prefix check to either TRUE or FALSE if both the string and the
1348        prefix are statically known.
1349
1350        Example:
1351            >>> from sqlglot import parse_one
1352            >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1353            'TRUE'
1354        """
1355        if (
1356            isinstance(expression, exp.StartsWith)
1357            and expression.this.is_string
1358            and expression.expression.is_string
1359        ):
1360            return exp.convert(expression.name.startswith(expression.expression.name))
1361
1362        return expression
1363
1364    def _is_datetrunc_predicate(self, left: exp.Expr, right: exp.Expr) -> bool:
1365        return isinstance(left, self.DATETRUNCS) and _is_date_literal(right)
1366
1367    @annotate_types_on_change
1368    @catch(ModuleNotFoundError, UnsupportedUnit)
1369    def simplify_datetrunc(self, expression: exp.Expr) -> exp.Expr:
1370        """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
1371        comparison = expression.__class__
1372
1373        if isinstance(expression, self.DATETRUNCS):
1374            this = expression.this
1375            trunc_type = extract_type(this)
1376            date = extract_date(this)
1377            if date and expression.unit:
1378                return date_literal(
1379                    date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type
1380                )
1381        elif comparison not in self.DATETRUNC_COMPARISONS:
1382            return expression
1383
1384        if isinstance(expression, exp.Binary):
1385            l, r = expression.left, expression.right
1386
1387            if not self._is_datetrunc_predicate(l, r) or not isinstance(
1388                l, (exp.DateTrunc, exp.TimestampTrunc)
1389            ):
1390                return expression
1391
1392            trunc_arg = l.this
1393            unit = l.args["unit"].name.lower()
1394            date = extract_date(r)
1395
1396            if not date:
1397                return expression
1398
1399            return (
1400                self.DATETRUNC_BINARY_COMPARISONS[comparison](
1401                    trunc_arg, date, unit, self.dialect, extract_type(r)
1402                )
1403                or expression
1404            )
1405
1406        if isinstance(expression, exp.In):
1407            l = expression.this
1408            rs = expression.expressions
1409
1410            if (
1411                rs
1412                and all(self._is_datetrunc_predicate(l, r) for r in rs)
1413                and isinstance(l, (exp.DateTrunc, exp.TimestampTrunc))
1414            ):
1415                unit = l.args["unit"].name.lower()
1416
1417                ranges = []
1418                for r in rs:
1419                    date = extract_date(r)
1420                    if not date:
1421                        return expression
1422                    drange = _datetrunc_range(date, unit, self.dialect)
1423                    if drange:
1424                        ranges.append(drange)
1425
1426                if not ranges:
1427                    return expression
1428
1429                ranges = merge_ranges(ranges)
1430                target_type = extract_type(*rs)
1431
1432                return exp.or_(
1433                    *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges],
1434                    copy=False,
1435                )
1436
1437        return expression
1438
1439    @annotate_types_on_change
1440    def sort_comparison(self, expression: exp.Expr) -> exp.Expr:
1441        if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1442            l, r = expression.this, expression.expression
1443            l_column = isinstance(l, exp.Column)
1444            r_column = isinstance(r, exp.Column)
1445            l_const = _is_constant(l)
1446            r_const = _is_constant(r)
1447
1448            if (
1449                (l_column and not r_column)
1450                or (r_const and not l_const)
1451                or isinstance(r, exp.SubqueryPredicate)
1452            ):
1453                return expression
1454            if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1455                return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1456                    this=r, expression=l
1457                )
1458        return expression
1459
1460    def _flat_simplify(
1461        self,
1462        expression: exp.Expr,
1463        simplifier: t.Callable[[exp.Expr, exp.Expr, exp.Expr], exp.Expr | None],
1464        root: bool = True,
1465    ) -> exp.Expr:
1466        if root or not expression.same_parent:
1467            operands = []
1468            queue = deque(expression.flatten(unnest=False))
1469            size = len(queue)
1470
1471            while queue:
1472                a = queue.popleft()
1473
1474                for b in queue:
1475                    result = simplifier(expression, a, b)
1476
1477                    if result and result is not expression:
1478                        queue.remove(b)
1479                        queue.appendleft(result)
1480                        break
1481                else:
1482                    operands.append(a)
1483
1484            if len(operands) < size:
1485                return functools.reduce(
1486                    lambda a, b: expression.__class__(this=a, expression=b), operands
1487                )
1488        return expression
Simplifier( dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None, annotate_new_expressions: bool = True)
491    def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True):
492        self.dialect = Dialect.get_or_raise(dialect)
493        self.annotate_new_expressions = annotate_new_expressions
494
495        self._annotator: TypeAnnotator = TypeAnnotator(
496            schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False
497        )
dialect
annotate_new_expressions
TINYINT_MIN: ClassVar = -128
TINYINT_MAX: ClassVar = 127
UTINYINT_MIN: ClassVar = 0
UTINYINT_MAX: ClassVar = 255
COMPLEMENT_SUBQUERY_PREDICATES: ClassVar = {<class 'sqlglot.expressions.core.All'>: <class 'sqlglot.expressions.core.Any'>, <class 'sqlglot.expressions.core.Any'>: <class 'sqlglot.expressions.core.All'>}
LT_LTE: ClassVar = (<class 'sqlglot.expressions.core.LT'>, <class 'sqlglot.expressions.core.LTE'>)
GT_GTE: ClassVar = (<class 'sqlglot.expressions.core.GT'>, <class 'sqlglot.expressions.core.GTE'>)
NONDETERMINISTIC: ClassVar = (<class 'sqlglot.expressions.functions.Rand'>, <class 'sqlglot.expressions.functions.Randn'>)
AND_OR: ClassVar = (<class 'sqlglot.expressions.core.And'>, <class 'sqlglot.expressions.core.Or'>)
CONCATS: ClassVar = (<class 'sqlglot.expressions.string.Concat'>, <class 'sqlglot.expressions.core.DPipe'>)
DATETRUNC_BINARY_COMPARISONS: ClassVar[dict[type[sqlglot.expressions.core.Expr], Callable[[sqlglot.expressions.core.Expr, datetime.date, str, sqlglot.dialects.Dialect, sqlglot.expressions.datatypes.DataType], Optional[sqlglot.expressions.core.Expr]]]] = {<class 'sqlglot.expressions.core.LT'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.core.GT'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.core.LTE'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.core.GTE'>: <function Simplifier.<lambda>>, <class 'sqlglot.expressions.core.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.core.NEQ'>: <function _datetrunc_neq>}
SAFE_CONNECTOR_ELIMINATION_RESULT: ClassVar = (<class 'sqlglot.expressions.core.Connector'>, <class 'sqlglot.expressions.core.Boolean'>)
JOINS: ClassVar = {('', 'INNER'), ('RIGHT', 'OUTER'), ('RIGHT', ''), ('', '')}
def simplify( self, expression: sqlglot.expressions.core.Expr, constant_propagation: bool = False, coalesce_simplification: bool = False) -> sqlglot.expressions.core.Expr:
586    def simplify(
587        self,
588        expression: exp.Expr,
589        constant_propagation: bool = False,
590        coalesce_simplification: bool = False,
591    ) -> exp.Expr:
592        wheres: list[exp.Where] = []
593        joins: list[exp.Join] = []
594
595        for node in expression.walk(
596            prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL))
597        ):
598            if node.meta.get(FINAL):
599                continue
600
601            # group by expressions cannot be simplified, for example
602            # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
603            # the projection must exactly match the group by key
604            group = node.args.get("group")
605
606            if group and hasattr(node, "selects"):
607                groups = set(group.expressions)
608                group.meta[FINAL] = True
609
610                for s in node.selects:
611                    for n in s.walk():
612                        if n in groups:
613                            s.meta[FINAL] = True
614                            break
615
616                having = node.args.get("having")
617
618                if having:
619                    for n in having.walk():
620                        if n in groups:
621                            having.meta[FINAL] = True
622                            break
623
624            if isinstance(node, exp.Condition):
625                simplified = while_changing(
626                    node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification)
627                )
628
629                if node is expression:
630                    expression = simplified
631            elif isinstance(node, exp.Where):
632                wheres.append(node)
633            elif isinstance(node, exp.Join):
634                # snowflake match_conditions have very strict ordering rules
635                if match := node.args.get("match_condition"):
636                    match.meta[FINAL] = True
637
638                joins.append(node)
639
640        for where in wheres:
641            if always_true(where.this):
642                where.pop()
643        for join in joins:
644            if (
645                always_true(join.args.get("on"))
646                and not join.args.get("using")
647                and not join.args.get("method")
648                and (join.side, join.kind) in self.JOINS
649            ):
650                join.args["on"].pop()
651                join.set("side", None)
652                join.set("kind", "CROSS")
653
654        return expression
@annotate_types_on_change
def rewrite_between( self, expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
725    @annotate_types_on_change
726    def rewrite_between(self, expression: exp.Expr) -> exp.Expr:
727        """Rewrite x between y and z to x >= y AND x <= z.
728
729        This is done because comparison simplification is only done on lt/lte/gt/gte.
730        """
731        if isinstance(expression, exp.Between):
732            negate = isinstance(expression.parent, exp.Not)
733
734            expression = exp.and_(
735                exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
736                exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
737                copy=False,
738            )
739
740            if negate:
741                expression = exp.paren(expression, copy=False)
742
743        return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

@annotate_types_on_change
def simplify_not( self, expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
745    @annotate_types_on_change
746    def simplify_not(self, expression: exp.Expr) -> exp.Expr:
747        """
748        Demorgan's Law
749        NOT (x OR y) -> NOT x AND NOT y
750        NOT (x AND y) -> NOT x OR NOT y
751        """
752        if isinstance(expression, exp.Not):
753            this = expression.this
754            if is_null(this):
755                return exp.and_(exp.null(), exp.true(), copy=False)
756            if this.__class__ in self.COMPLEMENT_COMPARISONS:
757                right = this.expression
758                complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get(
759                    right.__class__
760                )
761                if complement_subquery_predicate:
762                    right = complement_subquery_predicate(this=right.this)
763
764                return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right)
765            if isinstance(this, exp.Paren):
766                condition = this.unnest()
767                if isinstance(condition, exp.And):
768                    return exp.paren(
769                        exp.or_(
770                            exp.not_(condition.left, copy=False),
771                            exp.not_(condition.right, copy=False),
772                            copy=False,
773                        ),
774                        copy=False,
775                    )
776                if isinstance(condition, exp.Or):
777                    return exp.paren(
778                        exp.and_(
779                            exp.not_(condition.left, copy=False),
780                            exp.not_(condition.right, copy=False),
781                            copy=False,
782                        ),
783                        copy=False,
784                    )
785                if is_null(condition):
786                    return exp.and_(exp.null(), exp.true(), copy=False)
787            if always_true(this):
788                return exp.false()
789            if is_false(this):
790                return exp.true()
791            if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION:
792                inner = this.this
793                if inner.is_type(exp.DType.BOOLEAN):
794                    # double negation
795                    # NOT NOT x -> x, if x is BOOLEAN type
796                    return inner
797        return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

@annotate_types_on_change
def simplify_connectors( self, expression: sqlglot.expressions.core.Expr, root: bool = True) -> sqlglot.expressions.core.Expr:
799    @annotate_types_on_change
800    def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr:
801        def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr):
802            if isinstance(expression, exp.And):
803                if is_false(left) or is_false(right):
804                    return exp.false()
805                if is_zero(left) or is_zero(right):
806                    return exp.false()
807                if (
808                    (is_null(left) and is_null(right))
809                    or (is_null(left) and always_true(right))
810                    or (always_true(left) and is_null(right))
811                ):
812                    return exp.null()
813                if always_true(left) and always_true(right):
814                    return exp.true()
815                if always_true(left):
816                    return right
817                if always_true(right):
818                    return left
819                return self._simplify_comparison(expression, left, right)
820            elif isinstance(expression, exp.Or):
821                if always_true(left) or always_true(right):
822                    return exp.true()
823                if (
824                    (is_null(left) and is_null(right))
825                    or (is_null(left) and always_false(right))
826                    or (always_false(left) and is_null(right))
827                ):
828                    return exp.null()
829                if is_false(left):
830                    return right
831                if is_false(right):
832                    return left
833                return self._simplify_comparison(expression, left, right, or_=True)
834
835        if isinstance(expression, exp.Connector):
836            original_parent = expression.parent
837            expression = self._flat_simplify(expression, _simplify_connectors, root)
838
839            # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need
840            # to ensure that the resulting type is boolean. We know this is true only for connectors,
841            # boolean values and columns that are essentially operands to a connector:
842            #
843            # A AND (((B)))
844            #          ~ this is safe to keep because it will eventually be part of another connector
845            if not isinstance(
846                expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT
847            ) and not expression.is_type(exp.DType.BOOLEAN):
848                while True:
849                    if isinstance(original_parent, exp.Connector):
850                        break
851                    if not isinstance(original_parent, exp.Paren):
852                        expression = expression.and_(exp.true(), copy=False)
853                        break
854
855                    original_parent = original_parent.parent
856
857        return expression
@annotate_types_on_change
def remove_complements(self, expression: object, root: bool = True) -> object:
925    @annotate_types_on_change
926    def remove_complements(self, expression: object, root: bool = True) -> object:
927        """
928        Removing complements.
929
930        A AND NOT A -> FALSE (only for non-NULL A)
931        A OR NOT A -> TRUE (only for non-NULL A)
932        """
933        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
934            ops = set(expression.flatten())
935            for op in ops:
936                if isinstance(op, exp.Not) and op.this in ops:
937                    if expression.meta.get("nonnull") is True:
938                        return exp.false() if isinstance(expression, exp.And) else exp.true()
939
940        return expression

Removing complements.

A AND NOT A -> FALSE (only for non-NULL A) A OR NOT A -> TRUE (only for non-NULL A)

@annotate_types_on_change
def uniq_sort(self, expression: object, root: bool = True) -> object:
942    @annotate_types_on_change
943    def uniq_sort(self, expression: object, root: bool = True) -> object:
944        """
945        Uniq and sort a connector.
946
947        C AND A AND B AND B -> A AND B AND C
948        """
949        if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
950            flattened = tuple(expression.flatten())
951
952            if isinstance(expression, exp.Xor):
953                result_func = exp.xor
954                # Do not deduplicate XOR as A XOR A != A if A == True
955                deduped = None
956                arr = tuple((gen(e), e) for e in flattened)
957            else:
958                result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
959                deduped = {gen(e): e for e in flattened}
960                arr = tuple(deduped.items())
961
962            # check if the operands are already sorted, if not sort them
963            # A AND C AND B -> A AND B AND C
964            for i, (sql, e) in enumerate(arr[1:]):
965                if sql < arr[i][0]:
966                    expression = result_func(*(e for _, e in sorted(arr)), copy=False)
967                    break
968            else:
969                # we didn't have to sort but maybe we need to dedup
970                if deduped and len(deduped) < len(flattened):
971                    unique_operand = flattened[0]
972                    if len(deduped) == 1:
973                        expression = unique_operand.and_(exp.true(), copy=False)
974                    else:
975                        expression = result_func(*deduped.values(), copy=False)
976
977        return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

@annotate_types_on_change
def absorb_and_eliminate(self, expression: object, root: bool = True) -> object:
 979    @annotate_types_on_change
 980    def absorb_and_eliminate(self, expression: object, root: bool = True) -> object:
 981        """
 982        absorption:
 983            A AND (A OR B) -> A
 984            A OR (A AND B) -> A
 985            A AND (NOT A OR B) -> A AND B
 986            A OR (NOT A AND B) -> A OR B
 987        elimination:
 988            (A AND B) OR (A AND NOT B) -> A
 989            (A OR B) AND (A OR NOT B) -> A
 990        """
 991        if isinstance(expression, self.AND_OR) and (root or not expression.same_parent):
 992            kind = exp.Or if isinstance(expression, exp.And) else exp.And
 993
 994            ops = tuple(expression.flatten())
 995
 996            # Initialize lookup tables:
 997            # Set of all operands, used to find complements for absorption.
 998            op_set = set()
 999            # Sub-operands, used to find subsets for absorption.
1000            subops = defaultdict(list)
1001            # Pairs of complements, used for elimination.
1002            pairs = defaultdict(list)
1003
1004            # Populate the lookup tables
1005            for op in ops:
1006                op_set.add(op)
1007
1008                if not isinstance(op, kind):
1009                    # In cases like: A OR (A AND B)
1010                    # Subop will be: ^
1011                    subops[op].append({op})
1012                    continue
1013
1014                # In cases like: (A AND B) OR (A AND B AND C)
1015                # Subops will be: ^     ^
1016                subset = set(op.flatten())
1017                for i in subset:
1018                    subops[i].append(subset)
1019
1020                a, b = op.unnest_operands()
1021                if isinstance(a, exp.Not):
1022                    pairs[frozenset((a.this, b))].append((op, b))
1023                if isinstance(b, exp.Not):
1024                    pairs[frozenset((a, b.this))].append((op, a))
1025
1026            for op in ops:
1027                if not isinstance(op, kind):
1028                    continue
1029
1030                a, b = op.unnest_operands()
1031
1032                # Absorb
1033                if isinstance(a, exp.Not) and a.this in op_set:
1034                    a.replace(exp.true() if kind == exp.And else exp.false())
1035                    continue
1036                if isinstance(b, exp.Not) and b.this in op_set:
1037                    b.replace(exp.true() if kind == exp.And else exp.false())
1038                    continue
1039                superset = set(op.flatten())
1040                if any(any(subset < superset for subset in subops[i]) for i in superset):
1041                    op.replace(exp.false() if kind == exp.And else exp.true())
1042                    continue
1043
1044                # Eliminate
1045                for other, complement in pairs[frozenset((a, b))]:
1046                    op.replace(complement)
1047                    other.replace(complement)
1048
1049        return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_equality(expression, *args, **kwargs):
84        def wrapped(expression, *args, **kwargs):
85            try:
86                return func(expression, *args, **kwargs)
87            except exceptions:
88                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

l     r
x + 1 = 3
a   b
@annotate_types_on_change
def simplify_literals( self, expression: sqlglot.expressions.core.Expr, root: bool = True) -> sqlglot.expressions.core.Expr:
1104    @annotate_types_on_change
1105    def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr:
1106        if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
1107            return self._flat_simplify(expression, self._simplify_binary, root)
1108
1109        if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg):
1110            return expression.this.this
1111
1112        if self._is_inverse_date_op(expression):
1113            return (
1114                self._simplify_binary(expression, expression.this, expression.interval())
1115                or expression
1116            )
1117
1118        return expression
@annotate_types_on_change
def simplify_coalesce( self, expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
1214    @annotate_types_on_change
1215    def simplify_coalesce(self, expression: exp.Expr) -> exp.Expr:
1216        # COALESCE(x) -> x
1217        if (
1218            isinstance(expression, exp.Coalesce)
1219            and (not expression.expressions or _is_nonnull_constant(expression.this))
1220            # COALESCE is also used as a Spark partitioning hint
1221            and not isinstance(expression.parent, exp.Hint)
1222        ):
1223            return expression.this
1224
1225        if self.dialect.COALESCE_COMPARISON_NON_STANDARD:
1226            return expression
1227
1228        if not isinstance(expression, self.COMPARISONS):
1229            return expression
1230
1231        if isinstance(expression.left, exp.Coalesce):
1232            coalesce = expression.left
1233            other = expression.right
1234        elif isinstance(expression.right, exp.Coalesce):
1235            coalesce = expression.right
1236            other = expression.left
1237        else:
1238            return expression
1239
1240        # This transformation is valid for non-constants,
1241        # but it really only does anything if they are both constants.
1242        if not _is_constant(other):
1243            return expression
1244
1245        # Find the first constant arg
1246        for arg_index, arg in enumerate(coalesce.expressions):
1247            if _is_constant(arg):
1248                break
1249        else:
1250            return expression
1251
1252        coalesce.set("expressions", coalesce.expressions[:arg_index])
1253
1254        # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
1255        # since we already remove COALESCE at the top of this function.
1256        this: exp.Expr = coalesce if coalesce.expressions else coalesce.this
1257
1258        # This expression is more complex than when we started, but it will get simplified further
1259        return exp.paren(
1260            exp.or_(
1261                exp.and_(
1262                    this.is_(exp.null()).not_(copy=False),
1263                    expression.copy(),
1264                    copy=False,
1265                ),
1266                exp.and_(
1267                    this.is_(exp.null()),
1268                    type(expression)(this=arg.copy(), expression=other.copy()),
1269                    copy=False,
1270                ),
1271                copy=False,
1272            ),
1273            copy=False,
1274        )
@annotate_types_on_change
def simplify_concat(self, expression):
1276    @annotate_types_on_change
1277    def simplify_concat(self, expression):
1278        """Reduces all groups that contain string literals by concatenating them."""
1279        if not isinstance(expression, self.CONCATS) or (
1280            # We can't reduce a CONCAT_WS call if we don't statically know the separator
1281            isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
1282        ):
1283            return expression
1284
1285        if isinstance(expression, exp.ConcatWs):
1286            sep_expr, *expressions = expression.expressions
1287            sep = sep_expr.name
1288            concat_type = exp.ConcatWs
1289            args = {}
1290        else:
1291            expressions = expression.expressions
1292            sep = ""
1293            concat_type = exp.Concat
1294            args = {
1295                "safe": expression.args.get("safe"),
1296                "coalesce": expression.args.get("coalesce"),
1297            }
1298
1299        new_args = []
1300        for is_string_group, group in itertools.groupby(
1301            expressions or expression.flatten(), lambda e: e.is_string
1302        ):
1303            if is_string_group:
1304                new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
1305            else:
1306                new_args.extend(group)
1307
1308        if len(new_args) == 1 and new_args[0].is_string:
1309            return new_args[0]
1310
1311        if concat_type is exp.ConcatWs:
1312            new_args = [sep_expr] + new_args
1313        elif isinstance(expression, exp.DPipe):
1314            return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
1315
1316        return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

@annotate_types_on_change
def simplify_conditionals(self, expression):
1318    @annotate_types_on_change
1319    def simplify_conditionals(self, expression):
1320        """Simplifies expressions like IF, CASE if their condition is statically known."""
1321        if isinstance(expression, exp.Case):
1322            this = expression.this
1323            for case in expression.args["ifs"]:
1324                cond = case.this
1325                if this:
1326                    # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
1327                    cond = cond.replace(this.pop().eq(cond))
1328
1329                if always_true(cond):
1330                    return case.args["true"]
1331
1332                if always_false(cond):
1333                    case.pop()
1334                    if not expression.args["ifs"]:
1335                        return expression.args.get("default") or exp.null()
1336        elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
1337            if always_true(expression.this):
1338                return expression.args["true"]
1339            if always_false(expression.this):
1340                return expression.args.get("false") or exp.null()
1341
1342        return expression

Simplifies expressions like IF, CASE if their condition is statically known.

@annotate_types_on_change
def simplify_startswith( self, expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
1344    @annotate_types_on_change
1345    def simplify_startswith(self, expression: exp.Expr) -> exp.Expr:
1346        """
1347        Reduces a prefix check to either TRUE or FALSE if both the string and the
1348        prefix are statically known.
1349
1350        Example:
1351            >>> from sqlglot import parse_one
1352            >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
1353            'TRUE'
1354        """
1355        if (
1356            isinstance(expression, exp.StartsWith)
1357            and expression.this.is_string
1358            and expression.expression.is_string
1359        ):
1360            return exp.convert(expression.name.startswith(expression.expression.name))
1361
1362        return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
def simplify_datetrunc(expression, *args, **kwargs):
84        def wrapped(expression, *args, **kwargs):
85            try:
86                return func(expression, *args, **kwargs)
87            except exceptions:
88                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

@annotate_types_on_change
def sort_comparison( self, expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
1439    @annotate_types_on_change
1440    def sort_comparison(self, expression: exp.Expr) -> exp.Expr:
1441        if expression.__class__ in self.COMPLEMENT_COMPARISONS:
1442            l, r = expression.this, expression.expression
1443            l_column = isinstance(l, exp.Column)
1444            r_column = isinstance(r, exp.Column)
1445            l_const = _is_constant(l)
1446            r_const = _is_constant(r)
1447
1448            if (
1449                (l_column and not r_column)
1450                or (r_const and not l_const)
1451                or isinstance(r, exp.SubqueryPredicate)
1452            ):
1453                return expression
1454            if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1455                return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1456                    this=r, expression=l
1457                )
1458        return expression
def gen(expression: sqlglot.expressions.core.Expr, comments: bool = False) -> str:
1491def gen(expression: exp.Expr, comments: bool = False) -> str:
1492    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1493
1494    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1495    generator is expensive so we have a bare minimum sql generator here.
1496
1497    Args:
1498        expression: the expression to convert into a SQL string.
1499        comments: whether to include the expression's comments.
1500    """
1501    return Gen().gen(expression, comments=comments)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

Arguments:
  • expression: the expression to convert into a SQL string.
  • comments: whether to include the expression's comments.
class Gen:
1504class Gen:
1505    def __init__(self):
1506        self.stack = []
1507        self.sqls = []
1508
1509    def gen(self, expression: exp.Expr, comments: bool = False) -> str:
1510        self.stack = [expression]
1511        self.sqls.clear()
1512
1513        while self.stack:
1514            node = self.stack.pop()
1515
1516            if isinstance(node, exp.Expr):
1517                if comments and node.comments:
1518                    self.stack.append(f" /*{','.join(node.comments)}*/")
1519
1520                exp_handler_name = f"{node.key}_sql"
1521
1522                if hasattr(self, exp_handler_name):
1523                    getattr(self, exp_handler_name)(node)
1524                elif isinstance(node, exp.Func):
1525                    self._function(node)
1526                else:
1527                    key = node.key.upper()
1528                    self.stack.append(f"{key} " if self._args(node) else key)
1529            elif type(node) is list:
1530                for n in reversed(node):
1531                    if n is not None:
1532                        self.stack.extend((n, ","))
1533                if node:
1534                    self.stack.pop()
1535            else:
1536                if node is not None:
1537                    self.sqls.append(str(node))
1538
1539        return "".join(self.sqls)
1540
1541    def add_sql(self, e: exp.Add) -> None:
1542        self._binary(e, " + ")
1543
1544    def alias_sql(self, e: exp.Alias) -> None:
1545        self.stack.extend(
1546            (
1547                e.args.get("alias"),
1548                " AS ",
1549                e.args.get("this"),
1550            )
1551        )
1552
1553    def and_sql(self, e: exp.And) -> None:
1554        self._binary(e, " AND ")
1555
1556    def anonymous_sql(self, e: exp.Anonymous) -> None:
1557        this = e.this
1558        if isinstance(this, str):
1559            name = this.upper()
1560        elif isinstance(this, exp.Identifier):
1561            name = this.this
1562            name = f'"{name}"' if this.quoted else name.upper()
1563        else:
1564            raise ValueError(
1565                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1566            )
1567
1568        self.stack.extend(
1569            (
1570                ")",
1571                e.expressions,
1572                "(",
1573                name,
1574            )
1575        )
1576
1577    def between_sql(self, e: exp.Between) -> None:
1578        self.stack.extend(
1579            (
1580                e.args.get("high"),
1581                " AND ",
1582                e.args.get("low"),
1583                " BETWEEN ",
1584                e.this,
1585            )
1586        )
1587
1588    def boolean_sql(self, e: exp.Boolean) -> None:
1589        self.stack.append("TRUE" if e.this else "FALSE")
1590
1591    def bracket_sql(self, e: exp.Bracket) -> None:
1592        self.stack.extend(
1593            (
1594                "]",
1595                e.expressions,
1596                "[",
1597                e.this,
1598            )
1599        )
1600
1601    def column_sql(self, e: exp.Column) -> None:
1602        for p in reversed(e.parts):
1603            self.stack.extend((p, "."))
1604        self.stack.pop()
1605
1606    def datatype_sql(self, e: exp.DataType) -> None:
1607        self._args(e, 1)
1608        self.stack.append(f"{e.this.name} ")
1609
1610    def div_sql(self, e: exp.Div) -> None:
1611        self._binary(e, " / ")
1612
1613    def dot_sql(self, e: exp.Dot) -> None:
1614        self._binary(e, ".")
1615
1616    def eq_sql(self, e: exp.EQ) -> None:
1617        self._binary(e, " = ")
1618
1619    def from_sql(self, e: exp.From) -> None:
1620        self.stack.extend((e.this, "FROM "))
1621
1622    def gt_sql(self, e: exp.GT) -> None:
1623        self._binary(e, " > ")
1624
1625    def gte_sql(self, e: exp.GTE) -> None:
1626        self._binary(e, " >= ")
1627
1628    def identifier_sql(self, e: exp.Identifier) -> None:
1629        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1630
1631    def ilike_sql(self, e: exp.ILike) -> None:
1632        self._binary(e, " ILIKE ")
1633
1634    def in_sql(self, e: exp.In) -> None:
1635        self.stack.append(")")
1636        self._args(e, 1)
1637        self.stack.extend(
1638            (
1639                "(",
1640                " IN ",
1641                e.this,
1642            )
1643        )
1644
1645    def intdiv_sql(self, e: exp.IntDiv) -> None:
1646        self._binary(e, " DIV ")
1647
1648    def is_sql(self, e: exp.Is) -> None:
1649        self._binary(e, " IS ")
1650
1651    def like_sql(self, e: exp.Like) -> None:
1652        self._binary(e, " Like ")
1653
1654    def literal_sql(self, e: exp.Literal) -> None:
1655        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1656
1657    def lt_sql(self, e: exp.LT) -> None:
1658        self._binary(e, " < ")
1659
1660    def lte_sql(self, e: exp.LTE) -> None:
1661        self._binary(e, " <= ")
1662
1663    def mod_sql(self, e: exp.Mod) -> None:
1664        self._binary(e, " % ")
1665
1666    def mul_sql(self, e: exp.Mul) -> None:
1667        self._binary(e, " * ")
1668
1669    def neg_sql(self, e: exp.Neg) -> None:
1670        self._unary(e, "-")
1671
1672    def neq_sql(self, e: exp.NEQ) -> None:
1673        self._binary(e, " <> ")
1674
1675    def not_sql(self, e: exp.Not) -> None:
1676        self._unary(e, "NOT ")
1677
1678    def null_sql(self, e: exp.Null) -> None:
1679        self.stack.append("NULL")
1680
1681    def or_sql(self, e: exp.Or) -> None:
1682        self._binary(e, " OR ")
1683
1684    def paren_sql(self, e: exp.Paren) -> None:
1685        self.stack.extend(
1686            (
1687                ")",
1688                e.this,
1689                "(",
1690            )
1691        )
1692
1693    def sub_sql(self, e: exp.Sub) -> None:
1694        self._binary(e, " - ")
1695
1696    def subquery_sql(self, e: exp.Subquery) -> None:
1697        self._args(e, 2)
1698        alias = e.args.get("alias")
1699        if alias:
1700            self.stack.append(alias)
1701        self.stack.extend((")", e.this, "("))
1702
1703    def table_sql(self, e: exp.Table) -> None:
1704        self._args(e, 4)
1705        alias = e.args.get("alias")
1706        if alias:
1707            self.stack.append(alias)
1708        for p in reversed(e.parts):
1709            self.stack.extend((p, "."))
1710        self.stack.pop()
1711
1712    def tablealias_sql(self, e: exp.TableAlias) -> None:
1713        columns = e.columns
1714
1715        if columns:
1716            self.stack.extend((")", columns, "("))
1717
1718        self.stack.extend((e.this, " AS "))
1719
1720    def var_sql(self, e: exp.Var) -> None:
1721        self.stack.append(e.this)
1722
1723    def _binary(self, e: exp.Binary, op: str) -> None:
1724        self.stack.extend((e.expression, op, e.this))
1725
1726    def _unary(self, e: exp.Unary, op: str) -> None:
1727        self.stack.extend((e.this, op))
1728
1729    def _function(self, e: exp.Func) -> None:
1730        self.stack.extend(
1731            (
1732                ")",
1733                list(e.args.values()),
1734                "(",
1735                e.sql_name(),
1736            )
1737        )
1738
1739    def _args(self, node: exp.Expr, arg_index: int = 0) -> bool:
1740        kvs = []
1741        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1742
1743        for k in arg_types:
1744            v = node.args.get(k)
1745
1746            if v is not None:
1747                kvs.append([f":{k}", v])
1748        if kvs:
1749            self.stack.append(kvs)
1750            return True
1751        return False
stack
sqls
def gen( self, expression: sqlglot.expressions.core.Expr, comments: bool = False) -> str:
1509    def gen(self, expression: exp.Expr, comments: bool = False) -> str:
1510        self.stack = [expression]
1511        self.sqls.clear()
1512
1513        while self.stack:
1514            node = self.stack.pop()
1515
1516            if isinstance(node, exp.Expr):
1517                if comments and node.comments:
1518                    self.stack.append(f" /*{','.join(node.comments)}*/")
1519
1520                exp_handler_name = f"{node.key}_sql"
1521
1522                if hasattr(self, exp_handler_name):
1523                    getattr(self, exp_handler_name)(node)
1524                elif isinstance(node, exp.Func):
1525                    self._function(node)
1526                else:
1527                    key = node.key.upper()
1528                    self.stack.append(f"{key} " if self._args(node) else key)
1529            elif type(node) is list:
1530                for n in reversed(node):
1531                    if n is not None:
1532                        self.stack.extend((n, ","))
1533                if node:
1534                    self.stack.pop()
1535            else:
1536                if node is not None:
1537                    self.sqls.append(str(node))
1538
1539        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.core.Add) -> None:
1541    def add_sql(self, e: exp.Add) -> None:
1542        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.core.Alias) -> None:
1544    def alias_sql(self, e: exp.Alias) -> None:
1545        self.stack.extend(
1546            (
1547                e.args.get("alias"),
1548                " AS ",
1549                e.args.get("this"),
1550            )
1551        )
def and_sql(self, e: sqlglot.expressions.core.And) -> None:
1553    def and_sql(self, e: exp.And) -> None:
1554        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.core.Anonymous) -> None:
1556    def anonymous_sql(self, e: exp.Anonymous) -> None:
1557        this = e.this
1558        if isinstance(this, str):
1559            name = this.upper()
1560        elif isinstance(this, exp.Identifier):
1561            name = this.this
1562            name = f'"{name}"' if this.quoted else name.upper()
1563        else:
1564            raise ValueError(
1565                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1566            )
1567
1568        self.stack.extend(
1569            (
1570                ")",
1571                e.expressions,
1572                "(",
1573                name,
1574            )
1575        )
def between_sql(self, e: sqlglot.expressions.core.Between) -> None:
1577    def between_sql(self, e: exp.Between) -> None:
1578        self.stack.extend(
1579            (
1580                e.args.get("high"),
1581                " AND ",
1582                e.args.get("low"),
1583                " BETWEEN ",
1584                e.this,
1585            )
1586        )
def boolean_sql(self, e: sqlglot.expressions.core.Boolean) -> None:
1588    def boolean_sql(self, e: exp.Boolean) -> None:
1589        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.core.Bracket) -> None:
1591    def bracket_sql(self, e: exp.Bracket) -> None:
1592        self.stack.extend(
1593            (
1594                "]",
1595                e.expressions,
1596                "[",
1597                e.this,
1598            )
1599        )
def column_sql(self, e: sqlglot.expressions.core.Column) -> None:
1601    def column_sql(self, e: exp.Column) -> None:
1602        for p in reversed(e.parts):
1603            self.stack.extend((p, "."))
1604        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.datatypes.DataType) -> None:
1606    def datatype_sql(self, e: exp.DataType) -> None:
1607        self._args(e, 1)
1608        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.core.Div) -> None:
1610    def div_sql(self, e: exp.Div) -> None:
1611        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.core.Dot) -> None:
1613    def dot_sql(self, e: exp.Dot) -> None:
1614        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.core.EQ) -> None:
1616    def eq_sql(self, e: exp.EQ) -> None:
1617        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.query.From) -> None:
1619    def from_sql(self, e: exp.From) -> None:
1620        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.core.GT) -> None:
1622    def gt_sql(self, e: exp.GT) -> None:
1623        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.core.GTE) -> None:
1625    def gte_sql(self, e: exp.GTE) -> None:
1626        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.core.Identifier) -> None:
1628    def identifier_sql(self, e: exp.Identifier) -> None:
1629        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.core.ILike) -> None:
1631    def ilike_sql(self, e: exp.ILike) -> None:
1632        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.core.In) -> None:
1634    def in_sql(self, e: exp.In) -> None:
1635        self.stack.append(")")
1636        self._args(e, 1)
1637        self.stack.extend(
1638            (
1639                "(",
1640                " IN ",
1641                e.this,
1642            )
1643        )
def intdiv_sql(self, e: sqlglot.expressions.core.IntDiv) -> None:
1645    def intdiv_sql(self, e: exp.IntDiv) -> None:
1646        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.core.Is) -> None:
1648    def is_sql(self, e: exp.Is) -> None:
1649        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.core.Like) -> None:
1651    def like_sql(self, e: exp.Like) -> None:
1652        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.core.Literal) -> None:
1654    def literal_sql(self, e: exp.Literal) -> None:
1655        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.core.LT) -> None:
1657    def lt_sql(self, e: exp.LT) -> None:
1658        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.core.LTE) -> None:
1660    def lte_sql(self, e: exp.LTE) -> None:
1661        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.core.Mod) -> None:
1663    def mod_sql(self, e: exp.Mod) -> None:
1664        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.core.Mul) -> None:
1666    def mul_sql(self, e: exp.Mul) -> None:
1667        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.core.Neg) -> None:
1669    def neg_sql(self, e: exp.Neg) -> None:
1670        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.core.NEQ) -> None:
1672    def neq_sql(self, e: exp.NEQ) -> None:
1673        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.core.Not) -> None:
1675    def not_sql(self, e: exp.Not) -> None:
1676        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.core.Null) -> None:
1678    def null_sql(self, e: exp.Null) -> None:
1679        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.core.Or) -> None:
1681    def or_sql(self, e: exp.Or) -> None:
1682        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.core.Paren) -> None:
1684    def paren_sql(self, e: exp.Paren) -> None:
1685        self.stack.extend(
1686            (
1687                ")",
1688                e.this,
1689                "(",
1690            )
1691        )
def sub_sql(self, e: sqlglot.expressions.core.Sub) -> None:
1693    def sub_sql(self, e: exp.Sub) -> None:
1694        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.query.Subquery) -> None:
1696    def subquery_sql(self, e: exp.Subquery) -> None:
1697        self._args(e, 2)
1698        alias = e.args.get("alias")
1699        if alias:
1700            self.stack.append(alias)
1701        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.query.Table) -> None:
1703    def table_sql(self, e: exp.Table) -> None:
1704        self._args(e, 4)
1705        alias = e.args.get("alias")
1706        if alias:
1707            self.stack.append(alias)
1708        for p in reversed(e.parts):
1709            self.stack.extend((p, "."))
1710        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.query.TableAlias) -> None:
1712    def tablealias_sql(self, e: exp.TableAlias) -> None:
1713        columns = e.columns
1714
1715        if columns:
1716            self.stack.extend((")", columns, "("))
1717
1718        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.core.Var) -> None:
1720    def var_sql(self, e: exp.Var) -> None:
1721        self.stack.append(e.this)