sqlglot.optimizer.simplify
1from __future__ import annotations 2 3from datetime import date, datetime, timedelta 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce, wraps 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.annotate_types import TypeAnnotator 15from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 16from sqlglot.schema import ensure_schema 17 18 19if t.TYPE_CHECKING: 20 from dateutil.relativedelta import relativedelta 21 from sqlglot.dialects.dialect import DialectType 22 from typing_extensions import TypeIs 23 24 DateRange = tuple[date, date] 25 DateTruncBinaryTransform = t.Callable[ 26 [exp.Expr, date, str, Dialect, exp.DataType], t.Optional[exp.Expr] 27 ] 28 29logger = logging.getLogger("sqlglot") 30 31 32# Final means that an expression should not be simplified 33FINAL = "final" 34 35SIMPLIFIABLE = ( 36 exp.Binary, 37 exp.Func, 38 exp.Lambda, 39 exp.Predicate, 40 exp.Unary, 41) 42 43 44def simplify( 45 expression: exp.Expr, 46 constant_propagation: bool = False, 47 coalesce_simplification: bool = False, 48 dialect: DialectType = None, 49): 50 """ 51 Rewrite sqlglot AST to simplify expressions. 52 53 Example: 54 >>> import sqlglot 55 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 56 >>> simplify(expression).sql() 57 'TRUE' 58 59 Args: 60 expression: expression to simplify 61 constant_propagation: whether the constant propagation rule should be used 62 coalesce_simplification: whether the simplify coalesce rule should be used. 63 This rule tries to remove coalesce functions, which can be useful in certain analyses but 64 can leave the query more verbose. 65 Returns: 66 sqlglot.Expr: simplified expression 67 """ 68 return Simplifier(dialect=dialect).simplify( 69 expression, 70 constant_propagation=constant_propagation, 71 coalesce_simplification=coalesce_simplification, 72 ) 73 74 75class UnsupportedUnit(Exception): 76 pass 77 78 79def catch(*exceptions): 80 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 81 82 def decorator(func): 83 def wrapped(expression, *args, **kwargs): 84 try: 85 return func(expression, *args, **kwargs) 86 except exceptions: 87 return expression 88 89 return wrapped 90 91 return decorator 92 93 94def annotate_types_on_change(func): 95 @wraps(func) 96 def _func(self, expression: exp.Expr, *args, **kwargs) -> exp.Expr | None: 97 new_expression: exp.Expr | None = func(self, expression, *args, **kwargs) 98 99 if new_expression is None: 100 return new_expression 101 102 if self.annotate_new_expressions and expression != new_expression: 103 self._annotator.clear() 104 105 # We annotate this to ensure new children nodes are also annotated 106 new_expression = self._annotator.annotate( 107 expression=new_expression, 108 annotate_scope=False, 109 ) 110 111 # Whatever expression the original expression is transformed into needs to preserve 112 # the original type, otherwise the simplification could result in a different schema 113 new_expression.type = expression.type 114 115 return new_expression 116 117 return _func 118 119 120def flatten(expression: exp.Expr) -> exp.Expr: 121 """ 122 A AND (B AND C) -> A AND B AND C 123 A OR (B OR C) -> A OR B OR C 124 """ 125 if isinstance(expression, exp.Connector): 126 for node in expression.args.values(): 127 child = node.unnest() 128 if isinstance(child, expression.__class__): 129 node.replace(child) 130 return expression 131 132 133def simplify_parens(expression: exp.Expr, dialect: DialectType) -> exp.Expr: 134 if not isinstance(expression, exp.Paren): 135 return expression 136 137 this = expression.this 138 parent = expression.parent 139 parent_is_predicate = isinstance(parent, exp.Predicate) 140 141 if isinstance(this, exp.Select): 142 return expression 143 144 if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): 145 return expression 146 147 if ( 148 Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS 149 and isinstance(parent, exp.Dot) 150 and (isinstance(parent.right, (exp.Identifier, exp.Star))) 151 ): 152 return expression 153 154 if isinstance(this, exp.Predicate) and ( 155 not ( 156 parent_is_predicate 157 or isinstance(parent, exp.Neg) 158 or (isinstance(parent, exp.Binary) and not isinstance(parent, exp.Connector)) 159 ) 160 ): 161 return this 162 163 if ( 164 not isinstance(parent, (exp.Condition, exp.Binary)) 165 or isinstance(parent, exp.Paren) 166 or ( 167 not isinstance(this, exp.Binary) 168 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 169 ) 170 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 171 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 172 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 173 ): 174 return this 175 176 return expression 177 178 179def propagate_constants(expression, root=True): 180 """ 181 Propagate constants for conjunctions in DNF: 182 183 SELECT * FROM t WHERE a = b AND b = 5 becomes 184 SELECT * FROM t WHERE a = 5 AND b = 5 185 186 Reference: https://www.sqlite.org/optoverview.html 187 """ 188 189 if ( 190 isinstance(expression, exp.And) 191 and (root or not expression.same_parent) 192 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 193 ): 194 constant_mapping = {} 195 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 196 if isinstance(expr, exp.EQ): 197 l, r = expr.left, expr.right 198 199 # TODO: create a helper that can be used to detect nested literal expressions such 200 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 201 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 202 constant_mapping[l] = (id(l), r) 203 204 if constant_mapping: 205 for column in find_all_in_scope(expression, exp.Column): 206 parent = column.parent 207 column_id, constant = constant_mapping.get(column) or (None, None) 208 if ( 209 column_id is not None 210 and id(column) != column_id 211 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 212 ): 213 column.replace(constant.copy()) 214 215 return expression 216 217 218def _is_number(expression: exp.Expr) -> bool: 219 return expression.is_number 220 221 222def _is_interval(expression: exp.Expr) -> bool: 223 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 224 225 226def _is_nonnull_constant(expression: exp.Expr) -> bool: 227 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 228 229 230def _is_constant(expression: exp.Expr) -> bool: 231 expr = expression.this if isinstance(expression, exp.Neg) else expression 232 return isinstance(expr, exp.CONSTANTS) or _is_date_literal(expr) 233 234 235def _datetrunc_range(date: date, unit: str, dialect: Dialect) -> DateRange | None: 236 """ 237 Get the date range for a DATE_TRUNC equality comparison: 238 239 Example: 240 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 241 Returns: 242 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 243 """ 244 floor = date_floor(date, unit, dialect) 245 246 if date != floor: 247 # This will always be False, except for NULL values. 248 return None 249 250 return floor, floor + interval(unit) 251 252 253def _datetrunc_eq_expression( 254 left: exp.Expr, drange: DateRange, target_type: exp.DataType | None 255) -> exp.Expr: 256 """Get the logical expression for a date range""" 257 return exp.and_( 258 left >= date_literal(drange[0], target_type), 259 left < date_literal(drange[1], target_type), 260 copy=False, 261 ) 262 263 264def _datetrunc_eq( 265 left: exp.Expr, 266 date: date, 267 unit: str, 268 dialect: Dialect, 269 target_type: exp.DataType | None, 270) -> exp.Expr | None: 271 drange = _datetrunc_range(date, unit, dialect) 272 if not drange: 273 return None 274 275 return _datetrunc_eq_expression(left, drange, target_type) 276 277 278def _datetrunc_neq( 279 left: exp.Expr, 280 date: date, 281 unit: str, 282 dialect: Dialect, 283 target_type: exp.DataType | None, 284) -> exp.Expr | None: 285 drange = _datetrunc_range(date, unit, dialect) 286 if not drange: 287 return None 288 289 return exp.and_( 290 left < date_literal(drange[0], target_type), 291 left >= date_literal(drange[1], target_type), 292 copy=False, 293 ) 294 295 296def always_true(expression: object) -> bool: 297 return (isinstance(expression, exp.Boolean) and expression.this) or ( 298 isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression) 299 ) 300 301 302def always_false(expression: object) -> bool: 303 return is_false(expression) or is_null(expression) or is_zero(expression) 304 305 306def is_zero(expression: object) -> bool: 307 return isinstance(expression, exp.Literal) and expression.to_py() == 0 308 309 310def is_complement(a: object, b: object) -> bool: 311 return isinstance(b, exp.Not) and b.this == a 312 313 314def is_false(a: object) -> bool: 315 return type(a) is exp.Boolean and not a.this 316 317 318def is_null(a: object) -> bool: 319 return type(a) is exp.Null 320 321 322class SupportsComparison(t.Protocol): 323 """Protocol for expressions or values that can be compared using <, <=, >, >=.""" 324 325 def __lt__(self, other: t.Any) -> bool: ... 326 def __le__(self, other: t.Any) -> bool: ... 327 def __gt__(self, other: t.Any) -> bool: ... 328 def __ge__(self, other: t.Any) -> bool: ... 329 330 331def eval_boolean( 332 expression: object, a: SupportsComparison, b: SupportsComparison 333) -> exp.Boolean | None: 334 if isinstance(expression, (exp.EQ, exp.Is)): 335 return boolean_literal(a == b) 336 if isinstance(expression, exp.NEQ): 337 return boolean_literal(a != b) 338 if isinstance(expression, exp.GT): 339 return boolean_literal(a > b) 340 if isinstance(expression, exp.GTE): 341 return boolean_literal(a >= b) 342 if isinstance(expression, exp.LT): 343 return boolean_literal(a < b) 344 if isinstance(expression, exp.LTE): 345 return boolean_literal(a <= b) 346 return None 347 348 349def cast_as_date(value: datetime | date | str) -> date | None: 350 if isinstance(value, datetime): 351 return value.date() 352 if isinstance(value, date): 353 return value 354 try: 355 return datetime.fromisoformat(value).date() 356 except ValueError: 357 return None 358 359 360def cast_as_datetime( 361 value: datetime | date | str, 362) -> datetime | None: 363 if isinstance(value, datetime): 364 return value 365 if isinstance(value, date): 366 return datetime(year=value.year, month=value.month, day=value.day) 367 try: 368 return datetime.fromisoformat(value) 369 except ValueError: 370 return None 371 372 373def cast_value(value: datetime | date | str, to: exp.DataType) -> date | date | None: 374 if not value: 375 return None 376 if to.is_type(exp.DType.DATE): 377 return cast_as_date(value) 378 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 379 return cast_as_datetime(value) 380 return None 381 382 383def extract_date(cast: exp.Expr) -> date | date | None: 384 if isinstance(cast, exp.Cast): 385 to = cast.to 386 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 387 to = exp.DType.DATE.into_expr() 388 else: 389 return None 390 391 if isinstance(cast.this, exp.Literal): 392 value: t.Any = cast.this.name 393 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 394 value = extract_date(cast.this) 395 else: 396 return None 397 return cast_value(value, to) 398 399 400def _is_date_literal(expression: exp.Expr) -> bool: 401 return extract_date(expression) is not None 402 403 404def extract_interval(expression: exp.Expr) -> relativedelta | None: 405 try: 406 n = int(expression.this.to_py()) 407 unit = expression.text("unit").lower() 408 return interval(unit, n) 409 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 410 return None 411 412 413def extract_type(*expressions: exp.Expr): 414 target_type = None 415 for expression in expressions: 416 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 417 if target_type: 418 break 419 420 return target_type 421 422 423def date_literal(date: object, target_type=None) -> exp.Expr: 424 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 425 target_type = exp.DType.DATETIME if isinstance(date, datetime) else exp.DType.DATE 426 427 return exp.cast(exp.Literal.string(date), target_type) 428 429 430def interval(unit: str, n: int = 1) -> relativedelta: 431 from dateutil.relativedelta import relativedelta 432 433 if unit == "year": 434 return relativedelta(years=1 * n) 435 if unit == "quarter": 436 return relativedelta(months=3 * n) 437 if unit == "month": 438 return relativedelta(months=1 * n) 439 if unit == "week": 440 return relativedelta(weeks=1 * n) 441 if unit == "day": 442 return relativedelta(days=1 * n) 443 if unit == "hour": 444 return relativedelta(hours=1 * n) 445 if unit == "minute": 446 return relativedelta(minutes=1 * n) 447 if unit == "second": 448 return relativedelta(seconds=1 * n) 449 450 raise UnsupportedUnit(f"Unsupported unit: {unit}") 451 452 453def date_floor(d: date, unit: str, dialect: Dialect) -> date: 454 if unit == "year": 455 return d.replace(month=1, day=1) 456 if unit == "quarter": 457 if d.month <= 3: 458 return d.replace(month=1, day=1) 459 elif d.month <= 6: 460 return d.replace(month=4, day=1) 461 elif d.month <= 9: 462 return d.replace(month=7, day=1) 463 else: 464 return d.replace(month=10, day=1) 465 if unit == "month": 466 return d.replace(month=d.month, day=1) 467 if unit == "week": 468 # Assuming week starts on Monday (0) and ends on Sunday (6) 469 return d - timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 470 if unit == "day": 471 return d 472 473 raise UnsupportedUnit(f"Unsupported unit: {unit}") 474 475 476def date_ceil(d: date, unit: str, dialect: Dialect) -> date: 477 floor = date_floor(d, unit, dialect) 478 479 if floor == d: 480 return d 481 482 return floor + interval(unit) 483 484 485def boolean_literal(condition: bool | exp.Predicate) -> exp.Boolean: 486 return exp.true() if condition else exp.false() 487 488 489class Simplifier: 490 def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): 491 self.dialect = Dialect.get_or_raise(dialect) 492 self.annotate_new_expressions = annotate_new_expressions 493 494 self._annotator: TypeAnnotator = TypeAnnotator( 495 schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False 496 ) 497 498 # Value ranges for byte-sized signed/unsigned integers 499 TINYINT_MIN: t.ClassVar = -128 500 TINYINT_MAX: t.ClassVar = 127 501 UTINYINT_MIN: t.ClassVar = 0 502 UTINYINT_MAX: t.ClassVar = 255 503 504 COMPLEMENT_COMPARISONS: t.ClassVar = { 505 exp.LT: exp.GTE, 506 exp.GT: exp.LTE, 507 exp.LTE: exp.GT, 508 exp.GTE: exp.LT, 509 exp.EQ: exp.NEQ, 510 exp.NEQ: exp.EQ, 511 } 512 513 COMPLEMENT_SUBQUERY_PREDICATES: t.ClassVar = { 514 exp.All: exp.Any, 515 exp.Any: exp.All, 516 } 517 518 LT_LTE: t.ClassVar = (exp.LT, exp.LTE) 519 GT_GTE: t.ClassVar = (exp.GT, exp.GTE) 520 521 COMPARISONS: t.ClassVar = ( 522 *LT_LTE, 523 *GT_GTE, 524 exp.EQ, 525 exp.NEQ, 526 exp.Is, 527 ) 528 529 INVERSE_COMPARISONS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = { 530 exp.LT: exp.GT, 531 exp.GT: exp.LT, 532 exp.LTE: exp.GTE, 533 exp.GTE: exp.LTE, 534 } 535 536 NONDETERMINISTIC: t.ClassVar = (exp.Rand, exp.Randn) 537 AND_OR: t.ClassVar = (exp.And, exp.Or) 538 539 INVERSE_DATE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = { 540 exp.DateAdd: exp.Sub, 541 exp.DateSub: exp.Add, 542 exp.DatetimeAdd: exp.Sub, 543 exp.DatetimeSub: exp.Add, 544 } 545 546 INVERSE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = { 547 **INVERSE_DATE_OPS, 548 exp.Add: exp.Sub, 549 exp.Sub: exp.Add, 550 } 551 552 NULL_OK: t.ClassVar = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 553 554 CONCATS: t.ClassVar = (exp.Concat, exp.DPipe) 555 556 DATETRUNC_BINARY_COMPARISONS: t.ClassVar[dict[type[exp.Expr], DateTruncBinaryTransform]] = { 557 exp.LT: lambda l, dt, u, d, t: ( 558 l 559 < date_literal( 560 dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t 561 ) 562 ), 563 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 564 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 565 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 566 exp.EQ: _datetrunc_eq, 567 exp.NEQ: _datetrunc_neq, 568 } 569 570 DATETRUNC_COMPARISONS: t.ClassVar = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 571 DATETRUNCS: t.ClassVar = (exp.DateTrunc, exp.TimestampTrunc) 572 573 SAFE_CONNECTOR_ELIMINATION_RESULT: t.ClassVar = (exp.Connector, exp.Boolean) 574 575 # CROSS joins result in an empty table if the right table is empty. 576 # So we can only simplify certain types of joins to CROSS. 577 # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 578 JOINS: t.ClassVar = { 579 ("", ""), 580 ("", "INNER"), 581 ("RIGHT", ""), 582 ("RIGHT", "OUTER"), 583 } 584 585 def simplify( 586 self, 587 expression: exp.Expr, 588 constant_propagation: bool = False, 589 coalesce_simplification: bool = False, 590 ) -> exp.Expr: 591 wheres: list[exp.Where] = [] 592 joins: list[exp.Join] = [] 593 594 for node in expression.walk( 595 prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) 596 ): 597 if node.meta.get(FINAL): 598 continue 599 600 # group by expressions cannot be simplified, for example 601 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 602 # the projection must exactly match the group by key 603 group = node.args.get("group") 604 605 if group and hasattr(node, "selects"): 606 groups = set(group.expressions) 607 group.meta[FINAL] = True 608 609 for s in node.selects: 610 for n in s.walk(): 611 if n in groups: 612 s.meta[FINAL] = True 613 break 614 615 having = node.args.get("having") 616 617 if having: 618 for n in having.walk(): 619 if n in groups: 620 having.meta[FINAL] = True 621 break 622 623 if isinstance(node, exp.Condition): 624 simplified = while_changing( 625 node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification) 626 ) 627 628 if node is expression: 629 expression = simplified 630 elif isinstance(node, exp.Where): 631 wheres.append(node) 632 elif isinstance(node, exp.Join): 633 # snowflake match_conditions have very strict ordering rules 634 if match := node.args.get("match_condition"): 635 match.meta[FINAL] = True 636 637 joins.append(node) 638 639 for where in wheres: 640 if always_true(where.this): 641 where.pop() 642 for join in joins: 643 if ( 644 always_true(join.args.get("on")) 645 and not join.args.get("using") 646 and not join.args.get("method") 647 and (join.side, join.kind) in self.JOINS 648 ): 649 join.args["on"].pop() 650 join.set("side", None) 651 join.set("kind", "CROSS") 652 653 return expression 654 655 def _simplify( 656 self, expression: exp.Expr, constant_propagation: bool, coalesce_simplification: bool 657 ): 658 pre_transformation_stack = [expression] 659 post_transformation_stack = [] 660 661 while pre_transformation_stack: 662 original = pre_transformation_stack.pop() 663 node = original 664 665 if not isinstance(node, SIMPLIFIABLE): 666 if isinstance(node, exp.Query): 667 self.simplify(node, constant_propagation, coalesce_simplification) 668 continue 669 670 parent = node.parent 671 root = node is expression 672 673 node = self.rewrite_between(node) 674 node = self.uniq_sort(node, root) 675 node = self.absorb_and_eliminate(node, root) 676 node = self.simplify_concat(node) 677 node = self.simplify_conditionals(node) 678 679 if constant_propagation: 680 node = propagate_constants(node, root) 681 682 if node is not original: 683 original.replace(node) 684 685 for n in node.iter_expressions(reverse=True): 686 if n.meta.get(FINAL): 687 raise 688 pre_transformation_stack.extend( 689 n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 690 ) 691 post_transformation_stack.append((node, parent)) 692 693 while post_transformation_stack: 694 original, parent = post_transformation_stack.pop() 695 root = original is expression 696 697 # Resets parent, arg_key, index pointers– this is needed because some of the 698 # previous transformations mutate the AST, leading to an inconsistent state 699 for k, v in tuple(original.args.items()): 700 original.set(k, v) 701 702 # Post-order transformations 703 node = self.simplify_not(original) 704 node = flatten(node) 705 node = self.simplify_connectors(node, root) 706 node = self.remove_complements(node, root) 707 708 if coalesce_simplification: 709 node = self.simplify_coalesce(node) 710 node.parent = parent 711 712 node = self.simplify_literals(node, root) 713 node = self.simplify_equality(node) 714 node = simplify_parens(node, dialect=self.dialect) 715 node = self.simplify_datetrunc(node) 716 node = self.sort_comparison(node) 717 node = self.simplify_startswith(node) 718 719 if node is not original: 720 original.replace(node) 721 722 return node 723 724 @annotate_types_on_change 725 def rewrite_between(self, expression: exp.Expr) -> exp.Expr: 726 """Rewrite x between y and z to x >= y AND x <= z. 727 728 This is done because comparison simplification is only done on lt/lte/gt/gte. 729 """ 730 if isinstance(expression, exp.Between): 731 negate = isinstance(expression.parent, exp.Not) 732 733 expression = exp.and_( 734 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 735 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 736 copy=False, 737 ) 738 739 if negate: 740 expression = exp.paren(expression, copy=False) 741 742 return expression 743 744 @annotate_types_on_change 745 def simplify_not(self, expression: exp.Expr) -> exp.Expr: 746 """ 747 Demorgan's Law 748 NOT (x OR y) -> NOT x AND NOT y 749 NOT (x AND y) -> NOT x OR NOT y 750 """ 751 if isinstance(expression, exp.Not): 752 this = expression.this 753 if is_null(this): 754 return exp.and_(exp.null(), exp.true(), copy=False) 755 if this.__class__ in self.COMPLEMENT_COMPARISONS: 756 right = this.expression 757 complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( 758 right.__class__ 759 ) 760 if complement_subquery_predicate: 761 right = complement_subquery_predicate(this=right.this) 762 763 return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 764 if isinstance(this, exp.Paren): 765 condition = this.unnest() 766 if isinstance(condition, exp.And): 767 return exp.paren( 768 exp.or_( 769 exp.not_(condition.left, copy=False), 770 exp.not_(condition.right, copy=False), 771 copy=False, 772 ), 773 copy=False, 774 ) 775 if isinstance(condition, exp.Or): 776 return exp.paren( 777 exp.and_( 778 exp.not_(condition.left, copy=False), 779 exp.not_(condition.right, copy=False), 780 copy=False, 781 ), 782 copy=False, 783 ) 784 if is_null(condition): 785 return exp.and_(exp.null(), exp.true(), copy=False) 786 if always_true(this): 787 return exp.false() 788 if is_false(this): 789 return exp.true() 790 if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: 791 inner = this.this 792 if inner.is_type(exp.DType.BOOLEAN): 793 # double negation 794 # NOT NOT x -> x, if x is BOOLEAN type 795 return inner 796 return expression 797 798 @annotate_types_on_change 799 def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr: 800 def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr): 801 if isinstance(expression, exp.And): 802 if is_false(left) or is_false(right): 803 return exp.false() 804 if is_zero(left) or is_zero(right): 805 return exp.false() 806 if ( 807 (is_null(left) and is_null(right)) 808 or (is_null(left) and always_true(right)) 809 or (always_true(left) and is_null(right)) 810 ): 811 return exp.null() 812 if always_true(left) and always_true(right): 813 return exp.true() 814 if always_true(left): 815 return right 816 if always_true(right): 817 return left 818 return self._simplify_comparison(expression, left, right) 819 elif isinstance(expression, exp.Or): 820 if always_true(left) or always_true(right): 821 return exp.true() 822 if ( 823 (is_null(left) and is_null(right)) 824 or (is_null(left) and always_false(right)) 825 or (always_false(left) and is_null(right)) 826 ): 827 return exp.null() 828 if is_false(left): 829 return right 830 if is_false(right): 831 return left 832 return self._simplify_comparison(expression, left, right, or_=True) 833 834 if isinstance(expression, exp.Connector): 835 original_parent = expression.parent 836 expression = self._flat_simplify(expression, _simplify_connectors, root) 837 838 # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need 839 # to ensure that the resulting type is boolean. We know this is true only for connectors, 840 # boolean values and columns that are essentially operands to a connector: 841 # 842 # A AND (((B))) 843 # ~ this is safe to keep because it will eventually be part of another connector 844 if not isinstance( 845 expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT 846 ) and not expression.is_type(exp.DType.BOOLEAN): 847 while True: 848 if isinstance(original_parent, exp.Connector): 849 break 850 if not isinstance(original_parent, exp.Paren): 851 expression = expression.and_(exp.true(), copy=False) 852 break 853 854 original_parent = original_parent.parent 855 856 return expression 857 858 @annotate_types_on_change 859 def _simplify_comparison( 860 self, expression: exp.Expr, left: exp.Expr, right: exp.Expr, or_: bool = False 861 ) -> exp.Expr | None: 862 if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): 863 ll, lr = left.args.values() 864 rl, rr = right.args.values() 865 866 largs = {ll, lr} 867 rargs = {rl, rr} 868 869 matching = largs & rargs 870 columns = { 871 m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC) 872 } 873 874 if matching and columns: 875 try: 876 l = first(largs - columns) 877 r = first(rargs - columns) 878 except StopIteration: 879 return expression 880 881 if l.is_number and r.is_number: 882 l = l.to_py() 883 r = r.to_py() 884 elif l.is_string and r.is_string: 885 l = l.name 886 r = r.name 887 else: 888 l = extract_date(l) 889 if not l: 890 return None 891 r = extract_date(r) 892 if not r: 893 return None 894 # python won't compare date and datetime, but many engines will upcast 895 l, r = cast_as_datetime(l), cast_as_datetime(r) 896 897 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 898 if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE): 899 return left if (av > bv if or_ else av <= bv) else right 900 if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE): 901 return left if (av < bv if or_ else av >= bv) else right 902 903 # we can't ever shortcut to true because the column could be null 904 if not or_: 905 if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE): 906 if av <= bv: 907 return exp.false() 908 elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE): 909 if av >= bv: 910 return exp.false() 911 elif isinstance(a, exp.EQ): 912 if isinstance(b, exp.LT): 913 return exp.false() if av >= bv else a 914 if isinstance(b, exp.LTE): 915 return exp.false() if av > bv else a 916 if isinstance(b, exp.GT): 917 return exp.false() if av <= bv else a 918 if isinstance(b, exp.GTE): 919 return exp.false() if av < bv else a 920 if isinstance(b, exp.NEQ): 921 return exp.false() if av == bv else a 922 return None 923 924 @annotate_types_on_change 925 def remove_complements(self, expression: object, root: bool = True) -> object: 926 """ 927 Removing complements. 928 929 A AND NOT A -> FALSE (only for non-NULL A) 930 A OR NOT A -> TRUE (only for non-NULL A) 931 """ 932 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 933 ops = set(expression.flatten()) 934 for op in ops: 935 if isinstance(op, exp.Not) and op.this in ops: 936 if expression.meta.get("nonnull") is True: 937 return exp.false() if isinstance(expression, exp.And) else exp.true() 938 939 return expression 940 941 @annotate_types_on_change 942 def uniq_sort(self, expression: object, root: bool = True) -> object: 943 """ 944 Uniq and sort a connector. 945 946 C AND A AND B AND B -> A AND B AND C 947 """ 948 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 949 flattened = tuple(expression.flatten()) 950 951 if isinstance(expression, exp.Xor): 952 result_func = exp.xor 953 # Do not deduplicate XOR as A XOR A != A if A == True 954 deduped = None 955 arr = tuple((gen(e), e) for e in flattened) 956 else: 957 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 958 deduped = {gen(e): e for e in flattened} 959 arr = tuple(deduped.items()) 960 961 # check if the operands are already sorted, if not sort them 962 # A AND C AND B -> A AND B AND C 963 for i, (sql, e) in enumerate(arr[1:]): 964 if sql < arr[i][0]: 965 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 966 break 967 else: 968 # we didn't have to sort but maybe we need to dedup 969 if deduped and len(deduped) < len(flattened): 970 unique_operand = flattened[0] 971 if len(deduped) == 1: 972 expression = unique_operand.and_(exp.true(), copy=False) 973 else: 974 expression = result_func(*deduped.values(), copy=False) 975 976 return expression 977 978 @annotate_types_on_change 979 def absorb_and_eliminate(self, expression: object, root: bool = True) -> object: 980 """ 981 absorption: 982 A AND (A OR B) -> A 983 A OR (A AND B) -> A 984 A AND (NOT A OR B) -> A AND B 985 A OR (NOT A AND B) -> A OR B 986 elimination: 987 (A AND B) OR (A AND NOT B) -> A 988 (A OR B) AND (A OR NOT B) -> A 989 """ 990 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 991 kind = exp.Or if isinstance(expression, exp.And) else exp.And 992 993 ops = tuple(expression.flatten()) 994 995 # Initialize lookup tables: 996 # Set of all operands, used to find complements for absorption. 997 op_set = set() 998 # Sub-operands, used to find subsets for absorption. 999 subops = defaultdict(list) 1000 # Pairs of complements, used for elimination. 1001 pairs = defaultdict(list) 1002 1003 # Populate the lookup tables 1004 for op in ops: 1005 op_set.add(op) 1006 1007 if not isinstance(op, kind): 1008 # In cases like: A OR (A AND B) 1009 # Subop will be: ^ 1010 subops[op].append({op}) 1011 continue 1012 1013 # In cases like: (A AND B) OR (A AND B AND C) 1014 # Subops will be: ^ ^ 1015 subset = set(op.flatten()) 1016 for i in subset: 1017 subops[i].append(subset) 1018 1019 a, b = op.unnest_operands() 1020 if isinstance(a, exp.Not): 1021 pairs[frozenset((a.this, b))].append((op, b)) 1022 if isinstance(b, exp.Not): 1023 pairs[frozenset((a, b.this))].append((op, a)) 1024 1025 for op in ops: 1026 if not isinstance(op, kind): 1027 continue 1028 1029 a, b = op.unnest_operands() 1030 1031 # Absorb 1032 if isinstance(a, exp.Not) and a.this in op_set: 1033 a.replace(exp.true() if kind == exp.And else exp.false()) 1034 continue 1035 if isinstance(b, exp.Not) and b.this in op_set: 1036 b.replace(exp.true() if kind == exp.And else exp.false()) 1037 continue 1038 superset = set(op.flatten()) 1039 if any(any(subset < superset for subset in subops[i]) for i in superset): 1040 op.replace(exp.false() if kind == exp.And else exp.true()) 1041 continue 1042 1043 # Eliminate 1044 for other, complement in pairs[frozenset((a, b))]: 1045 op.replace(complement) 1046 other.replace(complement) 1047 1048 return expression 1049 1050 @annotate_types_on_change 1051 @catch(ModuleNotFoundError, UnsupportedUnit) 1052 def simplify_equality(self, expression: exp.Expr) -> exp.Expr: 1053 """ 1054 Use the subtraction and addition properties of equality to simplify expressions: 1055 1056 x + 1 = 3 becomes x = 2 1057 1058 There are two binary operations in the above expression: + and = 1059 Here's how we reference all the operands in the code below: 1060 1061 l r 1062 x + 1 = 3 1063 a b 1064 """ 1065 if isinstance(expression, self.COMPARISONS): 1066 l, r = expression.left, expression.right 1067 1068 if l.__class__ not in self.INVERSE_OPS: 1069 return expression 1070 1071 if r.is_number: 1072 a_predicate = _is_number 1073 b_predicate = _is_number 1074 elif _is_date_literal(r): 1075 a_predicate = _is_date_literal 1076 b_predicate = _is_interval 1077 else: 1078 return expression 1079 1080 if l.__class__ in self.INVERSE_DATE_OPS: 1081 l = t.cast(exp.IntervalOp, l) 1082 a: exp.Expr = l.this 1083 b: exp.Expr = l.interval() 1084 else: 1085 l = t.cast(exp.Binary, l) 1086 a, b = l.left, l.right 1087 1088 if not a_predicate(a) and b_predicate(b): 1089 pass 1090 elif not a_predicate(b) and b_predicate(a): 1091 a, b = b, a 1092 else: 1093 return expression 1094 1095 return expression.__class__( 1096 this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b) 1097 ) 1098 return expression 1099 1100 def _is_inverse_date_op(self, expression: exp.Expr) -> TypeIs[exp.IntervalOp]: 1101 return type(expression) in self.INVERSE_DATE_OPS 1102 1103 @annotate_types_on_change 1104 def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr: 1105 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 1106 return self._flat_simplify(expression, self._simplify_binary, root) 1107 1108 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 1109 return expression.this.this 1110 1111 if self._is_inverse_date_op(expression): 1112 return ( 1113 self._simplify_binary(expression, expression.this, expression.interval()) 1114 or expression 1115 ) 1116 1117 return expression 1118 1119 def _simplify_integer_cast(self, expr: t.Any) -> t.Any: 1120 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 1121 this = self._simplify_integer_cast(expr.this) 1122 else: 1123 this = expr.this 1124 1125 if isinstance(expr, exp.Cast) and this.is_int: 1126 num = this.to_py() 1127 1128 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 1129 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 1130 # engine-dependent 1131 if ( 1132 self.TINYINT_MIN <= num <= self.TINYINT_MAX 1133 and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 1134 ) or ( 1135 self.UTINYINT_MIN <= num <= self.UTINYINT_MAX 1136 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 1137 ): 1138 return this 1139 1140 return expr 1141 1142 def _simplify_binary(self, expression, a, b): 1143 if isinstance(expression, self.COMPARISONS): 1144 a = self._simplify_integer_cast(a) 1145 b = self._simplify_integer_cast(b) 1146 1147 if isinstance(expression, exp.Is): 1148 if isinstance(b, exp.Not): 1149 c = b.this 1150 not_ = True 1151 else: 1152 c = b 1153 not_ = False 1154 1155 if is_null(c): 1156 if isinstance(a, exp.Literal): 1157 return exp.true() if not_ else exp.false() 1158 if is_null(a): 1159 return exp.false() if not_ else exp.true() 1160 elif isinstance(expression, self.NULL_OK): 1161 return None 1162 elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If): 1163 return exp.null() 1164 1165 if a.is_number and b.is_number: 1166 num_a = a.to_py() 1167 num_b = b.to_py() 1168 1169 if isinstance(expression, exp.Add): 1170 return exp.Literal.number(num_a + num_b) 1171 if isinstance(expression, exp.Mul): 1172 return exp.Literal.number(num_a * num_b) 1173 1174 # We only simplify Sub, Div if a and b have the same parent because they're not associative 1175 if isinstance(expression, exp.Sub): 1176 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 1177 if isinstance(expression, exp.Div): 1178 # engines have differing int div behavior so intdiv is not safe 1179 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 1180 return None 1181 return exp.Literal.number(num_a / num_b) 1182 1183 boolean = eval_boolean(expression, num_a, num_b) 1184 1185 if boolean: 1186 return boolean 1187 elif a.is_string and b.is_string: 1188 boolean = eval_boolean(expression, a.this, b.this) 1189 1190 if boolean: 1191 return boolean 1192 elif _is_date_literal(a) and isinstance(b, exp.Interval): 1193 date, b = extract_date(a), extract_interval(b) 1194 if date and b: 1195 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 1196 return date_literal(date + b, extract_type(a)) 1197 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 1198 return date_literal(date - b, extract_type(a)) 1199 elif isinstance(a, exp.Interval) and _is_date_literal(b): 1200 a, date = extract_interval(a), extract_date(b) 1201 # you cannot subtract a date from an interval 1202 if a and b and isinstance(expression, exp.Add): 1203 return date_literal(a + date, extract_type(b)) 1204 elif _is_date_literal(a) and _is_date_literal(b): 1205 if isinstance(expression, exp.Predicate): 1206 a, b = extract_date(a), extract_date(b) 1207 boolean = eval_boolean(expression, a, b) 1208 if boolean: 1209 return boolean 1210 1211 return None 1212 1213 @annotate_types_on_change 1214 def simplify_coalesce(self, expression: exp.Expr) -> exp.Expr: 1215 # COALESCE(x) -> x 1216 if ( 1217 isinstance(expression, exp.Coalesce) 1218 and (not expression.expressions or _is_nonnull_constant(expression.this)) 1219 # COALESCE is also used as a Spark partitioning hint 1220 and not isinstance(expression.parent, exp.Hint) 1221 ): 1222 return expression.this 1223 1224 if self.dialect.COALESCE_COMPARISON_NON_STANDARD: 1225 return expression 1226 1227 if not isinstance(expression, self.COMPARISONS): 1228 return expression 1229 1230 if isinstance(expression.left, exp.Coalesce): 1231 coalesce = expression.left 1232 other = expression.right 1233 elif isinstance(expression.right, exp.Coalesce): 1234 coalesce = expression.right 1235 other = expression.left 1236 else: 1237 return expression 1238 1239 # This transformation is valid for non-constants, 1240 # but it really only does anything if they are both constants. 1241 if not _is_constant(other): 1242 return expression 1243 1244 # Find the first constant arg 1245 for arg_index, arg in enumerate(coalesce.expressions): 1246 if _is_constant(arg): 1247 break 1248 else: 1249 return expression 1250 1251 coalesce.set("expressions", coalesce.expressions[:arg_index]) 1252 1253 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 1254 # since we already remove COALESCE at the top of this function. 1255 this: exp.Expr = coalesce if coalesce.expressions else coalesce.this 1256 1257 # This expression is more complex than when we started, but it will get simplified further 1258 return exp.paren( 1259 exp.or_( 1260 exp.and_( 1261 this.is_(exp.null()).not_(copy=False), 1262 expression.copy(), 1263 copy=False, 1264 ), 1265 exp.and_( 1266 this.is_(exp.null()), 1267 type(expression)(this=arg.copy(), expression=other.copy()), 1268 copy=False, 1269 ), 1270 copy=False, 1271 ), 1272 copy=False, 1273 ) 1274 1275 @annotate_types_on_change 1276 def simplify_concat(self, expression): 1277 """Reduces all groups that contain string literals by concatenating them.""" 1278 if not isinstance(expression, self.CONCATS) or ( 1279 # We can't reduce a CONCAT_WS call if we don't statically know the separator 1280 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 1281 ): 1282 return expression 1283 1284 if isinstance(expression, exp.ConcatWs): 1285 sep_expr, *expressions = expression.expressions 1286 sep = sep_expr.name 1287 concat_type = exp.ConcatWs 1288 args = {} 1289 else: 1290 expressions = expression.expressions 1291 sep = "" 1292 concat_type = exp.Concat 1293 args = { 1294 "safe": expression.args.get("safe"), 1295 "coalesce": expression.args.get("coalesce"), 1296 } 1297 1298 new_args = [] 1299 for is_string_group, group in itertools.groupby( 1300 expressions or expression.flatten(), lambda e: e.is_string 1301 ): 1302 if is_string_group: 1303 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 1304 else: 1305 new_args.extend(group) 1306 1307 if len(new_args) == 1 and new_args[0].is_string: 1308 return new_args[0] 1309 1310 if concat_type is exp.ConcatWs: 1311 new_args = [sep_expr] + new_args 1312 elif isinstance(expression, exp.DPipe): 1313 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 1314 1315 return concat_type(expressions=new_args, **args) 1316 1317 @annotate_types_on_change 1318 def simplify_conditionals(self, expression): 1319 """Simplifies expressions like IF, CASE if their condition is statically known.""" 1320 if isinstance(expression, exp.Case): 1321 this = expression.this 1322 for case in expression.args["ifs"]: 1323 cond = case.this 1324 if this: 1325 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 1326 cond = cond.replace(this.pop().eq(cond)) 1327 1328 if always_true(cond): 1329 return case.args["true"] 1330 1331 if always_false(cond): 1332 case.pop() 1333 if not expression.args["ifs"]: 1334 return expression.args.get("default") or exp.null() 1335 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 1336 if always_true(expression.this): 1337 return expression.args["true"] 1338 if always_false(expression.this): 1339 return expression.args.get("false") or exp.null() 1340 1341 return expression 1342 1343 @annotate_types_on_change 1344 def simplify_startswith(self, expression: exp.Expr) -> exp.Expr: 1345 """ 1346 Reduces a prefix check to either TRUE or FALSE if both the string and the 1347 prefix are statically known. 1348 1349 Example: 1350 >>> from sqlglot import parse_one 1351 >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 1352 'TRUE' 1353 """ 1354 if ( 1355 isinstance(expression, exp.StartsWith) 1356 and expression.this.is_string 1357 and expression.expression.is_string 1358 ): 1359 return exp.convert(expression.name.startswith(expression.expression.name)) 1360 1361 return expression 1362 1363 def _is_datetrunc_predicate(self, left: exp.Expr, right: exp.Expr) -> bool: 1364 return isinstance(left, self.DATETRUNCS) and _is_date_literal(right) 1365 1366 @annotate_types_on_change 1367 @catch(ModuleNotFoundError, UnsupportedUnit) 1368 def simplify_datetrunc(self, expression: exp.Expr) -> exp.Expr: 1369 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1370 comparison = expression.__class__ 1371 1372 if isinstance(expression, self.DATETRUNCS): 1373 this = expression.this 1374 trunc_type = extract_type(this) 1375 date = extract_date(this) 1376 if date and expression.unit: 1377 return date_literal( 1378 date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type 1379 ) 1380 elif comparison not in self.DATETRUNC_COMPARISONS: 1381 return expression 1382 1383 if isinstance(expression, exp.Binary): 1384 l, r = expression.left, expression.right 1385 1386 if not self._is_datetrunc_predicate(l, r) or not isinstance( 1387 l, (exp.DateTrunc, exp.TimestampTrunc) 1388 ): 1389 return expression 1390 1391 trunc_arg = l.this 1392 unit = l.args["unit"].name.lower() 1393 date = extract_date(r) 1394 1395 if not date: 1396 return expression 1397 1398 return ( 1399 self.DATETRUNC_BINARY_COMPARISONS[comparison]( 1400 trunc_arg, date, unit, self.dialect, extract_type(r) 1401 ) 1402 or expression 1403 ) 1404 1405 if isinstance(expression, exp.In): 1406 l = expression.this 1407 rs = expression.expressions 1408 1409 if ( 1410 rs 1411 and all(self._is_datetrunc_predicate(l, r) for r in rs) 1412 and isinstance(l, (exp.DateTrunc, exp.TimestampTrunc)) 1413 ): 1414 unit = l.args["unit"].name.lower() 1415 1416 ranges = [] 1417 for r in rs: 1418 date = extract_date(r) 1419 if not date: 1420 return expression 1421 drange = _datetrunc_range(date, unit, self.dialect) 1422 if drange: 1423 ranges.append(drange) 1424 1425 if not ranges: 1426 return expression 1427 1428 ranges = merge_ranges(ranges) 1429 target_type = extract_type(*rs) 1430 1431 return exp.or_( 1432 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], 1433 copy=False, 1434 ) 1435 1436 return expression 1437 1438 @annotate_types_on_change 1439 def sort_comparison(self, expression: exp.Expr) -> exp.Expr: 1440 if expression.__class__ in self.COMPLEMENT_COMPARISONS: 1441 l, r = expression.this, expression.expression 1442 l_column = isinstance(l, exp.Column) 1443 r_column = isinstance(r, exp.Column) 1444 l_const = _is_constant(l) 1445 r_const = _is_constant(r) 1446 1447 if ( 1448 (l_column and not r_column) 1449 or (r_const and not l_const) 1450 or isinstance(r, exp.SubqueryPredicate) 1451 ): 1452 return expression 1453 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1454 return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1455 this=r, expression=l 1456 ) 1457 return expression 1458 1459 def _flat_simplify( 1460 self, 1461 expression: exp.Expr, 1462 simplifier: t.Callable[[exp.Expr, exp.Expr, exp.Expr], exp.Expr | None], 1463 root: bool = True, 1464 ) -> exp.Expr: 1465 if root or not expression.same_parent: 1466 operands = [] 1467 queue = deque(expression.flatten(unnest=False)) 1468 size = len(queue) 1469 1470 while queue: 1471 a = queue.popleft() 1472 1473 for b in queue: 1474 result = simplifier(expression, a, b) 1475 1476 if result and result is not expression: 1477 queue.remove(b) 1478 queue.appendleft(result) 1479 break 1480 else: 1481 operands.append(a) 1482 1483 if len(operands) < size: 1484 return functools.reduce( 1485 lambda a, b: expression.__class__(this=a, expression=b), operands 1486 ) 1487 return expression 1488 1489 1490def gen(expression: exp.Expr, comments: bool = False) -> str: 1491 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1492 1493 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1494 generator is expensive so we have a bare minimum sql generator here. 1495 1496 Args: 1497 expression: the expression to convert into a SQL string. 1498 comments: whether to include the expression's comments. 1499 """ 1500 return Gen().gen(expression, comments=comments) 1501 1502 1503class Gen: 1504 def __init__(self): 1505 self.stack = [] 1506 self.sqls = [] 1507 1508 def gen(self, expression: exp.Expr, comments: bool = False) -> str: 1509 self.stack = [expression] 1510 self.sqls.clear() 1511 1512 while self.stack: 1513 node = self.stack.pop() 1514 1515 if isinstance(node, exp.Expr): 1516 if comments and node.comments: 1517 self.stack.append(f" /*{','.join(node.comments)}*/") 1518 1519 exp_handler_name = f"{node.key}_sql" 1520 1521 if hasattr(self, exp_handler_name): 1522 getattr(self, exp_handler_name)(node) 1523 elif isinstance(node, exp.Func): 1524 self._function(node) 1525 else: 1526 key = node.key.upper() 1527 self.stack.append(f"{key} " if self._args(node) else key) 1528 elif type(node) is list: 1529 for n in reversed(node): 1530 if n is not None: 1531 self.stack.extend((n, ",")) 1532 if node: 1533 self.stack.pop() 1534 else: 1535 if node is not None: 1536 self.sqls.append(str(node)) 1537 1538 return "".join(self.sqls) 1539 1540 def add_sql(self, e: exp.Add) -> None: 1541 self._binary(e, " + ") 1542 1543 def alias_sql(self, e: exp.Alias) -> None: 1544 self.stack.extend( 1545 ( 1546 e.args.get("alias"), 1547 " AS ", 1548 e.args.get("this"), 1549 ) 1550 ) 1551 1552 def and_sql(self, e: exp.And) -> None: 1553 self._binary(e, " AND ") 1554 1555 def anonymous_sql(self, e: exp.Anonymous) -> None: 1556 this = e.this 1557 if isinstance(this, str): 1558 name = this.upper() 1559 elif isinstance(this, exp.Identifier): 1560 name = this.this 1561 name = f'"{name}"' if this.quoted else name.upper() 1562 else: 1563 raise ValueError( 1564 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1565 ) 1566 1567 self.stack.extend( 1568 ( 1569 ")", 1570 e.expressions, 1571 "(", 1572 name, 1573 ) 1574 ) 1575 1576 def between_sql(self, e: exp.Between) -> None: 1577 self.stack.extend( 1578 ( 1579 e.args.get("high"), 1580 " AND ", 1581 e.args.get("low"), 1582 " BETWEEN ", 1583 e.this, 1584 ) 1585 ) 1586 1587 def boolean_sql(self, e: exp.Boolean) -> None: 1588 self.stack.append("TRUE" if e.this else "FALSE") 1589 1590 def bracket_sql(self, e: exp.Bracket) -> None: 1591 self.stack.extend( 1592 ( 1593 "]", 1594 e.expressions, 1595 "[", 1596 e.this, 1597 ) 1598 ) 1599 1600 def column_sql(self, e: exp.Column) -> None: 1601 for p in reversed(e.parts): 1602 self.stack.extend((p, ".")) 1603 self.stack.pop() 1604 1605 def datatype_sql(self, e: exp.DataType) -> None: 1606 self._args(e, 1) 1607 self.stack.append(f"{e.this.name} ") 1608 1609 def div_sql(self, e: exp.Div) -> None: 1610 self._binary(e, " / ") 1611 1612 def dot_sql(self, e: exp.Dot) -> None: 1613 self._binary(e, ".") 1614 1615 def eq_sql(self, e: exp.EQ) -> None: 1616 self._binary(e, " = ") 1617 1618 def from_sql(self, e: exp.From) -> None: 1619 self.stack.extend((e.this, "FROM ")) 1620 1621 def gt_sql(self, e: exp.GT) -> None: 1622 self._binary(e, " > ") 1623 1624 def gte_sql(self, e: exp.GTE) -> None: 1625 self._binary(e, " >= ") 1626 1627 def identifier_sql(self, e: exp.Identifier) -> None: 1628 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1629 1630 def ilike_sql(self, e: exp.ILike) -> None: 1631 self._binary(e, " ILIKE ") 1632 1633 def in_sql(self, e: exp.In) -> None: 1634 self.stack.append(")") 1635 self._args(e, 1) 1636 self.stack.extend( 1637 ( 1638 "(", 1639 " IN ", 1640 e.this, 1641 ) 1642 ) 1643 1644 def intdiv_sql(self, e: exp.IntDiv) -> None: 1645 self._binary(e, " DIV ") 1646 1647 def is_sql(self, e: exp.Is) -> None: 1648 self._binary(e, " IS ") 1649 1650 def like_sql(self, e: exp.Like) -> None: 1651 self._binary(e, " Like ") 1652 1653 def literal_sql(self, e: exp.Literal) -> None: 1654 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1655 1656 def lt_sql(self, e: exp.LT) -> None: 1657 self._binary(e, " < ") 1658 1659 def lte_sql(self, e: exp.LTE) -> None: 1660 self._binary(e, " <= ") 1661 1662 def mod_sql(self, e: exp.Mod) -> None: 1663 self._binary(e, " % ") 1664 1665 def mul_sql(self, e: exp.Mul) -> None: 1666 self._binary(e, " * ") 1667 1668 def neg_sql(self, e: exp.Neg) -> None: 1669 self._unary(e, "-") 1670 1671 def neq_sql(self, e: exp.NEQ) -> None: 1672 self._binary(e, " <> ") 1673 1674 def not_sql(self, e: exp.Not) -> None: 1675 self._unary(e, "NOT ") 1676 1677 def null_sql(self, e: exp.Null) -> None: 1678 self.stack.append("NULL") 1679 1680 def or_sql(self, e: exp.Or) -> None: 1681 self._binary(e, " OR ") 1682 1683 def paren_sql(self, e: exp.Paren) -> None: 1684 self.stack.extend( 1685 ( 1686 ")", 1687 e.this, 1688 "(", 1689 ) 1690 ) 1691 1692 def sub_sql(self, e: exp.Sub) -> None: 1693 self._binary(e, " - ") 1694 1695 def subquery_sql(self, e: exp.Subquery) -> None: 1696 self._args(e, 2) 1697 alias = e.args.get("alias") 1698 if alias: 1699 self.stack.append(alias) 1700 self.stack.extend((")", e.this, "(")) 1701 1702 def table_sql(self, e: exp.Table) -> None: 1703 self._args(e, 4) 1704 alias = e.args.get("alias") 1705 if alias: 1706 self.stack.append(alias) 1707 for p in reversed(e.parts): 1708 self.stack.extend((p, ".")) 1709 self.stack.pop() 1710 1711 def tablealias_sql(self, e: exp.TableAlias) -> None: 1712 columns = e.columns 1713 1714 if columns: 1715 self.stack.extend((")", columns, "(")) 1716 1717 self.stack.extend((e.this, " AS ")) 1718 1719 def var_sql(self, e: exp.Var) -> None: 1720 self.stack.append(e.this) 1721 1722 def _binary(self, e: exp.Binary, op: str) -> None: 1723 self.stack.extend((e.expression, op, e.this)) 1724 1725 def _unary(self, e: exp.Unary, op: str) -> None: 1726 self.stack.extend((e.this, op)) 1727 1728 def _function(self, e: exp.Func) -> None: 1729 self.stack.extend( 1730 ( 1731 ")", 1732 list(e.args.values()), 1733 "(", 1734 e.sql_name(), 1735 ) 1736 ) 1737 1738 def _args(self, node: exp.Expr, arg_index: int = 0) -> bool: 1739 kvs = [] 1740 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1741 1742 for k in arg_types: 1743 v = node.args.get(k) 1744 1745 if v is not None: 1746 kvs.append([f":{k}", v]) 1747 if kvs: 1748 self.stack.append(kvs) 1749 return True 1750 return False
45def simplify( 46 expression: exp.Expr, 47 constant_propagation: bool = False, 48 coalesce_simplification: bool = False, 49 dialect: DialectType = None, 50): 51 """ 52 Rewrite sqlglot AST to simplify expressions. 53 54 Example: 55 >>> import sqlglot 56 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 57 >>> simplify(expression).sql() 58 'TRUE' 59 60 Args: 61 expression: expression to simplify 62 constant_propagation: whether the constant propagation rule should be used 63 coalesce_simplification: whether the simplify coalesce rule should be used. 64 This rule tries to remove coalesce functions, which can be useful in certain analyses but 65 can leave the query more verbose. 66 Returns: 67 sqlglot.Expr: simplified expression 68 """ 69 return Simplifier(dialect=dialect).simplify( 70 expression, 71 constant_propagation=constant_propagation, 72 coalesce_simplification=coalesce_simplification, 73 )
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:
sqlglot.Expr: simplified expression
Common base class for all non-exit exceptions.
80def catch(*exceptions): 81 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 82 83 def decorator(func): 84 def wrapped(expression, *args, **kwargs): 85 try: 86 return func(expression, *args, **kwargs) 87 except exceptions: 88 return expression 89 90 return wrapped 91 92 return decorator
Decorator that ignores a simplification function if any of exceptions are raised
95def annotate_types_on_change(func): 96 @wraps(func) 97 def _func(self, expression: exp.Expr, *args, **kwargs) -> exp.Expr | None: 98 new_expression: exp.Expr | None = func(self, expression, *args, **kwargs) 99 100 if new_expression is None: 101 return new_expression 102 103 if self.annotate_new_expressions and expression != new_expression: 104 self._annotator.clear() 105 106 # We annotate this to ensure new children nodes are also annotated 107 new_expression = self._annotator.annotate( 108 expression=new_expression, 109 annotate_scope=False, 110 ) 111 112 # Whatever expression the original expression is transformed into needs to preserve 113 # the original type, otherwise the simplification could result in a different schema 114 new_expression.type = expression.type 115 116 return new_expression 117 118 return _func
121def flatten(expression: exp.Expr) -> exp.Expr: 122 """ 123 A AND (B AND C) -> A AND B AND C 124 A OR (B OR C) -> A OR B OR C 125 """ 126 if isinstance(expression, exp.Connector): 127 for node in expression.args.values(): 128 child = node.unnest() 129 if isinstance(child, expression.__class__): 130 node.replace(child) 131 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
134def simplify_parens(expression: exp.Expr, dialect: DialectType) -> exp.Expr: 135 if not isinstance(expression, exp.Paren): 136 return expression 137 138 this = expression.this 139 parent = expression.parent 140 parent_is_predicate = isinstance(parent, exp.Predicate) 141 142 if isinstance(this, exp.Select): 143 return expression 144 145 if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): 146 return expression 147 148 if ( 149 Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS 150 and isinstance(parent, exp.Dot) 151 and (isinstance(parent.right, (exp.Identifier, exp.Star))) 152 ): 153 return expression 154 155 if isinstance(this, exp.Predicate) and ( 156 not ( 157 parent_is_predicate 158 or isinstance(parent, exp.Neg) 159 or (isinstance(parent, exp.Binary) and not isinstance(parent, exp.Connector)) 160 ) 161 ): 162 return this 163 164 if ( 165 not isinstance(parent, (exp.Condition, exp.Binary)) 166 or isinstance(parent, exp.Paren) 167 or ( 168 not isinstance(this, exp.Binary) 169 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 170 ) 171 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 172 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 173 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 174 ): 175 return this 176 177 return expression
180def propagate_constants(expression, root=True): 181 """ 182 Propagate constants for conjunctions in DNF: 183 184 SELECT * FROM t WHERE a = b AND b = 5 becomes 185 SELECT * FROM t WHERE a = 5 AND b = 5 186 187 Reference: https://www.sqlite.org/optoverview.html 188 """ 189 190 if ( 191 isinstance(expression, exp.And) 192 and (root or not expression.same_parent) 193 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 194 ): 195 constant_mapping = {} 196 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 197 if isinstance(expr, exp.EQ): 198 l, r = expr.left, expr.right 199 200 # TODO: create a helper that can be used to detect nested literal expressions such 201 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 202 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 203 constant_mapping[l] = (id(l), r) 204 205 if constant_mapping: 206 for column in find_all_in_scope(expression, exp.Column): 207 parent = column.parent 208 column_id, constant = constant_mapping.get(column) or (None, None) 209 if ( 210 column_id is not None 211 and id(column) != column_id 212 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 213 ): 214 column.replace(constant.copy()) 215 216 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
323class SupportsComparison(t.Protocol): 324 """Protocol for expressions or values that can be compared using <, <=, >, >=.""" 325 326 def __lt__(self, other: t.Any) -> bool: ... 327 def __le__(self, other: t.Any) -> bool: ... 328 def __gt__(self, other: t.Any) -> bool: ... 329 def __ge__(self, other: t.Any) -> bool: ...
Protocol for expressions or values that can be compared using <, <=, >, >=.
1431def _no_init_or_replace_init(self, *args, **kwargs): 1432 cls = type(self) 1433 1434 if cls._is_protocol: 1435 raise TypeError('Protocols cannot be instantiated') 1436 1437 # Already using a custom `__init__`. No need to calculate correct 1438 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1439 if cls.__init__ is not _no_init_or_replace_init: 1440 return 1441 1442 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1443 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1444 # searches for a proper new `__init__` in the MRO. The new `__init__` 1445 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1446 # instantiation of the protocol subclass will thus use the new 1447 # `__init__` and no longer call `_no_init_or_replace_init`. 1448 for base in cls.__mro__: 1449 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1450 if init is not _no_init_or_replace_init: 1451 cls.__init__ = init 1452 break 1453 else: 1454 # should not happen 1455 cls.__init__ = object.__init__ 1456 1457 cls.__init__(self, *args, **kwargs)
332def eval_boolean( 333 expression: object, a: SupportsComparison, b: SupportsComparison 334) -> exp.Boolean | None: 335 if isinstance(expression, (exp.EQ, exp.Is)): 336 return boolean_literal(a == b) 337 if isinstance(expression, exp.NEQ): 338 return boolean_literal(a != b) 339 if isinstance(expression, exp.GT): 340 return boolean_literal(a > b) 341 if isinstance(expression, exp.GTE): 342 return boolean_literal(a >= b) 343 if isinstance(expression, exp.LT): 344 return boolean_literal(a < b) 345 if isinstance(expression, exp.LTE): 346 return boolean_literal(a <= b) 347 return None
361def cast_as_datetime( 362 value: datetime | date | str, 363) -> datetime | None: 364 if isinstance(value, datetime): 365 return value 366 if isinstance(value, date): 367 return datetime(year=value.year, month=value.month, day=value.day) 368 try: 369 return datetime.fromisoformat(value) 370 except ValueError: 371 return None
384def extract_date(cast: exp.Expr) -> date | date | None: 385 if isinstance(cast, exp.Cast): 386 to = cast.to 387 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 388 to = exp.DType.DATE.into_expr() 389 else: 390 return None 391 392 if isinstance(cast.this, exp.Literal): 393 value: t.Any = cast.this.name 394 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 395 value = extract_date(cast.this) 396 else: 397 return None 398 return cast_value(value, to)
431def interval(unit: str, n: int = 1) -> relativedelta: 432 from dateutil.relativedelta import relativedelta 433 434 if unit == "year": 435 return relativedelta(years=1 * n) 436 if unit == "quarter": 437 return relativedelta(months=3 * n) 438 if unit == "month": 439 return relativedelta(months=1 * n) 440 if unit == "week": 441 return relativedelta(weeks=1 * n) 442 if unit == "day": 443 return relativedelta(days=1 * n) 444 if unit == "hour": 445 return relativedelta(hours=1 * n) 446 if unit == "minute": 447 return relativedelta(minutes=1 * n) 448 if unit == "second": 449 return relativedelta(seconds=1 * n) 450 451 raise UnsupportedUnit(f"Unsupported unit: {unit}")
454def date_floor(d: date, unit: str, dialect: Dialect) -> date: 455 if unit == "year": 456 return d.replace(month=1, day=1) 457 if unit == "quarter": 458 if d.month <= 3: 459 return d.replace(month=1, day=1) 460 elif d.month <= 6: 461 return d.replace(month=4, day=1) 462 elif d.month <= 9: 463 return d.replace(month=7, day=1) 464 else: 465 return d.replace(month=10, day=1) 466 if unit == "month": 467 return d.replace(month=d.month, day=1) 468 if unit == "week": 469 # Assuming week starts on Monday (0) and ends on Sunday (6) 470 return d - timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 471 if unit == "day": 472 return d 473 474 raise UnsupportedUnit(f"Unsupported unit: {unit}")
490class Simplifier: 491 def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): 492 self.dialect = Dialect.get_or_raise(dialect) 493 self.annotate_new_expressions = annotate_new_expressions 494 495 self._annotator: TypeAnnotator = TypeAnnotator( 496 schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False 497 ) 498 499 # Value ranges for byte-sized signed/unsigned integers 500 TINYINT_MIN: t.ClassVar = -128 501 TINYINT_MAX: t.ClassVar = 127 502 UTINYINT_MIN: t.ClassVar = 0 503 UTINYINT_MAX: t.ClassVar = 255 504 505 COMPLEMENT_COMPARISONS: t.ClassVar = { 506 exp.LT: exp.GTE, 507 exp.GT: exp.LTE, 508 exp.LTE: exp.GT, 509 exp.GTE: exp.LT, 510 exp.EQ: exp.NEQ, 511 exp.NEQ: exp.EQ, 512 } 513 514 COMPLEMENT_SUBQUERY_PREDICATES: t.ClassVar = { 515 exp.All: exp.Any, 516 exp.Any: exp.All, 517 } 518 519 LT_LTE: t.ClassVar = (exp.LT, exp.LTE) 520 GT_GTE: t.ClassVar = (exp.GT, exp.GTE) 521 522 COMPARISONS: t.ClassVar = ( 523 *LT_LTE, 524 *GT_GTE, 525 exp.EQ, 526 exp.NEQ, 527 exp.Is, 528 ) 529 530 INVERSE_COMPARISONS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = { 531 exp.LT: exp.GT, 532 exp.GT: exp.LT, 533 exp.LTE: exp.GTE, 534 exp.GTE: exp.LTE, 535 } 536 537 NONDETERMINISTIC: t.ClassVar = (exp.Rand, exp.Randn) 538 AND_OR: t.ClassVar = (exp.And, exp.Or) 539 540 INVERSE_DATE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = { 541 exp.DateAdd: exp.Sub, 542 exp.DateSub: exp.Add, 543 exp.DatetimeAdd: exp.Sub, 544 exp.DatetimeSub: exp.Add, 545 } 546 547 INVERSE_OPS: t.ClassVar[dict[type[exp.Expr], type[exp.Expr]]] = { 548 **INVERSE_DATE_OPS, 549 exp.Add: exp.Sub, 550 exp.Sub: exp.Add, 551 } 552 553 NULL_OK: t.ClassVar = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 554 555 CONCATS: t.ClassVar = (exp.Concat, exp.DPipe) 556 557 DATETRUNC_BINARY_COMPARISONS: t.ClassVar[dict[type[exp.Expr], DateTruncBinaryTransform]] = { 558 exp.LT: lambda l, dt, u, d, t: ( 559 l 560 < date_literal( 561 dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t 562 ) 563 ), 564 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 565 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 566 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 567 exp.EQ: _datetrunc_eq, 568 exp.NEQ: _datetrunc_neq, 569 } 570 571 DATETRUNC_COMPARISONS: t.ClassVar = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 572 DATETRUNCS: t.ClassVar = (exp.DateTrunc, exp.TimestampTrunc) 573 574 SAFE_CONNECTOR_ELIMINATION_RESULT: t.ClassVar = (exp.Connector, exp.Boolean) 575 576 # CROSS joins result in an empty table if the right table is empty. 577 # So we can only simplify certain types of joins to CROSS. 578 # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 579 JOINS: t.ClassVar = { 580 ("", ""), 581 ("", "INNER"), 582 ("RIGHT", ""), 583 ("RIGHT", "OUTER"), 584 } 585 586 def simplify( 587 self, 588 expression: exp.Expr, 589 constant_propagation: bool = False, 590 coalesce_simplification: bool = False, 591 ) -> exp.Expr: 592 wheres: list[exp.Where] = [] 593 joins: list[exp.Join] = [] 594 595 for node in expression.walk( 596 prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) 597 ): 598 if node.meta.get(FINAL): 599 continue 600 601 # group by expressions cannot be simplified, for example 602 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 603 # the projection must exactly match the group by key 604 group = node.args.get("group") 605 606 if group and hasattr(node, "selects"): 607 groups = set(group.expressions) 608 group.meta[FINAL] = True 609 610 for s in node.selects: 611 for n in s.walk(): 612 if n in groups: 613 s.meta[FINAL] = True 614 break 615 616 having = node.args.get("having") 617 618 if having: 619 for n in having.walk(): 620 if n in groups: 621 having.meta[FINAL] = True 622 break 623 624 if isinstance(node, exp.Condition): 625 simplified = while_changing( 626 node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification) 627 ) 628 629 if node is expression: 630 expression = simplified 631 elif isinstance(node, exp.Where): 632 wheres.append(node) 633 elif isinstance(node, exp.Join): 634 # snowflake match_conditions have very strict ordering rules 635 if match := node.args.get("match_condition"): 636 match.meta[FINAL] = True 637 638 joins.append(node) 639 640 for where in wheres: 641 if always_true(where.this): 642 where.pop() 643 for join in joins: 644 if ( 645 always_true(join.args.get("on")) 646 and not join.args.get("using") 647 and not join.args.get("method") 648 and (join.side, join.kind) in self.JOINS 649 ): 650 join.args["on"].pop() 651 join.set("side", None) 652 join.set("kind", "CROSS") 653 654 return expression 655 656 def _simplify( 657 self, expression: exp.Expr, constant_propagation: bool, coalesce_simplification: bool 658 ): 659 pre_transformation_stack = [expression] 660 post_transformation_stack = [] 661 662 while pre_transformation_stack: 663 original = pre_transformation_stack.pop() 664 node = original 665 666 if not isinstance(node, SIMPLIFIABLE): 667 if isinstance(node, exp.Query): 668 self.simplify(node, constant_propagation, coalesce_simplification) 669 continue 670 671 parent = node.parent 672 root = node is expression 673 674 node = self.rewrite_between(node) 675 node = self.uniq_sort(node, root) 676 node = self.absorb_and_eliminate(node, root) 677 node = self.simplify_concat(node) 678 node = self.simplify_conditionals(node) 679 680 if constant_propagation: 681 node = propagate_constants(node, root) 682 683 if node is not original: 684 original.replace(node) 685 686 for n in node.iter_expressions(reverse=True): 687 if n.meta.get(FINAL): 688 raise 689 pre_transformation_stack.extend( 690 n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 691 ) 692 post_transformation_stack.append((node, parent)) 693 694 while post_transformation_stack: 695 original, parent = post_transformation_stack.pop() 696 root = original is expression 697 698 # Resets parent, arg_key, index pointers– this is needed because some of the 699 # previous transformations mutate the AST, leading to an inconsistent state 700 for k, v in tuple(original.args.items()): 701 original.set(k, v) 702 703 # Post-order transformations 704 node = self.simplify_not(original) 705 node = flatten(node) 706 node = self.simplify_connectors(node, root) 707 node = self.remove_complements(node, root) 708 709 if coalesce_simplification: 710 node = self.simplify_coalesce(node) 711 node.parent = parent 712 713 node = self.simplify_literals(node, root) 714 node = self.simplify_equality(node) 715 node = simplify_parens(node, dialect=self.dialect) 716 node = self.simplify_datetrunc(node) 717 node = self.sort_comparison(node) 718 node = self.simplify_startswith(node) 719 720 if node is not original: 721 original.replace(node) 722 723 return node 724 725 @annotate_types_on_change 726 def rewrite_between(self, expression: exp.Expr) -> exp.Expr: 727 """Rewrite x between y and z to x >= y AND x <= z. 728 729 This is done because comparison simplification is only done on lt/lte/gt/gte. 730 """ 731 if isinstance(expression, exp.Between): 732 negate = isinstance(expression.parent, exp.Not) 733 734 expression = exp.and_( 735 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 736 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 737 copy=False, 738 ) 739 740 if negate: 741 expression = exp.paren(expression, copy=False) 742 743 return expression 744 745 @annotate_types_on_change 746 def simplify_not(self, expression: exp.Expr) -> exp.Expr: 747 """ 748 Demorgan's Law 749 NOT (x OR y) -> NOT x AND NOT y 750 NOT (x AND y) -> NOT x OR NOT y 751 """ 752 if isinstance(expression, exp.Not): 753 this = expression.this 754 if is_null(this): 755 return exp.and_(exp.null(), exp.true(), copy=False) 756 if this.__class__ in self.COMPLEMENT_COMPARISONS: 757 right = this.expression 758 complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( 759 right.__class__ 760 ) 761 if complement_subquery_predicate: 762 right = complement_subquery_predicate(this=right.this) 763 764 return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 765 if isinstance(this, exp.Paren): 766 condition = this.unnest() 767 if isinstance(condition, exp.And): 768 return exp.paren( 769 exp.or_( 770 exp.not_(condition.left, copy=False), 771 exp.not_(condition.right, copy=False), 772 copy=False, 773 ), 774 copy=False, 775 ) 776 if isinstance(condition, exp.Or): 777 return exp.paren( 778 exp.and_( 779 exp.not_(condition.left, copy=False), 780 exp.not_(condition.right, copy=False), 781 copy=False, 782 ), 783 copy=False, 784 ) 785 if is_null(condition): 786 return exp.and_(exp.null(), exp.true(), copy=False) 787 if always_true(this): 788 return exp.false() 789 if is_false(this): 790 return exp.true() 791 if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: 792 inner = this.this 793 if inner.is_type(exp.DType.BOOLEAN): 794 # double negation 795 # NOT NOT x -> x, if x is BOOLEAN type 796 return inner 797 return expression 798 799 @annotate_types_on_change 800 def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr: 801 def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr): 802 if isinstance(expression, exp.And): 803 if is_false(left) or is_false(right): 804 return exp.false() 805 if is_zero(left) or is_zero(right): 806 return exp.false() 807 if ( 808 (is_null(left) and is_null(right)) 809 or (is_null(left) and always_true(right)) 810 or (always_true(left) and is_null(right)) 811 ): 812 return exp.null() 813 if always_true(left) and always_true(right): 814 return exp.true() 815 if always_true(left): 816 return right 817 if always_true(right): 818 return left 819 return self._simplify_comparison(expression, left, right) 820 elif isinstance(expression, exp.Or): 821 if always_true(left) or always_true(right): 822 return exp.true() 823 if ( 824 (is_null(left) and is_null(right)) 825 or (is_null(left) and always_false(right)) 826 or (always_false(left) and is_null(right)) 827 ): 828 return exp.null() 829 if is_false(left): 830 return right 831 if is_false(right): 832 return left 833 return self._simplify_comparison(expression, left, right, or_=True) 834 835 if isinstance(expression, exp.Connector): 836 original_parent = expression.parent 837 expression = self._flat_simplify(expression, _simplify_connectors, root) 838 839 # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need 840 # to ensure that the resulting type is boolean. We know this is true only for connectors, 841 # boolean values and columns that are essentially operands to a connector: 842 # 843 # A AND (((B))) 844 # ~ this is safe to keep because it will eventually be part of another connector 845 if not isinstance( 846 expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT 847 ) and not expression.is_type(exp.DType.BOOLEAN): 848 while True: 849 if isinstance(original_parent, exp.Connector): 850 break 851 if not isinstance(original_parent, exp.Paren): 852 expression = expression.and_(exp.true(), copy=False) 853 break 854 855 original_parent = original_parent.parent 856 857 return expression 858 859 @annotate_types_on_change 860 def _simplify_comparison( 861 self, expression: exp.Expr, left: exp.Expr, right: exp.Expr, or_: bool = False 862 ) -> exp.Expr | None: 863 if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): 864 ll, lr = left.args.values() 865 rl, rr = right.args.values() 866 867 largs = {ll, lr} 868 rargs = {rl, rr} 869 870 matching = largs & rargs 871 columns = { 872 m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC) 873 } 874 875 if matching and columns: 876 try: 877 l = first(largs - columns) 878 r = first(rargs - columns) 879 except StopIteration: 880 return expression 881 882 if l.is_number and r.is_number: 883 l = l.to_py() 884 r = r.to_py() 885 elif l.is_string and r.is_string: 886 l = l.name 887 r = r.name 888 else: 889 l = extract_date(l) 890 if not l: 891 return None 892 r = extract_date(r) 893 if not r: 894 return None 895 # python won't compare date and datetime, but many engines will upcast 896 l, r = cast_as_datetime(l), cast_as_datetime(r) 897 898 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 899 if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE): 900 return left if (av > bv if or_ else av <= bv) else right 901 if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE): 902 return left if (av < bv if or_ else av >= bv) else right 903 904 # we can't ever shortcut to true because the column could be null 905 if not or_: 906 if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE): 907 if av <= bv: 908 return exp.false() 909 elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE): 910 if av >= bv: 911 return exp.false() 912 elif isinstance(a, exp.EQ): 913 if isinstance(b, exp.LT): 914 return exp.false() if av >= bv else a 915 if isinstance(b, exp.LTE): 916 return exp.false() if av > bv else a 917 if isinstance(b, exp.GT): 918 return exp.false() if av <= bv else a 919 if isinstance(b, exp.GTE): 920 return exp.false() if av < bv else a 921 if isinstance(b, exp.NEQ): 922 return exp.false() if av == bv else a 923 return None 924 925 @annotate_types_on_change 926 def remove_complements(self, expression: object, root: bool = True) -> object: 927 """ 928 Removing complements. 929 930 A AND NOT A -> FALSE (only for non-NULL A) 931 A OR NOT A -> TRUE (only for non-NULL A) 932 """ 933 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 934 ops = set(expression.flatten()) 935 for op in ops: 936 if isinstance(op, exp.Not) and op.this in ops: 937 if expression.meta.get("nonnull") is True: 938 return exp.false() if isinstance(expression, exp.And) else exp.true() 939 940 return expression 941 942 @annotate_types_on_change 943 def uniq_sort(self, expression: object, root: bool = True) -> object: 944 """ 945 Uniq and sort a connector. 946 947 C AND A AND B AND B -> A AND B AND C 948 """ 949 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 950 flattened = tuple(expression.flatten()) 951 952 if isinstance(expression, exp.Xor): 953 result_func = exp.xor 954 # Do not deduplicate XOR as A XOR A != A if A == True 955 deduped = None 956 arr = tuple((gen(e), e) for e in flattened) 957 else: 958 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 959 deduped = {gen(e): e for e in flattened} 960 arr = tuple(deduped.items()) 961 962 # check if the operands are already sorted, if not sort them 963 # A AND C AND B -> A AND B AND C 964 for i, (sql, e) in enumerate(arr[1:]): 965 if sql < arr[i][0]: 966 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 967 break 968 else: 969 # we didn't have to sort but maybe we need to dedup 970 if deduped and len(deduped) < len(flattened): 971 unique_operand = flattened[0] 972 if len(deduped) == 1: 973 expression = unique_operand.and_(exp.true(), copy=False) 974 else: 975 expression = result_func(*deduped.values(), copy=False) 976 977 return expression 978 979 @annotate_types_on_change 980 def absorb_and_eliminate(self, expression: object, root: bool = True) -> object: 981 """ 982 absorption: 983 A AND (A OR B) -> A 984 A OR (A AND B) -> A 985 A AND (NOT A OR B) -> A AND B 986 A OR (NOT A AND B) -> A OR B 987 elimination: 988 (A AND B) OR (A AND NOT B) -> A 989 (A OR B) AND (A OR NOT B) -> A 990 """ 991 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 992 kind = exp.Or if isinstance(expression, exp.And) else exp.And 993 994 ops = tuple(expression.flatten()) 995 996 # Initialize lookup tables: 997 # Set of all operands, used to find complements for absorption. 998 op_set = set() 999 # Sub-operands, used to find subsets for absorption. 1000 subops = defaultdict(list) 1001 # Pairs of complements, used for elimination. 1002 pairs = defaultdict(list) 1003 1004 # Populate the lookup tables 1005 for op in ops: 1006 op_set.add(op) 1007 1008 if not isinstance(op, kind): 1009 # In cases like: A OR (A AND B) 1010 # Subop will be: ^ 1011 subops[op].append({op}) 1012 continue 1013 1014 # In cases like: (A AND B) OR (A AND B AND C) 1015 # Subops will be: ^ ^ 1016 subset = set(op.flatten()) 1017 for i in subset: 1018 subops[i].append(subset) 1019 1020 a, b = op.unnest_operands() 1021 if isinstance(a, exp.Not): 1022 pairs[frozenset((a.this, b))].append((op, b)) 1023 if isinstance(b, exp.Not): 1024 pairs[frozenset((a, b.this))].append((op, a)) 1025 1026 for op in ops: 1027 if not isinstance(op, kind): 1028 continue 1029 1030 a, b = op.unnest_operands() 1031 1032 # Absorb 1033 if isinstance(a, exp.Not) and a.this in op_set: 1034 a.replace(exp.true() if kind == exp.And else exp.false()) 1035 continue 1036 if isinstance(b, exp.Not) and b.this in op_set: 1037 b.replace(exp.true() if kind == exp.And else exp.false()) 1038 continue 1039 superset = set(op.flatten()) 1040 if any(any(subset < superset for subset in subops[i]) for i in superset): 1041 op.replace(exp.false() if kind == exp.And else exp.true()) 1042 continue 1043 1044 # Eliminate 1045 for other, complement in pairs[frozenset((a, b))]: 1046 op.replace(complement) 1047 other.replace(complement) 1048 1049 return expression 1050 1051 @annotate_types_on_change 1052 @catch(ModuleNotFoundError, UnsupportedUnit) 1053 def simplify_equality(self, expression: exp.Expr) -> exp.Expr: 1054 """ 1055 Use the subtraction and addition properties of equality to simplify expressions: 1056 1057 x + 1 = 3 becomes x = 2 1058 1059 There are two binary operations in the above expression: + and = 1060 Here's how we reference all the operands in the code below: 1061 1062 l r 1063 x + 1 = 3 1064 a b 1065 """ 1066 if isinstance(expression, self.COMPARISONS): 1067 l, r = expression.left, expression.right 1068 1069 if l.__class__ not in self.INVERSE_OPS: 1070 return expression 1071 1072 if r.is_number: 1073 a_predicate = _is_number 1074 b_predicate = _is_number 1075 elif _is_date_literal(r): 1076 a_predicate = _is_date_literal 1077 b_predicate = _is_interval 1078 else: 1079 return expression 1080 1081 if l.__class__ in self.INVERSE_DATE_OPS: 1082 l = t.cast(exp.IntervalOp, l) 1083 a: exp.Expr = l.this 1084 b: exp.Expr = l.interval() 1085 else: 1086 l = t.cast(exp.Binary, l) 1087 a, b = l.left, l.right 1088 1089 if not a_predicate(a) and b_predicate(b): 1090 pass 1091 elif not a_predicate(b) and b_predicate(a): 1092 a, b = b, a 1093 else: 1094 return expression 1095 1096 return expression.__class__( 1097 this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b) 1098 ) 1099 return expression 1100 1101 def _is_inverse_date_op(self, expression: exp.Expr) -> TypeIs[exp.IntervalOp]: 1102 return type(expression) in self.INVERSE_DATE_OPS 1103 1104 @annotate_types_on_change 1105 def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr: 1106 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 1107 return self._flat_simplify(expression, self._simplify_binary, root) 1108 1109 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 1110 return expression.this.this 1111 1112 if self._is_inverse_date_op(expression): 1113 return ( 1114 self._simplify_binary(expression, expression.this, expression.interval()) 1115 or expression 1116 ) 1117 1118 return expression 1119 1120 def _simplify_integer_cast(self, expr: t.Any) -> t.Any: 1121 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 1122 this = self._simplify_integer_cast(expr.this) 1123 else: 1124 this = expr.this 1125 1126 if isinstance(expr, exp.Cast) and this.is_int: 1127 num = this.to_py() 1128 1129 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 1130 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 1131 # engine-dependent 1132 if ( 1133 self.TINYINT_MIN <= num <= self.TINYINT_MAX 1134 and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 1135 ) or ( 1136 self.UTINYINT_MIN <= num <= self.UTINYINT_MAX 1137 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 1138 ): 1139 return this 1140 1141 return expr 1142 1143 def _simplify_binary(self, expression, a, b): 1144 if isinstance(expression, self.COMPARISONS): 1145 a = self._simplify_integer_cast(a) 1146 b = self._simplify_integer_cast(b) 1147 1148 if isinstance(expression, exp.Is): 1149 if isinstance(b, exp.Not): 1150 c = b.this 1151 not_ = True 1152 else: 1153 c = b 1154 not_ = False 1155 1156 if is_null(c): 1157 if isinstance(a, exp.Literal): 1158 return exp.true() if not_ else exp.false() 1159 if is_null(a): 1160 return exp.false() if not_ else exp.true() 1161 elif isinstance(expression, self.NULL_OK): 1162 return None 1163 elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If): 1164 return exp.null() 1165 1166 if a.is_number and b.is_number: 1167 num_a = a.to_py() 1168 num_b = b.to_py() 1169 1170 if isinstance(expression, exp.Add): 1171 return exp.Literal.number(num_a + num_b) 1172 if isinstance(expression, exp.Mul): 1173 return exp.Literal.number(num_a * num_b) 1174 1175 # We only simplify Sub, Div if a and b have the same parent because they're not associative 1176 if isinstance(expression, exp.Sub): 1177 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 1178 if isinstance(expression, exp.Div): 1179 # engines have differing int div behavior so intdiv is not safe 1180 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 1181 return None 1182 return exp.Literal.number(num_a / num_b) 1183 1184 boolean = eval_boolean(expression, num_a, num_b) 1185 1186 if boolean: 1187 return boolean 1188 elif a.is_string and b.is_string: 1189 boolean = eval_boolean(expression, a.this, b.this) 1190 1191 if boolean: 1192 return boolean 1193 elif _is_date_literal(a) and isinstance(b, exp.Interval): 1194 date, b = extract_date(a), extract_interval(b) 1195 if date and b: 1196 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 1197 return date_literal(date + b, extract_type(a)) 1198 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 1199 return date_literal(date - b, extract_type(a)) 1200 elif isinstance(a, exp.Interval) and _is_date_literal(b): 1201 a, date = extract_interval(a), extract_date(b) 1202 # you cannot subtract a date from an interval 1203 if a and b and isinstance(expression, exp.Add): 1204 return date_literal(a + date, extract_type(b)) 1205 elif _is_date_literal(a) and _is_date_literal(b): 1206 if isinstance(expression, exp.Predicate): 1207 a, b = extract_date(a), extract_date(b) 1208 boolean = eval_boolean(expression, a, b) 1209 if boolean: 1210 return boolean 1211 1212 return None 1213 1214 @annotate_types_on_change 1215 def simplify_coalesce(self, expression: exp.Expr) -> exp.Expr: 1216 # COALESCE(x) -> x 1217 if ( 1218 isinstance(expression, exp.Coalesce) 1219 and (not expression.expressions or _is_nonnull_constant(expression.this)) 1220 # COALESCE is also used as a Spark partitioning hint 1221 and not isinstance(expression.parent, exp.Hint) 1222 ): 1223 return expression.this 1224 1225 if self.dialect.COALESCE_COMPARISON_NON_STANDARD: 1226 return expression 1227 1228 if not isinstance(expression, self.COMPARISONS): 1229 return expression 1230 1231 if isinstance(expression.left, exp.Coalesce): 1232 coalesce = expression.left 1233 other = expression.right 1234 elif isinstance(expression.right, exp.Coalesce): 1235 coalesce = expression.right 1236 other = expression.left 1237 else: 1238 return expression 1239 1240 # This transformation is valid for non-constants, 1241 # but it really only does anything if they are both constants. 1242 if not _is_constant(other): 1243 return expression 1244 1245 # Find the first constant arg 1246 for arg_index, arg in enumerate(coalesce.expressions): 1247 if _is_constant(arg): 1248 break 1249 else: 1250 return expression 1251 1252 coalesce.set("expressions", coalesce.expressions[:arg_index]) 1253 1254 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 1255 # since we already remove COALESCE at the top of this function. 1256 this: exp.Expr = coalesce if coalesce.expressions else coalesce.this 1257 1258 # This expression is more complex than when we started, but it will get simplified further 1259 return exp.paren( 1260 exp.or_( 1261 exp.and_( 1262 this.is_(exp.null()).not_(copy=False), 1263 expression.copy(), 1264 copy=False, 1265 ), 1266 exp.and_( 1267 this.is_(exp.null()), 1268 type(expression)(this=arg.copy(), expression=other.copy()), 1269 copy=False, 1270 ), 1271 copy=False, 1272 ), 1273 copy=False, 1274 ) 1275 1276 @annotate_types_on_change 1277 def simplify_concat(self, expression): 1278 """Reduces all groups that contain string literals by concatenating them.""" 1279 if not isinstance(expression, self.CONCATS) or ( 1280 # We can't reduce a CONCAT_WS call if we don't statically know the separator 1281 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 1282 ): 1283 return expression 1284 1285 if isinstance(expression, exp.ConcatWs): 1286 sep_expr, *expressions = expression.expressions 1287 sep = sep_expr.name 1288 concat_type = exp.ConcatWs 1289 args = {} 1290 else: 1291 expressions = expression.expressions 1292 sep = "" 1293 concat_type = exp.Concat 1294 args = { 1295 "safe": expression.args.get("safe"), 1296 "coalesce": expression.args.get("coalesce"), 1297 } 1298 1299 new_args = [] 1300 for is_string_group, group in itertools.groupby( 1301 expressions or expression.flatten(), lambda e: e.is_string 1302 ): 1303 if is_string_group: 1304 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 1305 else: 1306 new_args.extend(group) 1307 1308 if len(new_args) == 1 and new_args[0].is_string: 1309 return new_args[0] 1310 1311 if concat_type is exp.ConcatWs: 1312 new_args = [sep_expr] + new_args 1313 elif isinstance(expression, exp.DPipe): 1314 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 1315 1316 return concat_type(expressions=new_args, **args) 1317 1318 @annotate_types_on_change 1319 def simplify_conditionals(self, expression): 1320 """Simplifies expressions like IF, CASE if their condition is statically known.""" 1321 if isinstance(expression, exp.Case): 1322 this = expression.this 1323 for case in expression.args["ifs"]: 1324 cond = case.this 1325 if this: 1326 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 1327 cond = cond.replace(this.pop().eq(cond)) 1328 1329 if always_true(cond): 1330 return case.args["true"] 1331 1332 if always_false(cond): 1333 case.pop() 1334 if not expression.args["ifs"]: 1335 return expression.args.get("default") or exp.null() 1336 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 1337 if always_true(expression.this): 1338 return expression.args["true"] 1339 if always_false(expression.this): 1340 return expression.args.get("false") or exp.null() 1341 1342 return expression 1343 1344 @annotate_types_on_change 1345 def simplify_startswith(self, expression: exp.Expr) -> exp.Expr: 1346 """ 1347 Reduces a prefix check to either TRUE or FALSE if both the string and the 1348 prefix are statically known. 1349 1350 Example: 1351 >>> from sqlglot import parse_one 1352 >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 1353 'TRUE' 1354 """ 1355 if ( 1356 isinstance(expression, exp.StartsWith) 1357 and expression.this.is_string 1358 and expression.expression.is_string 1359 ): 1360 return exp.convert(expression.name.startswith(expression.expression.name)) 1361 1362 return expression 1363 1364 def _is_datetrunc_predicate(self, left: exp.Expr, right: exp.Expr) -> bool: 1365 return isinstance(left, self.DATETRUNCS) and _is_date_literal(right) 1366 1367 @annotate_types_on_change 1368 @catch(ModuleNotFoundError, UnsupportedUnit) 1369 def simplify_datetrunc(self, expression: exp.Expr) -> exp.Expr: 1370 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1371 comparison = expression.__class__ 1372 1373 if isinstance(expression, self.DATETRUNCS): 1374 this = expression.this 1375 trunc_type = extract_type(this) 1376 date = extract_date(this) 1377 if date and expression.unit: 1378 return date_literal( 1379 date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type 1380 ) 1381 elif comparison not in self.DATETRUNC_COMPARISONS: 1382 return expression 1383 1384 if isinstance(expression, exp.Binary): 1385 l, r = expression.left, expression.right 1386 1387 if not self._is_datetrunc_predicate(l, r) or not isinstance( 1388 l, (exp.DateTrunc, exp.TimestampTrunc) 1389 ): 1390 return expression 1391 1392 trunc_arg = l.this 1393 unit = l.args["unit"].name.lower() 1394 date = extract_date(r) 1395 1396 if not date: 1397 return expression 1398 1399 return ( 1400 self.DATETRUNC_BINARY_COMPARISONS[comparison]( 1401 trunc_arg, date, unit, self.dialect, extract_type(r) 1402 ) 1403 or expression 1404 ) 1405 1406 if isinstance(expression, exp.In): 1407 l = expression.this 1408 rs = expression.expressions 1409 1410 if ( 1411 rs 1412 and all(self._is_datetrunc_predicate(l, r) for r in rs) 1413 and isinstance(l, (exp.DateTrunc, exp.TimestampTrunc)) 1414 ): 1415 unit = l.args["unit"].name.lower() 1416 1417 ranges = [] 1418 for r in rs: 1419 date = extract_date(r) 1420 if not date: 1421 return expression 1422 drange = _datetrunc_range(date, unit, self.dialect) 1423 if drange: 1424 ranges.append(drange) 1425 1426 if not ranges: 1427 return expression 1428 1429 ranges = merge_ranges(ranges) 1430 target_type = extract_type(*rs) 1431 1432 return exp.or_( 1433 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], 1434 copy=False, 1435 ) 1436 1437 return expression 1438 1439 @annotate_types_on_change 1440 def sort_comparison(self, expression: exp.Expr) -> exp.Expr: 1441 if expression.__class__ in self.COMPLEMENT_COMPARISONS: 1442 l, r = expression.this, expression.expression 1443 l_column = isinstance(l, exp.Column) 1444 r_column = isinstance(r, exp.Column) 1445 l_const = _is_constant(l) 1446 r_const = _is_constant(r) 1447 1448 if ( 1449 (l_column and not r_column) 1450 or (r_const and not l_const) 1451 or isinstance(r, exp.SubqueryPredicate) 1452 ): 1453 return expression 1454 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1455 return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1456 this=r, expression=l 1457 ) 1458 return expression 1459 1460 def _flat_simplify( 1461 self, 1462 expression: exp.Expr, 1463 simplifier: t.Callable[[exp.Expr, exp.Expr, exp.Expr], exp.Expr | None], 1464 root: bool = True, 1465 ) -> exp.Expr: 1466 if root or not expression.same_parent: 1467 operands = [] 1468 queue = deque(expression.flatten(unnest=False)) 1469 size = len(queue) 1470 1471 while queue: 1472 a = queue.popleft() 1473 1474 for b in queue: 1475 result = simplifier(expression, a, b) 1476 1477 if result and result is not expression: 1478 queue.remove(b) 1479 queue.appendleft(result) 1480 break 1481 else: 1482 operands.append(a) 1483 1484 if len(operands) < size: 1485 return functools.reduce( 1486 lambda a, b: expression.__class__(this=a, expression=b), operands 1487 ) 1488 return expression
491 def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): 492 self.dialect = Dialect.get_or_raise(dialect) 493 self.annotate_new_expressions = annotate_new_expressions 494 495 self._annotator: TypeAnnotator = TypeAnnotator( 496 schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False 497 )
586 def simplify( 587 self, 588 expression: exp.Expr, 589 constant_propagation: bool = False, 590 coalesce_simplification: bool = False, 591 ) -> exp.Expr: 592 wheres: list[exp.Where] = [] 593 joins: list[exp.Join] = [] 594 595 for node in expression.walk( 596 prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) 597 ): 598 if node.meta.get(FINAL): 599 continue 600 601 # group by expressions cannot be simplified, for example 602 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 603 # the projection must exactly match the group by key 604 group = node.args.get("group") 605 606 if group and hasattr(node, "selects"): 607 groups = set(group.expressions) 608 group.meta[FINAL] = True 609 610 for s in node.selects: 611 for n in s.walk(): 612 if n in groups: 613 s.meta[FINAL] = True 614 break 615 616 having = node.args.get("having") 617 618 if having: 619 for n in having.walk(): 620 if n in groups: 621 having.meta[FINAL] = True 622 break 623 624 if isinstance(node, exp.Condition): 625 simplified = while_changing( 626 node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification) 627 ) 628 629 if node is expression: 630 expression = simplified 631 elif isinstance(node, exp.Where): 632 wheres.append(node) 633 elif isinstance(node, exp.Join): 634 # snowflake match_conditions have very strict ordering rules 635 if match := node.args.get("match_condition"): 636 match.meta[FINAL] = True 637 638 joins.append(node) 639 640 for where in wheres: 641 if always_true(where.this): 642 where.pop() 643 for join in joins: 644 if ( 645 always_true(join.args.get("on")) 646 and not join.args.get("using") 647 and not join.args.get("method") 648 and (join.side, join.kind) in self.JOINS 649 ): 650 join.args["on"].pop() 651 join.set("side", None) 652 join.set("kind", "CROSS") 653 654 return expression
725 @annotate_types_on_change 726 def rewrite_between(self, expression: exp.Expr) -> exp.Expr: 727 """Rewrite x between y and z to x >= y AND x <= z. 728 729 This is done because comparison simplification is only done on lt/lte/gt/gte. 730 """ 731 if isinstance(expression, exp.Between): 732 negate = isinstance(expression.parent, exp.Not) 733 734 expression = exp.and_( 735 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 736 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 737 copy=False, 738 ) 739 740 if negate: 741 expression = exp.paren(expression, copy=False) 742 743 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
745 @annotate_types_on_change 746 def simplify_not(self, expression: exp.Expr) -> exp.Expr: 747 """ 748 Demorgan's Law 749 NOT (x OR y) -> NOT x AND NOT y 750 NOT (x AND y) -> NOT x OR NOT y 751 """ 752 if isinstance(expression, exp.Not): 753 this = expression.this 754 if is_null(this): 755 return exp.and_(exp.null(), exp.true(), copy=False) 756 if this.__class__ in self.COMPLEMENT_COMPARISONS: 757 right = this.expression 758 complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( 759 right.__class__ 760 ) 761 if complement_subquery_predicate: 762 right = complement_subquery_predicate(this=right.this) 763 764 return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 765 if isinstance(this, exp.Paren): 766 condition = this.unnest() 767 if isinstance(condition, exp.And): 768 return exp.paren( 769 exp.or_( 770 exp.not_(condition.left, copy=False), 771 exp.not_(condition.right, copy=False), 772 copy=False, 773 ), 774 copy=False, 775 ) 776 if isinstance(condition, exp.Or): 777 return exp.paren( 778 exp.and_( 779 exp.not_(condition.left, copy=False), 780 exp.not_(condition.right, copy=False), 781 copy=False, 782 ), 783 copy=False, 784 ) 785 if is_null(condition): 786 return exp.and_(exp.null(), exp.true(), copy=False) 787 if always_true(this): 788 return exp.false() 789 if is_false(this): 790 return exp.true() 791 if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: 792 inner = this.this 793 if inner.is_type(exp.DType.BOOLEAN): 794 # double negation 795 # NOT NOT x -> x, if x is BOOLEAN type 796 return inner 797 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
799 @annotate_types_on_change 800 def simplify_connectors(self, expression: exp.Expr, root: bool = True) -> exp.Expr: 801 def _simplify_connectors(expression: exp.Expr, left: exp.Expr, right: exp.Expr): 802 if isinstance(expression, exp.And): 803 if is_false(left) or is_false(right): 804 return exp.false() 805 if is_zero(left) or is_zero(right): 806 return exp.false() 807 if ( 808 (is_null(left) and is_null(right)) 809 or (is_null(left) and always_true(right)) 810 or (always_true(left) and is_null(right)) 811 ): 812 return exp.null() 813 if always_true(left) and always_true(right): 814 return exp.true() 815 if always_true(left): 816 return right 817 if always_true(right): 818 return left 819 return self._simplify_comparison(expression, left, right) 820 elif isinstance(expression, exp.Or): 821 if always_true(left) or always_true(right): 822 return exp.true() 823 if ( 824 (is_null(left) and is_null(right)) 825 or (is_null(left) and always_false(right)) 826 or (always_false(left) and is_null(right)) 827 ): 828 return exp.null() 829 if is_false(left): 830 return right 831 if is_false(right): 832 return left 833 return self._simplify_comparison(expression, left, right, or_=True) 834 835 if isinstance(expression, exp.Connector): 836 original_parent = expression.parent 837 expression = self._flat_simplify(expression, _simplify_connectors, root) 838 839 # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need 840 # to ensure that the resulting type is boolean. We know this is true only for connectors, 841 # boolean values and columns that are essentially operands to a connector: 842 # 843 # A AND (((B))) 844 # ~ this is safe to keep because it will eventually be part of another connector 845 if not isinstance( 846 expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT 847 ) and not expression.is_type(exp.DType.BOOLEAN): 848 while True: 849 if isinstance(original_parent, exp.Connector): 850 break 851 if not isinstance(original_parent, exp.Paren): 852 expression = expression.and_(exp.true(), copy=False) 853 break 854 855 original_parent = original_parent.parent 856 857 return expression
925 @annotate_types_on_change 926 def remove_complements(self, expression: object, root: bool = True) -> object: 927 """ 928 Removing complements. 929 930 A AND NOT A -> FALSE (only for non-NULL A) 931 A OR NOT A -> TRUE (only for non-NULL A) 932 """ 933 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 934 ops = set(expression.flatten()) 935 for op in ops: 936 if isinstance(op, exp.Not) and op.this in ops: 937 if expression.meta.get("nonnull") is True: 938 return exp.false() if isinstance(expression, exp.And) else exp.true() 939 940 return expression
Removing complements.
A AND NOT A -> FALSE (only for non-NULL A) A OR NOT A -> TRUE (only for non-NULL A)
942 @annotate_types_on_change 943 def uniq_sort(self, expression: object, root: bool = True) -> object: 944 """ 945 Uniq and sort a connector. 946 947 C AND A AND B AND B -> A AND B AND C 948 """ 949 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 950 flattened = tuple(expression.flatten()) 951 952 if isinstance(expression, exp.Xor): 953 result_func = exp.xor 954 # Do not deduplicate XOR as A XOR A != A if A == True 955 deduped = None 956 arr = tuple((gen(e), e) for e in flattened) 957 else: 958 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 959 deduped = {gen(e): e for e in flattened} 960 arr = tuple(deduped.items()) 961 962 # check if the operands are already sorted, if not sort them 963 # A AND C AND B -> A AND B AND C 964 for i, (sql, e) in enumerate(arr[1:]): 965 if sql < arr[i][0]: 966 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 967 break 968 else: 969 # we didn't have to sort but maybe we need to dedup 970 if deduped and len(deduped) < len(flattened): 971 unique_operand = flattened[0] 972 if len(deduped) == 1: 973 expression = unique_operand.and_(exp.true(), copy=False) 974 else: 975 expression = result_func(*deduped.values(), copy=False) 976 977 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
979 @annotate_types_on_change 980 def absorb_and_eliminate(self, expression: object, root: bool = True) -> object: 981 """ 982 absorption: 983 A AND (A OR B) -> A 984 A OR (A AND B) -> A 985 A AND (NOT A OR B) -> A AND B 986 A OR (NOT A AND B) -> A OR B 987 elimination: 988 (A AND B) OR (A AND NOT B) -> A 989 (A OR B) AND (A OR NOT B) -> A 990 """ 991 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 992 kind = exp.Or if isinstance(expression, exp.And) else exp.And 993 994 ops = tuple(expression.flatten()) 995 996 # Initialize lookup tables: 997 # Set of all operands, used to find complements for absorption. 998 op_set = set() 999 # Sub-operands, used to find subsets for absorption. 1000 subops = defaultdict(list) 1001 # Pairs of complements, used for elimination. 1002 pairs = defaultdict(list) 1003 1004 # Populate the lookup tables 1005 for op in ops: 1006 op_set.add(op) 1007 1008 if not isinstance(op, kind): 1009 # In cases like: A OR (A AND B) 1010 # Subop will be: ^ 1011 subops[op].append({op}) 1012 continue 1013 1014 # In cases like: (A AND B) OR (A AND B AND C) 1015 # Subops will be: ^ ^ 1016 subset = set(op.flatten()) 1017 for i in subset: 1018 subops[i].append(subset) 1019 1020 a, b = op.unnest_operands() 1021 if isinstance(a, exp.Not): 1022 pairs[frozenset((a.this, b))].append((op, b)) 1023 if isinstance(b, exp.Not): 1024 pairs[frozenset((a, b.this))].append((op, a)) 1025 1026 for op in ops: 1027 if not isinstance(op, kind): 1028 continue 1029 1030 a, b = op.unnest_operands() 1031 1032 # Absorb 1033 if isinstance(a, exp.Not) and a.this in op_set: 1034 a.replace(exp.true() if kind == exp.And else exp.false()) 1035 continue 1036 if isinstance(b, exp.Not) and b.this in op_set: 1037 b.replace(exp.true() if kind == exp.And else exp.false()) 1038 continue 1039 superset = set(op.flatten()) 1040 if any(any(subset < superset for subset in subops[i]) for i in superset): 1041 op.replace(exp.false() if kind == exp.And else exp.true()) 1042 continue 1043 1044 # Eliminate 1045 for other, complement in pairs[frozenset((a, b))]: 1046 op.replace(complement) 1047 other.replace(complement) 1048 1049 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
84 def wrapped(expression, *args, **kwargs): 85 try: 86 return func(expression, *args, **kwargs) 87 except exceptions: 88 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
1104 @annotate_types_on_change 1105 def simplify_literals(self, expression: exp.Expr, root: bool = True) -> exp.Expr: 1106 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 1107 return self._flat_simplify(expression, self._simplify_binary, root) 1108 1109 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 1110 return expression.this.this 1111 1112 if self._is_inverse_date_op(expression): 1113 return ( 1114 self._simplify_binary(expression, expression.this, expression.interval()) 1115 or expression 1116 ) 1117 1118 return expression
1214 @annotate_types_on_change 1215 def simplify_coalesce(self, expression: exp.Expr) -> exp.Expr: 1216 # COALESCE(x) -> x 1217 if ( 1218 isinstance(expression, exp.Coalesce) 1219 and (not expression.expressions or _is_nonnull_constant(expression.this)) 1220 # COALESCE is also used as a Spark partitioning hint 1221 and not isinstance(expression.parent, exp.Hint) 1222 ): 1223 return expression.this 1224 1225 if self.dialect.COALESCE_COMPARISON_NON_STANDARD: 1226 return expression 1227 1228 if not isinstance(expression, self.COMPARISONS): 1229 return expression 1230 1231 if isinstance(expression.left, exp.Coalesce): 1232 coalesce = expression.left 1233 other = expression.right 1234 elif isinstance(expression.right, exp.Coalesce): 1235 coalesce = expression.right 1236 other = expression.left 1237 else: 1238 return expression 1239 1240 # This transformation is valid for non-constants, 1241 # but it really only does anything if they are both constants. 1242 if not _is_constant(other): 1243 return expression 1244 1245 # Find the first constant arg 1246 for arg_index, arg in enumerate(coalesce.expressions): 1247 if _is_constant(arg): 1248 break 1249 else: 1250 return expression 1251 1252 coalesce.set("expressions", coalesce.expressions[:arg_index]) 1253 1254 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 1255 # since we already remove COALESCE at the top of this function. 1256 this: exp.Expr = coalesce if coalesce.expressions else coalesce.this 1257 1258 # This expression is more complex than when we started, but it will get simplified further 1259 return exp.paren( 1260 exp.or_( 1261 exp.and_( 1262 this.is_(exp.null()).not_(copy=False), 1263 expression.copy(), 1264 copy=False, 1265 ), 1266 exp.and_( 1267 this.is_(exp.null()), 1268 type(expression)(this=arg.copy(), expression=other.copy()), 1269 copy=False, 1270 ), 1271 copy=False, 1272 ), 1273 copy=False, 1274 )
1276 @annotate_types_on_change 1277 def simplify_concat(self, expression): 1278 """Reduces all groups that contain string literals by concatenating them.""" 1279 if not isinstance(expression, self.CONCATS) or ( 1280 # We can't reduce a CONCAT_WS call if we don't statically know the separator 1281 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 1282 ): 1283 return expression 1284 1285 if isinstance(expression, exp.ConcatWs): 1286 sep_expr, *expressions = expression.expressions 1287 sep = sep_expr.name 1288 concat_type = exp.ConcatWs 1289 args = {} 1290 else: 1291 expressions = expression.expressions 1292 sep = "" 1293 concat_type = exp.Concat 1294 args = { 1295 "safe": expression.args.get("safe"), 1296 "coalesce": expression.args.get("coalesce"), 1297 } 1298 1299 new_args = [] 1300 for is_string_group, group in itertools.groupby( 1301 expressions or expression.flatten(), lambda e: e.is_string 1302 ): 1303 if is_string_group: 1304 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 1305 else: 1306 new_args.extend(group) 1307 1308 if len(new_args) == 1 and new_args[0].is_string: 1309 return new_args[0] 1310 1311 if concat_type is exp.ConcatWs: 1312 new_args = [sep_expr] + new_args 1313 elif isinstance(expression, exp.DPipe): 1314 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 1315 1316 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
1318 @annotate_types_on_change 1319 def simplify_conditionals(self, expression): 1320 """Simplifies expressions like IF, CASE if their condition is statically known.""" 1321 if isinstance(expression, exp.Case): 1322 this = expression.this 1323 for case in expression.args["ifs"]: 1324 cond = case.this 1325 if this: 1326 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 1327 cond = cond.replace(this.pop().eq(cond)) 1328 1329 if always_true(cond): 1330 return case.args["true"] 1331 1332 if always_false(cond): 1333 case.pop() 1334 if not expression.args["ifs"]: 1335 return expression.args.get("default") or exp.null() 1336 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 1337 if always_true(expression.this): 1338 return expression.args["true"] 1339 if always_false(expression.this): 1340 return expression.args.get("false") or exp.null() 1341 1342 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
1344 @annotate_types_on_change 1345 def simplify_startswith(self, expression: exp.Expr) -> exp.Expr: 1346 """ 1347 Reduces a prefix check to either TRUE or FALSE if both the string and the 1348 prefix are statically known. 1349 1350 Example: 1351 >>> from sqlglot import parse_one 1352 >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 1353 'TRUE' 1354 """ 1355 if ( 1356 isinstance(expression, exp.StartsWith) 1357 and expression.this.is_string 1358 and expression.expression.is_string 1359 ): 1360 return exp.convert(expression.name.startswith(expression.expression.name)) 1361 1362 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
84 def wrapped(expression, *args, **kwargs): 85 try: 86 return func(expression, *args, **kwargs) 87 except exceptions: 88 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1439 @annotate_types_on_change 1440 def sort_comparison(self, expression: exp.Expr) -> exp.Expr: 1441 if expression.__class__ in self.COMPLEMENT_COMPARISONS: 1442 l, r = expression.this, expression.expression 1443 l_column = isinstance(l, exp.Column) 1444 r_column = isinstance(r, exp.Column) 1445 l_const = _is_constant(l) 1446 r_const = _is_constant(r) 1447 1448 if ( 1449 (l_column and not r_column) 1450 or (r_const and not l_const) 1451 or isinstance(r, exp.SubqueryPredicate) 1452 ): 1453 return expression 1454 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1455 return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1456 this=r, expression=l 1457 ) 1458 return expression
1491def gen(expression: exp.Expr, comments: bool = False) -> str: 1492 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1493 1494 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1495 generator is expensive so we have a bare minimum sql generator here. 1496 1497 Args: 1498 expression: the expression to convert into a SQL string. 1499 comments: whether to include the expression's comments. 1500 """ 1501 return Gen().gen(expression, comments=comments)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
Arguments:
- expression: the expression to convert into a SQL string.
- comments: whether to include the expression's comments.
1504class Gen: 1505 def __init__(self): 1506 self.stack = [] 1507 self.sqls = [] 1508 1509 def gen(self, expression: exp.Expr, comments: bool = False) -> str: 1510 self.stack = [expression] 1511 self.sqls.clear() 1512 1513 while self.stack: 1514 node = self.stack.pop() 1515 1516 if isinstance(node, exp.Expr): 1517 if comments and node.comments: 1518 self.stack.append(f" /*{','.join(node.comments)}*/") 1519 1520 exp_handler_name = f"{node.key}_sql" 1521 1522 if hasattr(self, exp_handler_name): 1523 getattr(self, exp_handler_name)(node) 1524 elif isinstance(node, exp.Func): 1525 self._function(node) 1526 else: 1527 key = node.key.upper() 1528 self.stack.append(f"{key} " if self._args(node) else key) 1529 elif type(node) is list: 1530 for n in reversed(node): 1531 if n is not None: 1532 self.stack.extend((n, ",")) 1533 if node: 1534 self.stack.pop() 1535 else: 1536 if node is not None: 1537 self.sqls.append(str(node)) 1538 1539 return "".join(self.sqls) 1540 1541 def add_sql(self, e: exp.Add) -> None: 1542 self._binary(e, " + ") 1543 1544 def alias_sql(self, e: exp.Alias) -> None: 1545 self.stack.extend( 1546 ( 1547 e.args.get("alias"), 1548 " AS ", 1549 e.args.get("this"), 1550 ) 1551 ) 1552 1553 def and_sql(self, e: exp.And) -> None: 1554 self._binary(e, " AND ") 1555 1556 def anonymous_sql(self, e: exp.Anonymous) -> None: 1557 this = e.this 1558 if isinstance(this, str): 1559 name = this.upper() 1560 elif isinstance(this, exp.Identifier): 1561 name = this.this 1562 name = f'"{name}"' if this.quoted else name.upper() 1563 else: 1564 raise ValueError( 1565 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1566 ) 1567 1568 self.stack.extend( 1569 ( 1570 ")", 1571 e.expressions, 1572 "(", 1573 name, 1574 ) 1575 ) 1576 1577 def between_sql(self, e: exp.Between) -> None: 1578 self.stack.extend( 1579 ( 1580 e.args.get("high"), 1581 " AND ", 1582 e.args.get("low"), 1583 " BETWEEN ", 1584 e.this, 1585 ) 1586 ) 1587 1588 def boolean_sql(self, e: exp.Boolean) -> None: 1589 self.stack.append("TRUE" if e.this else "FALSE") 1590 1591 def bracket_sql(self, e: exp.Bracket) -> None: 1592 self.stack.extend( 1593 ( 1594 "]", 1595 e.expressions, 1596 "[", 1597 e.this, 1598 ) 1599 ) 1600 1601 def column_sql(self, e: exp.Column) -> None: 1602 for p in reversed(e.parts): 1603 self.stack.extend((p, ".")) 1604 self.stack.pop() 1605 1606 def datatype_sql(self, e: exp.DataType) -> None: 1607 self._args(e, 1) 1608 self.stack.append(f"{e.this.name} ") 1609 1610 def div_sql(self, e: exp.Div) -> None: 1611 self._binary(e, " / ") 1612 1613 def dot_sql(self, e: exp.Dot) -> None: 1614 self._binary(e, ".") 1615 1616 def eq_sql(self, e: exp.EQ) -> None: 1617 self._binary(e, " = ") 1618 1619 def from_sql(self, e: exp.From) -> None: 1620 self.stack.extend((e.this, "FROM ")) 1621 1622 def gt_sql(self, e: exp.GT) -> None: 1623 self._binary(e, " > ") 1624 1625 def gte_sql(self, e: exp.GTE) -> None: 1626 self._binary(e, " >= ") 1627 1628 def identifier_sql(self, e: exp.Identifier) -> None: 1629 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1630 1631 def ilike_sql(self, e: exp.ILike) -> None: 1632 self._binary(e, " ILIKE ") 1633 1634 def in_sql(self, e: exp.In) -> None: 1635 self.stack.append(")") 1636 self._args(e, 1) 1637 self.stack.extend( 1638 ( 1639 "(", 1640 " IN ", 1641 e.this, 1642 ) 1643 ) 1644 1645 def intdiv_sql(self, e: exp.IntDiv) -> None: 1646 self._binary(e, " DIV ") 1647 1648 def is_sql(self, e: exp.Is) -> None: 1649 self._binary(e, " IS ") 1650 1651 def like_sql(self, e: exp.Like) -> None: 1652 self._binary(e, " Like ") 1653 1654 def literal_sql(self, e: exp.Literal) -> None: 1655 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1656 1657 def lt_sql(self, e: exp.LT) -> None: 1658 self._binary(e, " < ") 1659 1660 def lte_sql(self, e: exp.LTE) -> None: 1661 self._binary(e, " <= ") 1662 1663 def mod_sql(self, e: exp.Mod) -> None: 1664 self._binary(e, " % ") 1665 1666 def mul_sql(self, e: exp.Mul) -> None: 1667 self._binary(e, " * ") 1668 1669 def neg_sql(self, e: exp.Neg) -> None: 1670 self._unary(e, "-") 1671 1672 def neq_sql(self, e: exp.NEQ) -> None: 1673 self._binary(e, " <> ") 1674 1675 def not_sql(self, e: exp.Not) -> None: 1676 self._unary(e, "NOT ") 1677 1678 def null_sql(self, e: exp.Null) -> None: 1679 self.stack.append("NULL") 1680 1681 def or_sql(self, e: exp.Or) -> None: 1682 self._binary(e, " OR ") 1683 1684 def paren_sql(self, e: exp.Paren) -> None: 1685 self.stack.extend( 1686 ( 1687 ")", 1688 e.this, 1689 "(", 1690 ) 1691 ) 1692 1693 def sub_sql(self, e: exp.Sub) -> None: 1694 self._binary(e, " - ") 1695 1696 def subquery_sql(self, e: exp.Subquery) -> None: 1697 self._args(e, 2) 1698 alias = e.args.get("alias") 1699 if alias: 1700 self.stack.append(alias) 1701 self.stack.extend((")", e.this, "(")) 1702 1703 def table_sql(self, e: exp.Table) -> None: 1704 self._args(e, 4) 1705 alias = e.args.get("alias") 1706 if alias: 1707 self.stack.append(alias) 1708 for p in reversed(e.parts): 1709 self.stack.extend((p, ".")) 1710 self.stack.pop() 1711 1712 def tablealias_sql(self, e: exp.TableAlias) -> None: 1713 columns = e.columns 1714 1715 if columns: 1716 self.stack.extend((")", columns, "(")) 1717 1718 self.stack.extend((e.this, " AS ")) 1719 1720 def var_sql(self, e: exp.Var) -> None: 1721 self.stack.append(e.this) 1722 1723 def _binary(self, e: exp.Binary, op: str) -> None: 1724 self.stack.extend((e.expression, op, e.this)) 1725 1726 def _unary(self, e: exp.Unary, op: str) -> None: 1727 self.stack.extend((e.this, op)) 1728 1729 def _function(self, e: exp.Func) -> None: 1730 self.stack.extend( 1731 ( 1732 ")", 1733 list(e.args.values()), 1734 "(", 1735 e.sql_name(), 1736 ) 1737 ) 1738 1739 def _args(self, node: exp.Expr, arg_index: int = 0) -> bool: 1740 kvs = [] 1741 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1742 1743 for k in arg_types: 1744 v = node.args.get(k) 1745 1746 if v is not None: 1747 kvs.append([f":{k}", v]) 1748 if kvs: 1749 self.stack.append(kvs) 1750 return True 1751 return False
1509 def gen(self, expression: exp.Expr, comments: bool = False) -> str: 1510 self.stack = [expression] 1511 self.sqls.clear() 1512 1513 while self.stack: 1514 node = self.stack.pop() 1515 1516 if isinstance(node, exp.Expr): 1517 if comments and node.comments: 1518 self.stack.append(f" /*{','.join(node.comments)}*/") 1519 1520 exp_handler_name = f"{node.key}_sql" 1521 1522 if hasattr(self, exp_handler_name): 1523 getattr(self, exp_handler_name)(node) 1524 elif isinstance(node, exp.Func): 1525 self._function(node) 1526 else: 1527 key = node.key.upper() 1528 self.stack.append(f"{key} " if self._args(node) else key) 1529 elif type(node) is list: 1530 for n in reversed(node): 1531 if n is not None: 1532 self.stack.extend((n, ",")) 1533 if node: 1534 self.stack.pop() 1535 else: 1536 if node is not None: 1537 self.sqls.append(str(node)) 1538 1539 return "".join(self.sqls)
1556 def anonymous_sql(self, e: exp.Anonymous) -> None: 1557 this = e.this 1558 if isinstance(this, str): 1559 name = this.upper() 1560 elif isinstance(this, exp.Identifier): 1561 name = this.this 1562 name = f'"{name}"' if this.quoted else name.upper() 1563 else: 1564 raise ValueError( 1565 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1566 ) 1567 1568 self.stack.extend( 1569 ( 1570 ")", 1571 e.expressions, 1572 "(", 1573 name, 1574 ) 1575 )