sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce, wraps 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.annotate_types import TypeAnnotator 15from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 16from sqlglot.schema import ensure_schema 17 18if t.TYPE_CHECKING: 19 from sqlglot.dialects.dialect import DialectType 20 21 DateRange = t.Tuple[datetime.date, datetime.date] 22 DateTruncBinaryTransform = t.Callable[ 23 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 24 ] 25 26 27logger = logging.getLogger("sqlglot") 28 29 30# Final means that an expression should not be simplified 31FINAL = "final" 32 33SIMPLIFIABLE = ( 34 exp.Binary, 35 exp.Func, 36 exp.Lambda, 37 exp.Predicate, 38 exp.Unary, 39) 40 41 42def simplify( 43 expression: exp.Expression, 44 constant_propagation: bool = False, 45 coalesce_simplification: bool = False, 46 dialect: DialectType = None, 47): 48 """ 49 Rewrite sqlglot AST to simplify expressions. 50 51 Example: 52 >>> import sqlglot 53 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 54 >>> simplify(expression).sql() 55 'TRUE' 56 57 Args: 58 expression: expression to simplify 59 constant_propagation: whether the constant propagation rule should be used 60 coalesce_simplification: whether the simplify coalesce rule should be used. 61 This rule tries to remove coalesce functions, which can be useful in certain analyses but 62 can leave the query more verbose. 63 Returns: 64 sqlglot.Expression: simplified expression 65 """ 66 return Simplifier(dialect=dialect).simplify( 67 expression, 68 constant_propagation=constant_propagation, 69 coalesce_simplification=coalesce_simplification, 70 ) 71 72 73class UnsupportedUnit(Exception): 74 pass 75 76 77def catch(*exceptions): 78 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 79 80 def decorator(func): 81 def wrapped(expression, *args, **kwargs): 82 try: 83 return func(expression, *args, **kwargs) 84 except exceptions: 85 return expression 86 87 return wrapped 88 89 return decorator 90 91 92def annotate_types_on_change(func): 93 @wraps(func) 94 def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]: 95 new_expression = func(self, expression, *args, **kwargs) 96 97 if new_expression is None: 98 return new_expression 99 100 if self.annotate_new_expressions and expression != new_expression: 101 self._annotator.clear() 102 103 # We annotate this to ensure new children nodes are also annotated 104 new_expression = self._annotator.annotate( 105 expression=new_expression, 106 annotate_scope=False, 107 ) 108 109 # Whatever expression the original expression is transformed into needs to preserve 110 # the original type, otherwise the simplification could result in a different schema 111 new_expression.type = expression.type 112 113 return new_expression 114 115 return _func 116 117 118def flatten(expression): 119 """ 120 A AND (B AND C) -> A AND B AND C 121 A OR (B OR C) -> A OR B OR C 122 """ 123 if isinstance(expression, exp.Connector): 124 for node in expression.args.values(): 125 child = node.unnest() 126 if isinstance(child, expression.__class__): 127 node.replace(child) 128 return expression 129 130 131def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 132 if not isinstance(expression, exp.Paren): 133 return expression 134 135 this = expression.this 136 parent = expression.parent 137 parent_is_predicate = isinstance(parent, exp.Predicate) 138 139 if isinstance(this, exp.Select): 140 return expression 141 142 if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): 143 return expression 144 145 if ( 146 Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS 147 and isinstance(parent, exp.Dot) 148 and (isinstance(parent.right, (exp.Identifier, exp.Star))) 149 ): 150 return expression 151 152 if ( 153 not isinstance(parent, (exp.Condition, exp.Binary)) 154 or isinstance(parent, exp.Paren) 155 or ( 156 not isinstance(this, exp.Binary) 157 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 158 ) 159 or ( 160 isinstance(this, exp.Predicate) 161 and not (parent_is_predicate or isinstance(parent, exp.Neg)) 162 ) 163 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 164 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 165 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 166 ): 167 return this 168 169 return expression 170 171 172def propagate_constants(expression, root=True): 173 """ 174 Propagate constants for conjunctions in DNF: 175 176 SELECT * FROM t WHERE a = b AND b = 5 becomes 177 SELECT * FROM t WHERE a = 5 AND b = 5 178 179 Reference: https://www.sqlite.org/optoverview.html 180 """ 181 182 if ( 183 isinstance(expression, exp.And) 184 and (root or not expression.same_parent) 185 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 186 ): 187 constant_mapping = {} 188 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 189 if isinstance(expr, exp.EQ): 190 l, r = expr.left, expr.right 191 192 # TODO: create a helper that can be used to detect nested literal expressions such 193 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 194 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 195 constant_mapping[l] = (id(l), r) 196 197 if constant_mapping: 198 for column in find_all_in_scope(expression, exp.Column): 199 parent = column.parent 200 column_id, constant = constant_mapping.get(column) or (None, None) 201 if ( 202 column_id is not None 203 and id(column) != column_id 204 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 205 ): 206 column.replace(constant.copy()) 207 208 return expression 209 210 211def _is_number(expression: exp.Expression) -> bool: 212 return expression.is_number 213 214 215def _is_interval(expression: exp.Expression) -> bool: 216 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 217 218 219def _is_nonnull_constant(expression: exp.Expression) -> bool: 220 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 221 222 223def _is_constant(expression: exp.Expression) -> bool: 224 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 225 226 227def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 228 """ 229 Get the date range for a DATE_TRUNC equality comparison: 230 231 Example: 232 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 233 Returns: 234 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 235 """ 236 floor = date_floor(date, unit, dialect) 237 238 if date != floor: 239 # This will always be False, except for NULL values. 240 return None 241 242 return floor, floor + interval(unit) 243 244 245def _datetrunc_eq_expression( 246 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 247) -> exp.Expression: 248 """Get the logical expression for a date range""" 249 return exp.and_( 250 left >= date_literal(drange[0], target_type), 251 left < date_literal(drange[1], target_type), 252 copy=False, 253 ) 254 255 256def _datetrunc_eq( 257 left: exp.Expression, 258 date: datetime.date, 259 unit: str, 260 dialect: Dialect, 261 target_type: t.Optional[exp.DataType], 262) -> t.Optional[exp.Expression]: 263 drange = _datetrunc_range(date, unit, dialect) 264 if not drange: 265 return None 266 267 return _datetrunc_eq_expression(left, drange, target_type) 268 269 270def _datetrunc_neq( 271 left: exp.Expression, 272 date: datetime.date, 273 unit: str, 274 dialect: Dialect, 275 target_type: t.Optional[exp.DataType], 276) -> t.Optional[exp.Expression]: 277 drange = _datetrunc_range(date, unit, dialect) 278 if not drange: 279 return None 280 281 return exp.and_( 282 left < date_literal(drange[0], target_type), 283 left >= date_literal(drange[1], target_type), 284 copy=False, 285 ) 286 287 288def always_true(expression): 289 return (isinstance(expression, exp.Boolean) and expression.this) or ( 290 isinstance(expression, exp.Literal) and expression.is_number and not is_zero(expression) 291 ) 292 293 294def always_false(expression): 295 return is_false(expression) or is_null(expression) or is_zero(expression) 296 297 298def is_zero(expression): 299 return isinstance(expression, exp.Literal) and expression.to_py() == 0 300 301 302def is_complement(a, b): 303 return isinstance(b, exp.Not) and b.this == a 304 305 306def is_false(a: exp.Expression) -> bool: 307 return type(a) is exp.Boolean and not a.this 308 309 310def is_null(a: exp.Expression) -> bool: 311 return type(a) is exp.Null 312 313 314def eval_boolean(expression, a, b): 315 if isinstance(expression, (exp.EQ, exp.Is)): 316 return boolean_literal(a == b) 317 if isinstance(expression, exp.NEQ): 318 return boolean_literal(a != b) 319 if isinstance(expression, exp.GT): 320 return boolean_literal(a > b) 321 if isinstance(expression, exp.GTE): 322 return boolean_literal(a >= b) 323 if isinstance(expression, exp.LT): 324 return boolean_literal(a < b) 325 if isinstance(expression, exp.LTE): 326 return boolean_literal(a <= b) 327 return None 328 329 330def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 331 if isinstance(value, datetime.datetime): 332 return value.date() 333 if isinstance(value, datetime.date): 334 return value 335 try: 336 return datetime.datetime.fromisoformat(value).date() 337 except ValueError: 338 return None 339 340 341def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 342 if isinstance(value, datetime.datetime): 343 return value 344 if isinstance(value, datetime.date): 345 return datetime.datetime(year=value.year, month=value.month, day=value.day) 346 try: 347 return datetime.datetime.fromisoformat(value) 348 except ValueError: 349 return None 350 351 352def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 353 if not value: 354 return None 355 if to.is_type(exp.DataType.Type.DATE): 356 return cast_as_date(value) 357 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 358 return cast_as_datetime(value) 359 return None 360 361 362def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 363 if isinstance(cast, exp.Cast): 364 to = cast.to 365 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 366 to = exp.DataType.build(exp.DataType.Type.DATE) 367 else: 368 return None 369 370 if isinstance(cast.this, exp.Literal): 371 value: t.Any = cast.this.name 372 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 373 value = extract_date(cast.this) 374 else: 375 return None 376 return cast_value(value, to) 377 378 379def _is_date_literal(expression: exp.Expression) -> bool: 380 return extract_date(expression) is not None 381 382 383def extract_interval(expression): 384 try: 385 n = int(expression.this.to_py()) 386 unit = expression.text("unit").lower() 387 return interval(unit, n) 388 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 389 return None 390 391 392def extract_type(*expressions): 393 target_type = None 394 for expression in expressions: 395 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 396 if target_type: 397 break 398 399 return target_type 400 401 402def date_literal(date, target_type=None): 403 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 404 target_type = ( 405 exp.DataType.Type.DATETIME 406 if isinstance(date, datetime.datetime) 407 else exp.DataType.Type.DATE 408 ) 409 410 return exp.cast(exp.Literal.string(date), target_type) 411 412 413def interval(unit: str, n: int = 1): 414 from dateutil.relativedelta import relativedelta 415 416 if unit == "year": 417 return relativedelta(years=1 * n) 418 if unit == "quarter": 419 return relativedelta(months=3 * n) 420 if unit == "month": 421 return relativedelta(months=1 * n) 422 if unit == "week": 423 return relativedelta(weeks=1 * n) 424 if unit == "day": 425 return relativedelta(days=1 * n) 426 if unit == "hour": 427 return relativedelta(hours=1 * n) 428 if unit == "minute": 429 return relativedelta(minutes=1 * n) 430 if unit == "second": 431 return relativedelta(seconds=1 * n) 432 433 raise UnsupportedUnit(f"Unsupported unit: {unit}") 434 435 436def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 437 if unit == "year": 438 return d.replace(month=1, day=1) 439 if unit == "quarter": 440 if d.month <= 3: 441 return d.replace(month=1, day=1) 442 elif d.month <= 6: 443 return d.replace(month=4, day=1) 444 elif d.month <= 9: 445 return d.replace(month=7, day=1) 446 else: 447 return d.replace(month=10, day=1) 448 if unit == "month": 449 return d.replace(month=d.month, day=1) 450 if unit == "week": 451 # Assuming week starts on Monday (0) and ends on Sunday (6) 452 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 453 if unit == "day": 454 return d 455 456 raise UnsupportedUnit(f"Unsupported unit: {unit}") 457 458 459def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 460 floor = date_floor(d, unit, dialect) 461 462 if floor == d: 463 return d 464 465 return floor + interval(unit) 466 467 468def boolean_literal(condition): 469 return exp.true() if condition else exp.false() 470 471 472class Simplifier: 473 def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): 474 self.dialect = Dialect.get_or_raise(dialect) 475 self.annotate_new_expressions = annotate_new_expressions 476 477 self._annotator: TypeAnnotator = TypeAnnotator( 478 schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False 479 ) 480 481 # Value ranges for byte-sized signed/unsigned integers 482 TINYINT_MIN = -128 483 TINYINT_MAX = 127 484 UTINYINT_MIN = 0 485 UTINYINT_MAX = 255 486 487 COMPLEMENT_COMPARISONS = { 488 exp.LT: exp.GTE, 489 exp.GT: exp.LTE, 490 exp.LTE: exp.GT, 491 exp.GTE: exp.LT, 492 exp.EQ: exp.NEQ, 493 exp.NEQ: exp.EQ, 494 } 495 496 COMPLEMENT_SUBQUERY_PREDICATES = { 497 exp.All: exp.Any, 498 exp.Any: exp.All, 499 } 500 501 LT_LTE = (exp.LT, exp.LTE) 502 GT_GTE = (exp.GT, exp.GTE) 503 504 COMPARISONS = ( 505 *LT_LTE, 506 *GT_GTE, 507 exp.EQ, 508 exp.NEQ, 509 exp.Is, 510 ) 511 512 INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 513 exp.LT: exp.GT, 514 exp.GT: exp.LT, 515 exp.LTE: exp.GTE, 516 exp.GTE: exp.LTE, 517 } 518 519 NONDETERMINISTIC = (exp.Rand, exp.Randn) 520 AND_OR = (exp.And, exp.Or) 521 522 INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 523 exp.DateAdd: exp.Sub, 524 exp.DateSub: exp.Add, 525 exp.DatetimeAdd: exp.Sub, 526 exp.DatetimeSub: exp.Add, 527 } 528 529 INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 530 **INVERSE_DATE_OPS, 531 exp.Add: exp.Sub, 532 exp.Sub: exp.Add, 533 } 534 535 NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 536 537 CONCATS = (exp.Concat, exp.DPipe) 538 539 DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 540 exp.LT: lambda l, dt, u, d, t: l 541 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 542 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 543 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 544 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 545 exp.EQ: _datetrunc_eq, 546 exp.NEQ: _datetrunc_neq, 547 } 548 549 DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 550 DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 551 552 SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean) 553 554 # CROSS joins result in an empty table if the right table is empty. 555 # So we can only simplify certain types of joins to CROSS. 556 # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 557 JOINS = { 558 ("", ""), 559 ("", "INNER"), 560 ("RIGHT", ""), 561 ("RIGHT", "OUTER"), 562 } 563 564 def simplify( 565 self, 566 expression: exp.Expression, 567 constant_propagation: bool = False, 568 coalesce_simplification: bool = False, 569 ): 570 wheres = [] 571 joins = [] 572 573 for node in expression.walk( 574 prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) 575 ): 576 if node.meta.get(FINAL): 577 continue 578 579 # group by expressions cannot be simplified, for example 580 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 581 # the projection must exactly match the group by key 582 group = node.args.get("group") 583 584 if group and hasattr(node, "selects"): 585 groups = set(group.expressions) 586 group.meta[FINAL] = True 587 588 for s in node.selects: 589 for n in s.walk(FINAL): 590 if n in groups: 591 s.meta[FINAL] = True 592 break 593 594 having = node.args.get("having") 595 596 if having: 597 for n in having.walk(): 598 if n in groups: 599 having.meta[FINAL] = True 600 break 601 602 if isinstance(node, exp.Condition): 603 simplified = while_changing( 604 node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification) 605 ) 606 607 if node is expression: 608 expression = simplified 609 elif isinstance(node, exp.Where): 610 wheres.append(node) 611 elif isinstance(node, exp.Join): 612 # snowflake match_conditions have very strict ordering rules 613 if match := node.args.get("match_condition"): 614 match.meta[FINAL] = True 615 616 joins.append(node) 617 618 for where in wheres: 619 if always_true(where.this): 620 where.pop() 621 for join in joins: 622 if ( 623 always_true(join.args.get("on")) 624 and not join.args.get("using") 625 and not join.args.get("method") 626 and (join.side, join.kind) in self.JOINS 627 ): 628 join.args["on"].pop() 629 join.set("side", None) 630 join.set("kind", "CROSS") 631 632 return expression 633 634 def _simplify( 635 self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool 636 ): 637 pre_transformation_stack = [expression] 638 post_transformation_stack = [] 639 640 while pre_transformation_stack: 641 original = pre_transformation_stack.pop() 642 node = original 643 644 if not isinstance(node, SIMPLIFIABLE): 645 if isinstance(node, exp.Query): 646 self.simplify(node, constant_propagation, coalesce_simplification) 647 continue 648 649 parent = node.parent 650 root = node is expression 651 652 node = self.rewrite_between(node) 653 node = self.uniq_sort(node, root) 654 node = self.absorb_and_eliminate(node, root) 655 node = self.simplify_concat(node) 656 node = self.simplify_conditionals(node) 657 658 if constant_propagation: 659 node = propagate_constants(node, root) 660 661 if node is not original: 662 original.replace(node) 663 664 for n in node.iter_expressions(reverse=True): 665 if n.meta.get(FINAL): 666 raise 667 pre_transformation_stack.extend( 668 n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 669 ) 670 post_transformation_stack.append((node, parent)) 671 672 while post_transformation_stack: 673 original, parent = post_transformation_stack.pop() 674 root = original is expression 675 676 # Resets parent, arg_key, index pointers– this is needed because some of the 677 # previous transformations mutate the AST, leading to an inconsistent state 678 for k, v in tuple(original.args.items()): 679 original.set(k, v) 680 681 # Post-order transformations 682 node = self.simplify_not(original) 683 node = flatten(node) 684 node = self.simplify_connectors(node, root) 685 node = self.remove_complements(node, root) 686 687 if coalesce_simplification: 688 node = self.simplify_coalesce(node) 689 node.parent = parent 690 691 node = self.simplify_literals(node, root) 692 node = self.simplify_equality(node) 693 node = simplify_parens(node, dialect=self.dialect) 694 node = self.simplify_datetrunc(node) 695 node = self.sort_comparison(node) 696 node = self.simplify_startswith(node) 697 698 if node is not original: 699 original.replace(node) 700 701 return node 702 703 @annotate_types_on_change 704 def rewrite_between(self, expression: exp.Expression) -> exp.Expression: 705 """Rewrite x between y and z to x >= y AND x <= z. 706 707 This is done because comparison simplification is only done on lt/lte/gt/gte. 708 """ 709 if isinstance(expression, exp.Between): 710 negate = isinstance(expression.parent, exp.Not) 711 712 expression = exp.and_( 713 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 714 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 715 copy=False, 716 ) 717 718 if negate: 719 expression = exp.paren(expression, copy=False) 720 721 return expression 722 723 @annotate_types_on_change 724 def simplify_not(self, expression: exp.Expression) -> exp.Expression: 725 """ 726 Demorgan's Law 727 NOT (x OR y) -> NOT x AND NOT y 728 NOT (x AND y) -> NOT x OR NOT y 729 """ 730 if isinstance(expression, exp.Not): 731 this = expression.this 732 if is_null(this): 733 return exp.and_(exp.null(), exp.true(), copy=False) 734 if this.__class__ in self.COMPLEMENT_COMPARISONS: 735 right = this.expression 736 complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( 737 right.__class__ 738 ) 739 if complement_subquery_predicate: 740 right = complement_subquery_predicate(this=right.this) 741 742 return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 743 if isinstance(this, exp.Paren): 744 condition = this.unnest() 745 if isinstance(condition, exp.And): 746 return exp.paren( 747 exp.or_( 748 exp.not_(condition.left, copy=False), 749 exp.not_(condition.right, copy=False), 750 copy=False, 751 ), 752 copy=False, 753 ) 754 if isinstance(condition, exp.Or): 755 return exp.paren( 756 exp.and_( 757 exp.not_(condition.left, copy=False), 758 exp.not_(condition.right, copy=False), 759 copy=False, 760 ), 761 copy=False, 762 ) 763 if is_null(condition): 764 return exp.and_(exp.null(), exp.true(), copy=False) 765 if always_true(this): 766 return exp.false() 767 if is_false(this): 768 return exp.true() 769 if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: 770 inner = this.this 771 if inner.is_type(exp.DataType.Type.BOOLEAN): 772 # double negation 773 # NOT NOT x -> x, if x is BOOLEAN type 774 return inner 775 return expression 776 777 @annotate_types_on_change 778 def simplify_connectors(self, expression, root=True): 779 def _simplify_connectors(expression, left, right): 780 if isinstance(expression, exp.And): 781 if is_false(left) or is_false(right): 782 return exp.false() 783 if is_zero(left) or is_zero(right): 784 return exp.false() 785 if ( 786 (is_null(left) and is_null(right)) 787 or (is_null(left) and always_true(right)) 788 or (always_true(left) and is_null(right)) 789 ): 790 return exp.null() 791 if always_true(left) and always_true(right): 792 return exp.true() 793 if always_true(left): 794 return right 795 if always_true(right): 796 return left 797 return self._simplify_comparison(expression, left, right) 798 elif isinstance(expression, exp.Or): 799 if always_true(left) or always_true(right): 800 return exp.true() 801 if ( 802 (is_null(left) and is_null(right)) 803 or (is_null(left) and always_false(right)) 804 or (always_false(left) and is_null(right)) 805 ): 806 return exp.null() 807 if is_false(left): 808 return right 809 if is_false(right): 810 return left 811 return self._simplify_comparison(expression, left, right, or_=True) 812 813 if isinstance(expression, exp.Connector): 814 original_parent = expression.parent 815 expression = self._flat_simplify(expression, _simplify_connectors, root) 816 817 # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need 818 # to ensure that the resulting type is boolean. We know this is true only for connectors, 819 # boolean values and columns that are essentially operands to a connector: 820 # 821 # A AND (((B))) 822 # ~ this is safe to keep because it will eventually be part of another connector 823 if not isinstance( 824 expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT 825 ) and not expression.is_type(exp.DataType.Type.BOOLEAN): 826 while True: 827 if isinstance(original_parent, exp.Connector): 828 break 829 if not isinstance(original_parent, exp.Paren): 830 expression = expression.and_(exp.true(), copy=False) 831 break 832 833 original_parent = original_parent.parent 834 835 return expression 836 837 @annotate_types_on_change 838 def _simplify_comparison(self, expression, left, right, or_=False): 839 if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): 840 ll, lr = left.args.values() 841 rl, rr = right.args.values() 842 843 largs = {ll, lr} 844 rargs = {rl, rr} 845 846 matching = largs & rargs 847 columns = { 848 m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC) 849 } 850 851 if matching and columns: 852 try: 853 l = first(largs - columns) 854 r = first(rargs - columns) 855 except StopIteration: 856 return expression 857 858 if l.is_number and r.is_number: 859 l = l.to_py() 860 r = r.to_py() 861 elif l.is_string and r.is_string: 862 l = l.name 863 r = r.name 864 else: 865 l = extract_date(l) 866 if not l: 867 return None 868 r = extract_date(r) 869 if not r: 870 return None 871 # python won't compare date and datetime, but many engines will upcast 872 l, r = cast_as_datetime(l), cast_as_datetime(r) 873 874 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 875 if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE): 876 return left if (av > bv if or_ else av <= bv) else right 877 if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE): 878 return left if (av < bv if or_ else av >= bv) else right 879 880 # we can't ever shortcut to true because the column could be null 881 if not or_: 882 if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE): 883 if av <= bv: 884 return exp.false() 885 elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE): 886 if av >= bv: 887 return exp.false() 888 elif isinstance(a, exp.EQ): 889 if isinstance(b, exp.LT): 890 return exp.false() if av >= bv else a 891 if isinstance(b, exp.LTE): 892 return exp.false() if av > bv else a 893 if isinstance(b, exp.GT): 894 return exp.false() if av <= bv else a 895 if isinstance(b, exp.GTE): 896 return exp.false() if av < bv else a 897 if isinstance(b, exp.NEQ): 898 return exp.false() if av == bv else a 899 return None 900 901 @annotate_types_on_change 902 def remove_complements(self, expression, root=True): 903 """ 904 Removing complements. 905 906 A AND NOT A -> FALSE (only for non-NULL A) 907 A OR NOT A -> TRUE (only for non-NULL A) 908 """ 909 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 910 ops = set(expression.flatten()) 911 for op in ops: 912 if isinstance(op, exp.Not) and op.this in ops: 913 if expression.meta.get("nonnull") is True: 914 return exp.false() if isinstance(expression, exp.And) else exp.true() 915 916 return expression 917 918 @annotate_types_on_change 919 def uniq_sort(self, expression, root=True): 920 """ 921 Uniq and sort a connector. 922 923 C AND A AND B AND B -> A AND B AND C 924 """ 925 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 926 flattened = tuple(expression.flatten()) 927 928 if isinstance(expression, exp.Xor): 929 result_func = exp.xor 930 # Do not deduplicate XOR as A XOR A != A if A == True 931 deduped = None 932 arr = tuple((gen(e), e) for e in flattened) 933 else: 934 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 935 deduped = {gen(e): e for e in flattened} 936 arr = tuple(deduped.items()) 937 938 # check if the operands are already sorted, if not sort them 939 # A AND C AND B -> A AND B AND C 940 for i, (sql, e) in enumerate(arr[1:]): 941 if sql < arr[i][0]: 942 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 943 break 944 else: 945 # we didn't have to sort but maybe we need to dedup 946 if deduped and len(deduped) < len(flattened): 947 unique_operand = flattened[0] 948 if len(deduped) == 1: 949 expression = unique_operand.and_(exp.true(), copy=False) 950 else: 951 expression = result_func(*deduped.values(), copy=False) 952 953 return expression 954 955 @annotate_types_on_change 956 def absorb_and_eliminate(self, expression, root=True): 957 """ 958 absorption: 959 A AND (A OR B) -> A 960 A OR (A AND B) -> A 961 A AND (NOT A OR B) -> A AND B 962 A OR (NOT A AND B) -> A OR B 963 elimination: 964 (A AND B) OR (A AND NOT B) -> A 965 (A OR B) AND (A OR NOT B) -> A 966 """ 967 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 968 kind = exp.Or if isinstance(expression, exp.And) else exp.And 969 970 ops = tuple(expression.flatten()) 971 972 # Initialize lookup tables: 973 # Set of all operands, used to find complements for absorption. 974 op_set = set() 975 # Sub-operands, used to find subsets for absorption. 976 subops = defaultdict(list) 977 # Pairs of complements, used for elimination. 978 pairs = defaultdict(list) 979 980 # Populate the lookup tables 981 for op in ops: 982 op_set.add(op) 983 984 if not isinstance(op, kind): 985 # In cases like: A OR (A AND B) 986 # Subop will be: ^ 987 subops[op].append({op}) 988 continue 989 990 # In cases like: (A AND B) OR (A AND B AND C) 991 # Subops will be: ^ ^ 992 subset = set(op.flatten()) 993 for i in subset: 994 subops[i].append(subset) 995 996 a, b = op.unnest_operands() 997 if isinstance(a, exp.Not): 998 pairs[frozenset((a.this, b))].append((op, b)) 999 if isinstance(b, exp.Not): 1000 pairs[frozenset((a, b.this))].append((op, a)) 1001 1002 for op in ops: 1003 if not isinstance(op, kind): 1004 continue 1005 1006 a, b = op.unnest_operands() 1007 1008 # Absorb 1009 if isinstance(a, exp.Not) and a.this in op_set: 1010 a.replace(exp.true() if kind == exp.And else exp.false()) 1011 continue 1012 if isinstance(b, exp.Not) and b.this in op_set: 1013 b.replace(exp.true() if kind == exp.And else exp.false()) 1014 continue 1015 superset = set(op.flatten()) 1016 if any(any(subset < superset for subset in subops[i]) for i in superset): 1017 op.replace(exp.false() if kind == exp.And else exp.true()) 1018 continue 1019 1020 # Eliminate 1021 for other, complement in pairs[frozenset((a, b))]: 1022 op.replace(complement) 1023 other.replace(complement) 1024 1025 return expression 1026 1027 @annotate_types_on_change 1028 @catch(ModuleNotFoundError, UnsupportedUnit) 1029 def simplify_equality(self, expression: exp.Expression) -> exp.Expression: 1030 """ 1031 Use the subtraction and addition properties of equality to simplify expressions: 1032 1033 x + 1 = 3 becomes x = 2 1034 1035 There are two binary operations in the above expression: + and = 1036 Here's how we reference all the operands in the code below: 1037 1038 l r 1039 x + 1 = 3 1040 a b 1041 """ 1042 if isinstance(expression, self.COMPARISONS): 1043 l, r = expression.left, expression.right 1044 1045 if l.__class__ not in self.INVERSE_OPS: 1046 return expression 1047 1048 if r.is_number: 1049 a_predicate = _is_number 1050 b_predicate = _is_number 1051 elif _is_date_literal(r): 1052 a_predicate = _is_date_literal 1053 b_predicate = _is_interval 1054 else: 1055 return expression 1056 1057 if l.__class__ in self.INVERSE_DATE_OPS: 1058 l = t.cast(exp.IntervalOp, l) 1059 a = l.this 1060 b = l.interval() 1061 else: 1062 l = t.cast(exp.Binary, l) 1063 a, b = l.left, l.right 1064 1065 if not a_predicate(a) and b_predicate(b): 1066 pass 1067 elif not a_predicate(b) and b_predicate(a): 1068 a, b = b, a 1069 else: 1070 return expression 1071 1072 return expression.__class__( 1073 this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b) 1074 ) 1075 return expression 1076 1077 @annotate_types_on_change 1078 def simplify_literals(self, expression, root=True): 1079 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 1080 return self._flat_simplify(expression, self._simplify_binary, root) 1081 1082 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 1083 return expression.this.this 1084 1085 if type(expression) in self.INVERSE_DATE_OPS: 1086 return ( 1087 self._simplify_binary(expression, expression.this, expression.interval()) 1088 or expression 1089 ) 1090 1091 return expression 1092 1093 def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression: 1094 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 1095 this = self._simplify_integer_cast(expr.this) 1096 else: 1097 this = expr.this 1098 1099 if isinstance(expr, exp.Cast) and this.is_int: 1100 num = this.to_py() 1101 1102 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 1103 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 1104 # engine-dependent 1105 if ( 1106 self.TINYINT_MIN <= num <= self.TINYINT_MAX 1107 and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 1108 ) or ( 1109 self.UTINYINT_MIN <= num <= self.UTINYINT_MAX 1110 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 1111 ): 1112 return this 1113 1114 return expr 1115 1116 def _simplify_binary(self, expression, a, b): 1117 if isinstance(expression, self.COMPARISONS): 1118 a = self._simplify_integer_cast(a) 1119 b = self._simplify_integer_cast(b) 1120 1121 if isinstance(expression, exp.Is): 1122 if isinstance(b, exp.Not): 1123 c = b.this 1124 not_ = True 1125 else: 1126 c = b 1127 not_ = False 1128 1129 if is_null(c): 1130 if isinstance(a, exp.Literal): 1131 return exp.true() if not_ else exp.false() 1132 if is_null(a): 1133 return exp.false() if not_ else exp.true() 1134 elif isinstance(expression, self.NULL_OK): 1135 return None 1136 elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If): 1137 return exp.null() 1138 1139 if a.is_number and b.is_number: 1140 num_a = a.to_py() 1141 num_b = b.to_py() 1142 1143 if isinstance(expression, exp.Add): 1144 return exp.Literal.number(num_a + num_b) 1145 if isinstance(expression, exp.Mul): 1146 return exp.Literal.number(num_a * num_b) 1147 1148 # We only simplify Sub, Div if a and b have the same parent because they're not associative 1149 if isinstance(expression, exp.Sub): 1150 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 1151 if isinstance(expression, exp.Div): 1152 # engines have differing int div behavior so intdiv is not safe 1153 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 1154 return None 1155 return exp.Literal.number(num_a / num_b) 1156 1157 boolean = eval_boolean(expression, num_a, num_b) 1158 1159 if boolean: 1160 return boolean 1161 elif a.is_string and b.is_string: 1162 boolean = eval_boolean(expression, a.this, b.this) 1163 1164 if boolean: 1165 return boolean 1166 elif _is_date_literal(a) and isinstance(b, exp.Interval): 1167 date, b = extract_date(a), extract_interval(b) 1168 if date and b: 1169 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 1170 return date_literal(date + b, extract_type(a)) 1171 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 1172 return date_literal(date - b, extract_type(a)) 1173 elif isinstance(a, exp.Interval) and _is_date_literal(b): 1174 a, date = extract_interval(a), extract_date(b) 1175 # you cannot subtract a date from an interval 1176 if a and b and isinstance(expression, exp.Add): 1177 return date_literal(a + date, extract_type(b)) 1178 elif _is_date_literal(a) and _is_date_literal(b): 1179 if isinstance(expression, exp.Predicate): 1180 a, b = extract_date(a), extract_date(b) 1181 boolean = eval_boolean(expression, a, b) 1182 if boolean: 1183 return boolean 1184 1185 return None 1186 1187 @annotate_types_on_change 1188 def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression: 1189 # COALESCE(x) -> x 1190 if ( 1191 isinstance(expression, exp.Coalesce) 1192 and (not expression.expressions or _is_nonnull_constant(expression.this)) 1193 # COALESCE is also used as a Spark partitioning hint 1194 and not isinstance(expression.parent, exp.Hint) 1195 ): 1196 return expression.this 1197 1198 if self.dialect.COALESCE_COMPARISON_NON_STANDARD: 1199 return expression 1200 1201 if not isinstance(expression, self.COMPARISONS): 1202 return expression 1203 1204 if isinstance(expression.left, exp.Coalesce): 1205 coalesce = expression.left 1206 other = expression.right 1207 elif isinstance(expression.right, exp.Coalesce): 1208 coalesce = expression.right 1209 other = expression.left 1210 else: 1211 return expression 1212 1213 # This transformation is valid for non-constants, 1214 # but it really only does anything if they are both constants. 1215 if not _is_constant(other): 1216 return expression 1217 1218 # Find the first constant arg 1219 for arg_index, arg in enumerate(coalesce.expressions): 1220 if _is_constant(arg): 1221 break 1222 else: 1223 return expression 1224 1225 coalesce.set("expressions", coalesce.expressions[:arg_index]) 1226 1227 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 1228 # since we already remove COALESCE at the top of this function. 1229 coalesce = coalesce if coalesce.expressions else coalesce.this 1230 1231 # This expression is more complex than when we started, but it will get simplified further 1232 return exp.paren( 1233 exp.or_( 1234 exp.and_( 1235 coalesce.is_(exp.null()).not_(copy=False), 1236 expression.copy(), 1237 copy=False, 1238 ), 1239 exp.and_( 1240 coalesce.is_(exp.null()), 1241 type(expression)(this=arg.copy(), expression=other.copy()), 1242 copy=False, 1243 ), 1244 copy=False, 1245 ), 1246 copy=False, 1247 ) 1248 1249 @annotate_types_on_change 1250 def simplify_concat(self, expression): 1251 """Reduces all groups that contain string literals by concatenating them.""" 1252 if not isinstance(expression, self.CONCATS) or ( 1253 # We can't reduce a CONCAT_WS call if we don't statically know the separator 1254 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 1255 ): 1256 return expression 1257 1258 if isinstance(expression, exp.ConcatWs): 1259 sep_expr, *expressions = expression.expressions 1260 sep = sep_expr.name 1261 concat_type = exp.ConcatWs 1262 args = {} 1263 else: 1264 expressions = expression.expressions 1265 sep = "" 1266 concat_type = exp.Concat 1267 args = { 1268 "safe": expression.args.get("safe"), 1269 "coalesce": expression.args.get("coalesce"), 1270 } 1271 1272 new_args = [] 1273 for is_string_group, group in itertools.groupby( 1274 expressions or expression.flatten(), lambda e: e.is_string 1275 ): 1276 if is_string_group: 1277 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 1278 else: 1279 new_args.extend(group) 1280 1281 if len(new_args) == 1 and new_args[0].is_string: 1282 return new_args[0] 1283 1284 if concat_type is exp.ConcatWs: 1285 new_args = [sep_expr] + new_args 1286 elif isinstance(expression, exp.DPipe): 1287 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 1288 1289 return concat_type(expressions=new_args, **args) 1290 1291 @annotate_types_on_change 1292 def simplify_conditionals(self, expression): 1293 """Simplifies expressions like IF, CASE if their condition is statically known.""" 1294 if isinstance(expression, exp.Case): 1295 this = expression.this 1296 for case in expression.args["ifs"]: 1297 cond = case.this 1298 if this: 1299 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 1300 cond = cond.replace(this.pop().eq(cond)) 1301 1302 if always_true(cond): 1303 return case.args["true"] 1304 1305 if always_false(cond): 1306 case.pop() 1307 if not expression.args["ifs"]: 1308 return expression.args.get("default") or exp.null() 1309 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 1310 if always_true(expression.this): 1311 return expression.args["true"] 1312 if always_false(expression.this): 1313 return expression.args.get("false") or exp.null() 1314 1315 return expression 1316 1317 @annotate_types_on_change 1318 def simplify_startswith(self, expression: exp.Expression) -> exp.Expression: 1319 """ 1320 Reduces a prefix check to either TRUE or FALSE if both the string and the 1321 prefix are statically known. 1322 1323 Example: 1324 >>> from sqlglot import parse_one 1325 >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 1326 'TRUE' 1327 """ 1328 if ( 1329 isinstance(expression, exp.StartsWith) 1330 and expression.this.is_string 1331 and expression.expression.is_string 1332 ): 1333 return exp.convert(expression.name.startswith(expression.expression.name)) 1334 1335 return expression 1336 1337 def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool: 1338 return isinstance(left, self.DATETRUNCS) and _is_date_literal(right) 1339 1340 @annotate_types_on_change 1341 @catch(ModuleNotFoundError, UnsupportedUnit) 1342 def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression: 1343 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1344 comparison = expression.__class__ 1345 1346 if isinstance(expression, self.DATETRUNCS): 1347 this = expression.this 1348 trunc_type = extract_type(this) 1349 date = extract_date(this) 1350 if date and expression.unit: 1351 return date_literal( 1352 date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type 1353 ) 1354 elif comparison not in self.DATETRUNC_COMPARISONS: 1355 return expression 1356 1357 if isinstance(expression, exp.Binary): 1358 l, r = expression.left, expression.right 1359 1360 if not self._is_datetrunc_predicate(l, r): 1361 return expression 1362 1363 l = t.cast(exp.DateTrunc, l) 1364 trunc_arg = l.this 1365 unit = l.unit.name.lower() 1366 date = extract_date(r) 1367 1368 if not date: 1369 return expression 1370 1371 return ( 1372 self.DATETRUNC_BINARY_COMPARISONS[comparison]( 1373 trunc_arg, date, unit, self.dialect, extract_type(r) 1374 ) 1375 or expression 1376 ) 1377 1378 if isinstance(expression, exp.In): 1379 l = expression.this 1380 rs = expression.expressions 1381 1382 if rs and all(self._is_datetrunc_predicate(l, r) for r in rs): 1383 l = t.cast(exp.DateTrunc, l) 1384 unit = l.unit.name.lower() 1385 1386 ranges = [] 1387 for r in rs: 1388 date = extract_date(r) 1389 if not date: 1390 return expression 1391 drange = _datetrunc_range(date, unit, self.dialect) 1392 if drange: 1393 ranges.append(drange) 1394 1395 if not ranges: 1396 return expression 1397 1398 ranges = merge_ranges(ranges) 1399 target_type = extract_type(*rs) 1400 1401 return exp.or_( 1402 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], 1403 copy=False, 1404 ) 1405 1406 return expression 1407 1408 @annotate_types_on_change 1409 def sort_comparison(self, expression: exp.Expression) -> exp.Expression: 1410 if expression.__class__ in self.COMPLEMENT_COMPARISONS: 1411 l, r = expression.this, expression.expression 1412 l_column = isinstance(l, exp.Column) 1413 r_column = isinstance(r, exp.Column) 1414 l_const = _is_constant(l) 1415 r_const = _is_constant(r) 1416 1417 if ( 1418 (l_column and not r_column) 1419 or (r_const and not l_const) 1420 or isinstance(r, exp.SubqueryPredicate) 1421 ): 1422 return expression 1423 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1424 return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1425 this=r, expression=l 1426 ) 1427 return expression 1428 1429 def _flat_simplify(self, expression, simplifier, root=True): 1430 if root or not expression.same_parent: 1431 operands = [] 1432 queue = deque(expression.flatten(unnest=False)) 1433 size = len(queue) 1434 1435 while queue: 1436 a = queue.popleft() 1437 1438 for b in queue: 1439 result = simplifier(expression, a, b) 1440 1441 if result and result is not expression: 1442 queue.remove(b) 1443 queue.appendleft(result) 1444 break 1445 else: 1446 operands.append(a) 1447 1448 if len(operands) < size: 1449 return functools.reduce( 1450 lambda a, b: expression.__class__(this=a, expression=b), operands 1451 ) 1452 return expression 1453 1454 1455def gen(expression: t.Any, comments: bool = False) -> str: 1456 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1457 1458 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1459 generator is expensive so we have a bare minimum sql generator here. 1460 1461 Args: 1462 expression: the expression to convert into a SQL string. 1463 comments: whether to include the expression's comments. 1464 """ 1465 return Gen().gen(expression, comments=comments) 1466 1467 1468class Gen: 1469 def __init__(self): 1470 self.stack = [] 1471 self.sqls = [] 1472 1473 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1474 self.stack = [expression] 1475 self.sqls.clear() 1476 1477 while self.stack: 1478 node = self.stack.pop() 1479 1480 if isinstance(node, exp.Expression): 1481 if comments and node.comments: 1482 self.stack.append(f" /*{','.join(node.comments)}*/") 1483 1484 exp_handler_name = f"{node.key}_sql" 1485 1486 if hasattr(self, exp_handler_name): 1487 getattr(self, exp_handler_name)(node) 1488 elif isinstance(node, exp.Func): 1489 self._function(node) 1490 else: 1491 key = node.key.upper() 1492 self.stack.append(f"{key} " if self._args(node) else key) 1493 elif type(node) is list: 1494 for n in reversed(node): 1495 if n is not None: 1496 self.stack.extend((n, ",")) 1497 if node: 1498 self.stack.pop() 1499 else: 1500 if node is not None: 1501 self.sqls.append(str(node)) 1502 1503 return "".join(self.sqls) 1504 1505 def add_sql(self, e: exp.Add) -> None: 1506 self._binary(e, " + ") 1507 1508 def alias_sql(self, e: exp.Alias) -> None: 1509 self.stack.extend( 1510 ( 1511 e.args.get("alias"), 1512 " AS ", 1513 e.args.get("this"), 1514 ) 1515 ) 1516 1517 def and_sql(self, e: exp.And) -> None: 1518 self._binary(e, " AND ") 1519 1520 def anonymous_sql(self, e: exp.Anonymous) -> None: 1521 this = e.this 1522 if isinstance(this, str): 1523 name = this.upper() 1524 elif isinstance(this, exp.Identifier): 1525 name = this.this 1526 name = f'"{name}"' if this.quoted else name.upper() 1527 else: 1528 raise ValueError( 1529 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1530 ) 1531 1532 self.stack.extend( 1533 ( 1534 ")", 1535 e.expressions, 1536 "(", 1537 name, 1538 ) 1539 ) 1540 1541 def between_sql(self, e: exp.Between) -> None: 1542 self.stack.extend( 1543 ( 1544 e.args.get("high"), 1545 " AND ", 1546 e.args.get("low"), 1547 " BETWEEN ", 1548 e.this, 1549 ) 1550 ) 1551 1552 def boolean_sql(self, e: exp.Boolean) -> None: 1553 self.stack.append("TRUE" if e.this else "FALSE") 1554 1555 def bracket_sql(self, e: exp.Bracket) -> None: 1556 self.stack.extend( 1557 ( 1558 "]", 1559 e.expressions, 1560 "[", 1561 e.this, 1562 ) 1563 ) 1564 1565 def column_sql(self, e: exp.Column) -> None: 1566 for p in reversed(e.parts): 1567 self.stack.extend((p, ".")) 1568 self.stack.pop() 1569 1570 def datatype_sql(self, e: exp.DataType) -> None: 1571 self._args(e, 1) 1572 self.stack.append(f"{e.this.name} ") 1573 1574 def div_sql(self, e: exp.Div) -> None: 1575 self._binary(e, " / ") 1576 1577 def dot_sql(self, e: exp.Dot) -> None: 1578 self._binary(e, ".") 1579 1580 def eq_sql(self, e: exp.EQ) -> None: 1581 self._binary(e, " = ") 1582 1583 def from_sql(self, e: exp.From) -> None: 1584 self.stack.extend((e.this, "FROM ")) 1585 1586 def gt_sql(self, e: exp.GT) -> None: 1587 self._binary(e, " > ") 1588 1589 def gte_sql(self, e: exp.GTE) -> None: 1590 self._binary(e, " >= ") 1591 1592 def identifier_sql(self, e: exp.Identifier) -> None: 1593 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1594 1595 def ilike_sql(self, e: exp.ILike) -> None: 1596 self._binary(e, " ILIKE ") 1597 1598 def in_sql(self, e: exp.In) -> None: 1599 self.stack.append(")") 1600 self._args(e, 1) 1601 self.stack.extend( 1602 ( 1603 "(", 1604 " IN ", 1605 e.this, 1606 ) 1607 ) 1608 1609 def intdiv_sql(self, e: exp.IntDiv) -> None: 1610 self._binary(e, " DIV ") 1611 1612 def is_sql(self, e: exp.Is) -> None: 1613 self._binary(e, " IS ") 1614 1615 def like_sql(self, e: exp.Like) -> None: 1616 self._binary(e, " Like ") 1617 1618 def literal_sql(self, e: exp.Literal) -> None: 1619 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1620 1621 def lt_sql(self, e: exp.LT) -> None: 1622 self._binary(e, " < ") 1623 1624 def lte_sql(self, e: exp.LTE) -> None: 1625 self._binary(e, " <= ") 1626 1627 def mod_sql(self, e: exp.Mod) -> None: 1628 self._binary(e, " % ") 1629 1630 def mul_sql(self, e: exp.Mul) -> None: 1631 self._binary(e, " * ") 1632 1633 def neg_sql(self, e: exp.Neg) -> None: 1634 self._unary(e, "-") 1635 1636 def neq_sql(self, e: exp.NEQ) -> None: 1637 self._binary(e, " <> ") 1638 1639 def not_sql(self, e: exp.Not) -> None: 1640 self._unary(e, "NOT ") 1641 1642 def null_sql(self, e: exp.Null) -> None: 1643 self.stack.append("NULL") 1644 1645 def or_sql(self, e: exp.Or) -> None: 1646 self._binary(e, " OR ") 1647 1648 def paren_sql(self, e: exp.Paren) -> None: 1649 self.stack.extend( 1650 ( 1651 ")", 1652 e.this, 1653 "(", 1654 ) 1655 ) 1656 1657 def sub_sql(self, e: exp.Sub) -> None: 1658 self._binary(e, " - ") 1659 1660 def subquery_sql(self, e: exp.Subquery) -> None: 1661 self._args(e, 2) 1662 alias = e.args.get("alias") 1663 if alias: 1664 self.stack.append(alias) 1665 self.stack.extend((")", e.this, "(")) 1666 1667 def table_sql(self, e: exp.Table) -> None: 1668 self._args(e, 4) 1669 alias = e.args.get("alias") 1670 if alias: 1671 self.stack.append(alias) 1672 for p in reversed(e.parts): 1673 self.stack.extend((p, ".")) 1674 self.stack.pop() 1675 1676 def tablealias_sql(self, e: exp.TableAlias) -> None: 1677 columns = e.columns 1678 1679 if columns: 1680 self.stack.extend((")", columns, "(")) 1681 1682 self.stack.extend((e.this, " AS ")) 1683 1684 def var_sql(self, e: exp.Var) -> None: 1685 self.stack.append(e.this) 1686 1687 def _binary(self, e: exp.Binary, op: str) -> None: 1688 self.stack.extend((e.expression, op, e.this)) 1689 1690 def _unary(self, e: exp.Unary, op: str) -> None: 1691 self.stack.extend((e.this, op)) 1692 1693 def _function(self, e: exp.Func) -> None: 1694 self.stack.extend( 1695 ( 1696 ")", 1697 list(e.args.values()), 1698 "(", 1699 e.sql_name(), 1700 ) 1701 ) 1702 1703 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1704 kvs = [] 1705 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1706 1707 for k in arg_types: 1708 v = node.args.get(k) 1709 1710 if v is not None: 1711 kvs.append([f":{k}", v]) 1712 if kvs: 1713 self.stack.append(kvs) 1714 return True 1715 return False
43def simplify( 44 expression: exp.Expression, 45 constant_propagation: bool = False, 46 coalesce_simplification: bool = False, 47 dialect: DialectType = None, 48): 49 """ 50 Rewrite sqlglot AST to simplify expressions. 51 52 Example: 53 >>> import sqlglot 54 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 55 >>> simplify(expression).sql() 56 'TRUE' 57 58 Args: 59 expression: expression to simplify 60 constant_propagation: whether the constant propagation rule should be used 61 coalesce_simplification: whether the simplify coalesce rule should be used. 62 This rule tries to remove coalesce functions, which can be useful in certain analyses but 63 can leave the query more verbose. 64 Returns: 65 sqlglot.Expression: simplified expression 66 """ 67 return Simplifier(dialect=dialect).simplify( 68 expression, 69 constant_propagation=constant_propagation, 70 coalesce_simplification=coalesce_simplification, 71 )
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:
sqlglot.Expression: simplified expression
Common base class for all non-exit exceptions.
78def catch(*exceptions): 79 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 80 81 def decorator(func): 82 def wrapped(expression, *args, **kwargs): 83 try: 84 return func(expression, *args, **kwargs) 85 except exceptions: 86 return expression 87 88 return wrapped 89 90 return decorator
Decorator that ignores a simplification function if any of exceptions are raised
93def annotate_types_on_change(func): 94 @wraps(func) 95 def _func(self, expression: exp.Expression, *args, **kwargs) -> t.Optional[exp.Expression]: 96 new_expression = func(self, expression, *args, **kwargs) 97 98 if new_expression is None: 99 return new_expression 100 101 if self.annotate_new_expressions and expression != new_expression: 102 self._annotator.clear() 103 104 # We annotate this to ensure new children nodes are also annotated 105 new_expression = self._annotator.annotate( 106 expression=new_expression, 107 annotate_scope=False, 108 ) 109 110 # Whatever expression the original expression is transformed into needs to preserve 111 # the original type, otherwise the simplification could result in a different schema 112 new_expression.type = expression.type 113 114 return new_expression 115 116 return _func
119def flatten(expression): 120 """ 121 A AND (B AND C) -> A AND B AND C 122 A OR (B OR C) -> A OR B OR C 123 """ 124 if isinstance(expression, exp.Connector): 125 for node in expression.args.values(): 126 child = node.unnest() 127 if isinstance(child, expression.__class__): 128 node.replace(child) 129 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
132def simplify_parens(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 133 if not isinstance(expression, exp.Paren): 134 return expression 135 136 this = expression.this 137 parent = expression.parent 138 parent_is_predicate = isinstance(parent, exp.Predicate) 139 140 if isinstance(this, exp.Select): 141 return expression 142 143 if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): 144 return expression 145 146 if ( 147 Dialect.get_or_raise(dialect).REQUIRES_PARENTHESIZED_STRUCT_ACCESS 148 and isinstance(parent, exp.Dot) 149 and (isinstance(parent.right, (exp.Identifier, exp.Star))) 150 ): 151 return expression 152 153 if ( 154 not isinstance(parent, (exp.Condition, exp.Binary)) 155 or isinstance(parent, exp.Paren) 156 or ( 157 not isinstance(this, exp.Binary) 158 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 159 ) 160 or ( 161 isinstance(this, exp.Predicate) 162 and not (parent_is_predicate or isinstance(parent, exp.Neg)) 163 ) 164 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 165 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 166 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 167 ): 168 return this 169 170 return expression
173def propagate_constants(expression, root=True): 174 """ 175 Propagate constants for conjunctions in DNF: 176 177 SELECT * FROM t WHERE a = b AND b = 5 becomes 178 SELECT * FROM t WHERE a = 5 AND b = 5 179 180 Reference: https://www.sqlite.org/optoverview.html 181 """ 182 183 if ( 184 isinstance(expression, exp.And) 185 and (root or not expression.same_parent) 186 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 187 ): 188 constant_mapping = {} 189 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 190 if isinstance(expr, exp.EQ): 191 l, r = expr.left, expr.right 192 193 # TODO: create a helper that can be used to detect nested literal expressions such 194 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 195 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 196 constant_mapping[l] = (id(l), r) 197 198 if constant_mapping: 199 for column in find_all_in_scope(expression, exp.Column): 200 parent = column.parent 201 column_id, constant = constant_mapping.get(column) or (None, None) 202 if ( 203 column_id is not None 204 and id(column) != column_id 205 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 206 ): 207 column.replace(constant.copy()) 208 209 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
315def eval_boolean(expression, a, b): 316 if isinstance(expression, (exp.EQ, exp.Is)): 317 return boolean_literal(a == b) 318 if isinstance(expression, exp.NEQ): 319 return boolean_literal(a != b) 320 if isinstance(expression, exp.GT): 321 return boolean_literal(a > b) 322 if isinstance(expression, exp.GTE): 323 return boolean_literal(a >= b) 324 if isinstance(expression, exp.LT): 325 return boolean_literal(a < b) 326 if isinstance(expression, exp.LTE): 327 return boolean_literal(a <= b) 328 return None
342def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 343 if isinstance(value, datetime.datetime): 344 return value 345 if isinstance(value, datetime.date): 346 return datetime.datetime(year=value.year, month=value.month, day=value.day) 347 try: 348 return datetime.datetime.fromisoformat(value) 349 except ValueError: 350 return None
353def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 354 if not value: 355 return None 356 if to.is_type(exp.DataType.Type.DATE): 357 return cast_as_date(value) 358 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 359 return cast_as_datetime(value) 360 return None
363def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 364 if isinstance(cast, exp.Cast): 365 to = cast.to 366 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 367 to = exp.DataType.build(exp.DataType.Type.DATE) 368 else: 369 return None 370 371 if isinstance(cast.this, exp.Literal): 372 value: t.Any = cast.this.name 373 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 374 value = extract_date(cast.this) 375 else: 376 return None 377 return cast_value(value, to)
403def date_literal(date, target_type=None): 404 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 405 target_type = ( 406 exp.DataType.Type.DATETIME 407 if isinstance(date, datetime.datetime) 408 else exp.DataType.Type.DATE 409 ) 410 411 return exp.cast(exp.Literal.string(date), target_type)
414def interval(unit: str, n: int = 1): 415 from dateutil.relativedelta import relativedelta 416 417 if unit == "year": 418 return relativedelta(years=1 * n) 419 if unit == "quarter": 420 return relativedelta(months=3 * n) 421 if unit == "month": 422 return relativedelta(months=1 * n) 423 if unit == "week": 424 return relativedelta(weeks=1 * n) 425 if unit == "day": 426 return relativedelta(days=1 * n) 427 if unit == "hour": 428 return relativedelta(hours=1 * n) 429 if unit == "minute": 430 return relativedelta(minutes=1 * n) 431 if unit == "second": 432 return relativedelta(seconds=1 * n) 433 434 raise UnsupportedUnit(f"Unsupported unit: {unit}")
437def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 438 if unit == "year": 439 return d.replace(month=1, day=1) 440 if unit == "quarter": 441 if d.month <= 3: 442 return d.replace(month=1, day=1) 443 elif d.month <= 6: 444 return d.replace(month=4, day=1) 445 elif d.month <= 9: 446 return d.replace(month=7, day=1) 447 else: 448 return d.replace(month=10, day=1) 449 if unit == "month": 450 return d.replace(month=d.month, day=1) 451 if unit == "week": 452 # Assuming week starts on Monday (0) and ends on Sunday (6) 453 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 454 if unit == "day": 455 return d 456 457 raise UnsupportedUnit(f"Unsupported unit: {unit}")
473class Simplifier: 474 def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): 475 self.dialect = Dialect.get_or_raise(dialect) 476 self.annotate_new_expressions = annotate_new_expressions 477 478 self._annotator: TypeAnnotator = TypeAnnotator( 479 schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False 480 ) 481 482 # Value ranges for byte-sized signed/unsigned integers 483 TINYINT_MIN = -128 484 TINYINT_MAX = 127 485 UTINYINT_MIN = 0 486 UTINYINT_MAX = 255 487 488 COMPLEMENT_COMPARISONS = { 489 exp.LT: exp.GTE, 490 exp.GT: exp.LTE, 491 exp.LTE: exp.GT, 492 exp.GTE: exp.LT, 493 exp.EQ: exp.NEQ, 494 exp.NEQ: exp.EQ, 495 } 496 497 COMPLEMENT_SUBQUERY_PREDICATES = { 498 exp.All: exp.Any, 499 exp.Any: exp.All, 500 } 501 502 LT_LTE = (exp.LT, exp.LTE) 503 GT_GTE = (exp.GT, exp.GTE) 504 505 COMPARISONS = ( 506 *LT_LTE, 507 *GT_GTE, 508 exp.EQ, 509 exp.NEQ, 510 exp.Is, 511 ) 512 513 INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 514 exp.LT: exp.GT, 515 exp.GT: exp.LT, 516 exp.LTE: exp.GTE, 517 exp.GTE: exp.LTE, 518 } 519 520 NONDETERMINISTIC = (exp.Rand, exp.Randn) 521 AND_OR = (exp.And, exp.Or) 522 523 INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 524 exp.DateAdd: exp.Sub, 525 exp.DateSub: exp.Add, 526 exp.DatetimeAdd: exp.Sub, 527 exp.DatetimeSub: exp.Add, 528 } 529 530 INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 531 **INVERSE_DATE_OPS, 532 exp.Add: exp.Sub, 533 exp.Sub: exp.Add, 534 } 535 536 NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 537 538 CONCATS = (exp.Concat, exp.DPipe) 539 540 DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 541 exp.LT: lambda l, dt, u, d, t: l 542 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 543 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 544 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 545 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 546 exp.EQ: _datetrunc_eq, 547 exp.NEQ: _datetrunc_neq, 548 } 549 550 DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 551 DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 552 553 SAFE_CONNECTOR_ELIMINATION_RESULT = (exp.Connector, exp.Boolean) 554 555 # CROSS joins result in an empty table if the right table is empty. 556 # So we can only simplify certain types of joins to CROSS. 557 # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 558 JOINS = { 559 ("", ""), 560 ("", "INNER"), 561 ("RIGHT", ""), 562 ("RIGHT", "OUTER"), 563 } 564 565 def simplify( 566 self, 567 expression: exp.Expression, 568 constant_propagation: bool = False, 569 coalesce_simplification: bool = False, 570 ): 571 wheres = [] 572 joins = [] 573 574 for node in expression.walk( 575 prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) 576 ): 577 if node.meta.get(FINAL): 578 continue 579 580 # group by expressions cannot be simplified, for example 581 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 582 # the projection must exactly match the group by key 583 group = node.args.get("group") 584 585 if group and hasattr(node, "selects"): 586 groups = set(group.expressions) 587 group.meta[FINAL] = True 588 589 for s in node.selects: 590 for n in s.walk(FINAL): 591 if n in groups: 592 s.meta[FINAL] = True 593 break 594 595 having = node.args.get("having") 596 597 if having: 598 for n in having.walk(): 599 if n in groups: 600 having.meta[FINAL] = True 601 break 602 603 if isinstance(node, exp.Condition): 604 simplified = while_changing( 605 node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification) 606 ) 607 608 if node is expression: 609 expression = simplified 610 elif isinstance(node, exp.Where): 611 wheres.append(node) 612 elif isinstance(node, exp.Join): 613 # snowflake match_conditions have very strict ordering rules 614 if match := node.args.get("match_condition"): 615 match.meta[FINAL] = True 616 617 joins.append(node) 618 619 for where in wheres: 620 if always_true(where.this): 621 where.pop() 622 for join in joins: 623 if ( 624 always_true(join.args.get("on")) 625 and not join.args.get("using") 626 and not join.args.get("method") 627 and (join.side, join.kind) in self.JOINS 628 ): 629 join.args["on"].pop() 630 join.set("side", None) 631 join.set("kind", "CROSS") 632 633 return expression 634 635 def _simplify( 636 self, expression: exp.Expression, constant_propagation: bool, coalesce_simplification: bool 637 ): 638 pre_transformation_stack = [expression] 639 post_transformation_stack = [] 640 641 while pre_transformation_stack: 642 original = pre_transformation_stack.pop() 643 node = original 644 645 if not isinstance(node, SIMPLIFIABLE): 646 if isinstance(node, exp.Query): 647 self.simplify(node, constant_propagation, coalesce_simplification) 648 continue 649 650 parent = node.parent 651 root = node is expression 652 653 node = self.rewrite_between(node) 654 node = self.uniq_sort(node, root) 655 node = self.absorb_and_eliminate(node, root) 656 node = self.simplify_concat(node) 657 node = self.simplify_conditionals(node) 658 659 if constant_propagation: 660 node = propagate_constants(node, root) 661 662 if node is not original: 663 original.replace(node) 664 665 for n in node.iter_expressions(reverse=True): 666 if n.meta.get(FINAL): 667 raise 668 pre_transformation_stack.extend( 669 n for n in node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 670 ) 671 post_transformation_stack.append((node, parent)) 672 673 while post_transformation_stack: 674 original, parent = post_transformation_stack.pop() 675 root = original is expression 676 677 # Resets parent, arg_key, index pointers– this is needed because some of the 678 # previous transformations mutate the AST, leading to an inconsistent state 679 for k, v in tuple(original.args.items()): 680 original.set(k, v) 681 682 # Post-order transformations 683 node = self.simplify_not(original) 684 node = flatten(node) 685 node = self.simplify_connectors(node, root) 686 node = self.remove_complements(node, root) 687 688 if coalesce_simplification: 689 node = self.simplify_coalesce(node) 690 node.parent = parent 691 692 node = self.simplify_literals(node, root) 693 node = self.simplify_equality(node) 694 node = simplify_parens(node, dialect=self.dialect) 695 node = self.simplify_datetrunc(node) 696 node = self.sort_comparison(node) 697 node = self.simplify_startswith(node) 698 699 if node is not original: 700 original.replace(node) 701 702 return node 703 704 @annotate_types_on_change 705 def rewrite_between(self, expression: exp.Expression) -> exp.Expression: 706 """Rewrite x between y and z to x >= y AND x <= z. 707 708 This is done because comparison simplification is only done on lt/lte/gt/gte. 709 """ 710 if isinstance(expression, exp.Between): 711 negate = isinstance(expression.parent, exp.Not) 712 713 expression = exp.and_( 714 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 715 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 716 copy=False, 717 ) 718 719 if negate: 720 expression = exp.paren(expression, copy=False) 721 722 return expression 723 724 @annotate_types_on_change 725 def simplify_not(self, expression: exp.Expression) -> exp.Expression: 726 """ 727 Demorgan's Law 728 NOT (x OR y) -> NOT x AND NOT y 729 NOT (x AND y) -> NOT x OR NOT y 730 """ 731 if isinstance(expression, exp.Not): 732 this = expression.this 733 if is_null(this): 734 return exp.and_(exp.null(), exp.true(), copy=False) 735 if this.__class__ in self.COMPLEMENT_COMPARISONS: 736 right = this.expression 737 complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( 738 right.__class__ 739 ) 740 if complement_subquery_predicate: 741 right = complement_subquery_predicate(this=right.this) 742 743 return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 744 if isinstance(this, exp.Paren): 745 condition = this.unnest() 746 if isinstance(condition, exp.And): 747 return exp.paren( 748 exp.or_( 749 exp.not_(condition.left, copy=False), 750 exp.not_(condition.right, copy=False), 751 copy=False, 752 ), 753 copy=False, 754 ) 755 if isinstance(condition, exp.Or): 756 return exp.paren( 757 exp.and_( 758 exp.not_(condition.left, copy=False), 759 exp.not_(condition.right, copy=False), 760 copy=False, 761 ), 762 copy=False, 763 ) 764 if is_null(condition): 765 return exp.and_(exp.null(), exp.true(), copy=False) 766 if always_true(this): 767 return exp.false() 768 if is_false(this): 769 return exp.true() 770 if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: 771 inner = this.this 772 if inner.is_type(exp.DataType.Type.BOOLEAN): 773 # double negation 774 # NOT NOT x -> x, if x is BOOLEAN type 775 return inner 776 return expression 777 778 @annotate_types_on_change 779 def simplify_connectors(self, expression, root=True): 780 def _simplify_connectors(expression, left, right): 781 if isinstance(expression, exp.And): 782 if is_false(left) or is_false(right): 783 return exp.false() 784 if is_zero(left) or is_zero(right): 785 return exp.false() 786 if ( 787 (is_null(left) and is_null(right)) 788 or (is_null(left) and always_true(right)) 789 or (always_true(left) and is_null(right)) 790 ): 791 return exp.null() 792 if always_true(left) and always_true(right): 793 return exp.true() 794 if always_true(left): 795 return right 796 if always_true(right): 797 return left 798 return self._simplify_comparison(expression, left, right) 799 elif isinstance(expression, exp.Or): 800 if always_true(left) or always_true(right): 801 return exp.true() 802 if ( 803 (is_null(left) and is_null(right)) 804 or (is_null(left) and always_false(right)) 805 or (always_false(left) and is_null(right)) 806 ): 807 return exp.null() 808 if is_false(left): 809 return right 810 if is_false(right): 811 return left 812 return self._simplify_comparison(expression, left, right, or_=True) 813 814 if isinstance(expression, exp.Connector): 815 original_parent = expression.parent 816 expression = self._flat_simplify(expression, _simplify_connectors, root) 817 818 # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need 819 # to ensure that the resulting type is boolean. We know this is true only for connectors, 820 # boolean values and columns that are essentially operands to a connector: 821 # 822 # A AND (((B))) 823 # ~ this is safe to keep because it will eventually be part of another connector 824 if not isinstance( 825 expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT 826 ) and not expression.is_type(exp.DataType.Type.BOOLEAN): 827 while True: 828 if isinstance(original_parent, exp.Connector): 829 break 830 if not isinstance(original_parent, exp.Paren): 831 expression = expression.and_(exp.true(), copy=False) 832 break 833 834 original_parent = original_parent.parent 835 836 return expression 837 838 @annotate_types_on_change 839 def _simplify_comparison(self, expression, left, right, or_=False): 840 if isinstance(left, self.COMPARISONS) and isinstance(right, self.COMPARISONS): 841 ll, lr = left.args.values() 842 rl, rr = right.args.values() 843 844 largs = {ll, lr} 845 rargs = {rl, rr} 846 847 matching = largs & rargs 848 columns = { 849 m for m in matching if not _is_constant(m) and not m.find(*self.NONDETERMINISTIC) 850 } 851 852 if matching and columns: 853 try: 854 l = first(largs - columns) 855 r = first(rargs - columns) 856 except StopIteration: 857 return expression 858 859 if l.is_number and r.is_number: 860 l = l.to_py() 861 r = r.to_py() 862 elif l.is_string and r.is_string: 863 l = l.name 864 r = r.name 865 else: 866 l = extract_date(l) 867 if not l: 868 return None 869 r = extract_date(r) 870 if not r: 871 return None 872 # python won't compare date and datetime, but many engines will upcast 873 l, r = cast_as_datetime(l), cast_as_datetime(r) 874 875 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 876 if isinstance(a, self.LT_LTE) and isinstance(b, self.LT_LTE): 877 return left if (av > bv if or_ else av <= bv) else right 878 if isinstance(a, self.GT_GTE) and isinstance(b, self.GT_GTE): 879 return left if (av < bv if or_ else av >= bv) else right 880 881 # we can't ever shortcut to true because the column could be null 882 if not or_: 883 if isinstance(a, exp.LT) and isinstance(b, self.GT_GTE): 884 if av <= bv: 885 return exp.false() 886 elif isinstance(a, exp.GT) and isinstance(b, self.LT_LTE): 887 if av >= bv: 888 return exp.false() 889 elif isinstance(a, exp.EQ): 890 if isinstance(b, exp.LT): 891 return exp.false() if av >= bv else a 892 if isinstance(b, exp.LTE): 893 return exp.false() if av > bv else a 894 if isinstance(b, exp.GT): 895 return exp.false() if av <= bv else a 896 if isinstance(b, exp.GTE): 897 return exp.false() if av < bv else a 898 if isinstance(b, exp.NEQ): 899 return exp.false() if av == bv else a 900 return None 901 902 @annotate_types_on_change 903 def remove_complements(self, expression, root=True): 904 """ 905 Removing complements. 906 907 A AND NOT A -> FALSE (only for non-NULL A) 908 A OR NOT A -> TRUE (only for non-NULL A) 909 """ 910 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 911 ops = set(expression.flatten()) 912 for op in ops: 913 if isinstance(op, exp.Not) and op.this in ops: 914 if expression.meta.get("nonnull") is True: 915 return exp.false() if isinstance(expression, exp.And) else exp.true() 916 917 return expression 918 919 @annotate_types_on_change 920 def uniq_sort(self, expression, root=True): 921 """ 922 Uniq and sort a connector. 923 924 C AND A AND B AND B -> A AND B AND C 925 """ 926 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 927 flattened = tuple(expression.flatten()) 928 929 if isinstance(expression, exp.Xor): 930 result_func = exp.xor 931 # Do not deduplicate XOR as A XOR A != A if A == True 932 deduped = None 933 arr = tuple((gen(e), e) for e in flattened) 934 else: 935 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 936 deduped = {gen(e): e for e in flattened} 937 arr = tuple(deduped.items()) 938 939 # check if the operands are already sorted, if not sort them 940 # A AND C AND B -> A AND B AND C 941 for i, (sql, e) in enumerate(arr[1:]): 942 if sql < arr[i][0]: 943 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 944 break 945 else: 946 # we didn't have to sort but maybe we need to dedup 947 if deduped and len(deduped) < len(flattened): 948 unique_operand = flattened[0] 949 if len(deduped) == 1: 950 expression = unique_operand.and_(exp.true(), copy=False) 951 else: 952 expression = result_func(*deduped.values(), copy=False) 953 954 return expression 955 956 @annotate_types_on_change 957 def absorb_and_eliminate(self, expression, root=True): 958 """ 959 absorption: 960 A AND (A OR B) -> A 961 A OR (A AND B) -> A 962 A AND (NOT A OR B) -> A AND B 963 A OR (NOT A AND B) -> A OR B 964 elimination: 965 (A AND B) OR (A AND NOT B) -> A 966 (A OR B) AND (A OR NOT B) -> A 967 """ 968 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 969 kind = exp.Or if isinstance(expression, exp.And) else exp.And 970 971 ops = tuple(expression.flatten()) 972 973 # Initialize lookup tables: 974 # Set of all operands, used to find complements for absorption. 975 op_set = set() 976 # Sub-operands, used to find subsets for absorption. 977 subops = defaultdict(list) 978 # Pairs of complements, used for elimination. 979 pairs = defaultdict(list) 980 981 # Populate the lookup tables 982 for op in ops: 983 op_set.add(op) 984 985 if not isinstance(op, kind): 986 # In cases like: A OR (A AND B) 987 # Subop will be: ^ 988 subops[op].append({op}) 989 continue 990 991 # In cases like: (A AND B) OR (A AND B AND C) 992 # Subops will be: ^ ^ 993 subset = set(op.flatten()) 994 for i in subset: 995 subops[i].append(subset) 996 997 a, b = op.unnest_operands() 998 if isinstance(a, exp.Not): 999 pairs[frozenset((a.this, b))].append((op, b)) 1000 if isinstance(b, exp.Not): 1001 pairs[frozenset((a, b.this))].append((op, a)) 1002 1003 for op in ops: 1004 if not isinstance(op, kind): 1005 continue 1006 1007 a, b = op.unnest_operands() 1008 1009 # Absorb 1010 if isinstance(a, exp.Not) and a.this in op_set: 1011 a.replace(exp.true() if kind == exp.And else exp.false()) 1012 continue 1013 if isinstance(b, exp.Not) and b.this in op_set: 1014 b.replace(exp.true() if kind == exp.And else exp.false()) 1015 continue 1016 superset = set(op.flatten()) 1017 if any(any(subset < superset for subset in subops[i]) for i in superset): 1018 op.replace(exp.false() if kind == exp.And else exp.true()) 1019 continue 1020 1021 # Eliminate 1022 for other, complement in pairs[frozenset((a, b))]: 1023 op.replace(complement) 1024 other.replace(complement) 1025 1026 return expression 1027 1028 @annotate_types_on_change 1029 @catch(ModuleNotFoundError, UnsupportedUnit) 1030 def simplify_equality(self, expression: exp.Expression) -> exp.Expression: 1031 """ 1032 Use the subtraction and addition properties of equality to simplify expressions: 1033 1034 x + 1 = 3 becomes x = 2 1035 1036 There are two binary operations in the above expression: + and = 1037 Here's how we reference all the operands in the code below: 1038 1039 l r 1040 x + 1 = 3 1041 a b 1042 """ 1043 if isinstance(expression, self.COMPARISONS): 1044 l, r = expression.left, expression.right 1045 1046 if l.__class__ not in self.INVERSE_OPS: 1047 return expression 1048 1049 if r.is_number: 1050 a_predicate = _is_number 1051 b_predicate = _is_number 1052 elif _is_date_literal(r): 1053 a_predicate = _is_date_literal 1054 b_predicate = _is_interval 1055 else: 1056 return expression 1057 1058 if l.__class__ in self.INVERSE_DATE_OPS: 1059 l = t.cast(exp.IntervalOp, l) 1060 a = l.this 1061 b = l.interval() 1062 else: 1063 l = t.cast(exp.Binary, l) 1064 a, b = l.left, l.right 1065 1066 if not a_predicate(a) and b_predicate(b): 1067 pass 1068 elif not a_predicate(b) and b_predicate(a): 1069 a, b = b, a 1070 else: 1071 return expression 1072 1073 return expression.__class__( 1074 this=a, expression=self.INVERSE_OPS[l.__class__](this=r, expression=b) 1075 ) 1076 return expression 1077 1078 @annotate_types_on_change 1079 def simplify_literals(self, expression, root=True): 1080 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 1081 return self._flat_simplify(expression, self._simplify_binary, root) 1082 1083 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 1084 return expression.this.this 1085 1086 if type(expression) in self.INVERSE_DATE_OPS: 1087 return ( 1088 self._simplify_binary(expression, expression.this, expression.interval()) 1089 or expression 1090 ) 1091 1092 return expression 1093 1094 def _simplify_integer_cast(self, expr: exp.Expression) -> exp.Expression: 1095 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 1096 this = self._simplify_integer_cast(expr.this) 1097 else: 1098 this = expr.this 1099 1100 if isinstance(expr, exp.Cast) and this.is_int: 1101 num = this.to_py() 1102 1103 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 1104 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 1105 # engine-dependent 1106 if ( 1107 self.TINYINT_MIN <= num <= self.TINYINT_MAX 1108 and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 1109 ) or ( 1110 self.UTINYINT_MIN <= num <= self.UTINYINT_MAX 1111 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 1112 ): 1113 return this 1114 1115 return expr 1116 1117 def _simplify_binary(self, expression, a, b): 1118 if isinstance(expression, self.COMPARISONS): 1119 a = self._simplify_integer_cast(a) 1120 b = self._simplify_integer_cast(b) 1121 1122 if isinstance(expression, exp.Is): 1123 if isinstance(b, exp.Not): 1124 c = b.this 1125 not_ = True 1126 else: 1127 c = b 1128 not_ = False 1129 1130 if is_null(c): 1131 if isinstance(a, exp.Literal): 1132 return exp.true() if not_ else exp.false() 1133 if is_null(a): 1134 return exp.false() if not_ else exp.true() 1135 elif isinstance(expression, self.NULL_OK): 1136 return None 1137 elif (is_null(a) or is_null(b)) and isinstance(expression.parent, exp.If): 1138 return exp.null() 1139 1140 if a.is_number and b.is_number: 1141 num_a = a.to_py() 1142 num_b = b.to_py() 1143 1144 if isinstance(expression, exp.Add): 1145 return exp.Literal.number(num_a + num_b) 1146 if isinstance(expression, exp.Mul): 1147 return exp.Literal.number(num_a * num_b) 1148 1149 # We only simplify Sub, Div if a and b have the same parent because they're not associative 1150 if isinstance(expression, exp.Sub): 1151 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 1152 if isinstance(expression, exp.Div): 1153 # engines have differing int div behavior so intdiv is not safe 1154 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 1155 return None 1156 return exp.Literal.number(num_a / num_b) 1157 1158 boolean = eval_boolean(expression, num_a, num_b) 1159 1160 if boolean: 1161 return boolean 1162 elif a.is_string and b.is_string: 1163 boolean = eval_boolean(expression, a.this, b.this) 1164 1165 if boolean: 1166 return boolean 1167 elif _is_date_literal(a) and isinstance(b, exp.Interval): 1168 date, b = extract_date(a), extract_interval(b) 1169 if date and b: 1170 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 1171 return date_literal(date + b, extract_type(a)) 1172 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 1173 return date_literal(date - b, extract_type(a)) 1174 elif isinstance(a, exp.Interval) and _is_date_literal(b): 1175 a, date = extract_interval(a), extract_date(b) 1176 # you cannot subtract a date from an interval 1177 if a and b and isinstance(expression, exp.Add): 1178 return date_literal(a + date, extract_type(b)) 1179 elif _is_date_literal(a) and _is_date_literal(b): 1180 if isinstance(expression, exp.Predicate): 1181 a, b = extract_date(a), extract_date(b) 1182 boolean = eval_boolean(expression, a, b) 1183 if boolean: 1184 return boolean 1185 1186 return None 1187 1188 @annotate_types_on_change 1189 def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression: 1190 # COALESCE(x) -> x 1191 if ( 1192 isinstance(expression, exp.Coalesce) 1193 and (not expression.expressions or _is_nonnull_constant(expression.this)) 1194 # COALESCE is also used as a Spark partitioning hint 1195 and not isinstance(expression.parent, exp.Hint) 1196 ): 1197 return expression.this 1198 1199 if self.dialect.COALESCE_COMPARISON_NON_STANDARD: 1200 return expression 1201 1202 if not isinstance(expression, self.COMPARISONS): 1203 return expression 1204 1205 if isinstance(expression.left, exp.Coalesce): 1206 coalesce = expression.left 1207 other = expression.right 1208 elif isinstance(expression.right, exp.Coalesce): 1209 coalesce = expression.right 1210 other = expression.left 1211 else: 1212 return expression 1213 1214 # This transformation is valid for non-constants, 1215 # but it really only does anything if they are both constants. 1216 if not _is_constant(other): 1217 return expression 1218 1219 # Find the first constant arg 1220 for arg_index, arg in enumerate(coalesce.expressions): 1221 if _is_constant(arg): 1222 break 1223 else: 1224 return expression 1225 1226 coalesce.set("expressions", coalesce.expressions[:arg_index]) 1227 1228 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 1229 # since we already remove COALESCE at the top of this function. 1230 coalesce = coalesce if coalesce.expressions else coalesce.this 1231 1232 # This expression is more complex than when we started, but it will get simplified further 1233 return exp.paren( 1234 exp.or_( 1235 exp.and_( 1236 coalesce.is_(exp.null()).not_(copy=False), 1237 expression.copy(), 1238 copy=False, 1239 ), 1240 exp.and_( 1241 coalesce.is_(exp.null()), 1242 type(expression)(this=arg.copy(), expression=other.copy()), 1243 copy=False, 1244 ), 1245 copy=False, 1246 ), 1247 copy=False, 1248 ) 1249 1250 @annotate_types_on_change 1251 def simplify_concat(self, expression): 1252 """Reduces all groups that contain string literals by concatenating them.""" 1253 if not isinstance(expression, self.CONCATS) or ( 1254 # We can't reduce a CONCAT_WS call if we don't statically know the separator 1255 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 1256 ): 1257 return expression 1258 1259 if isinstance(expression, exp.ConcatWs): 1260 sep_expr, *expressions = expression.expressions 1261 sep = sep_expr.name 1262 concat_type = exp.ConcatWs 1263 args = {} 1264 else: 1265 expressions = expression.expressions 1266 sep = "" 1267 concat_type = exp.Concat 1268 args = { 1269 "safe": expression.args.get("safe"), 1270 "coalesce": expression.args.get("coalesce"), 1271 } 1272 1273 new_args = [] 1274 for is_string_group, group in itertools.groupby( 1275 expressions or expression.flatten(), lambda e: e.is_string 1276 ): 1277 if is_string_group: 1278 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 1279 else: 1280 new_args.extend(group) 1281 1282 if len(new_args) == 1 and new_args[0].is_string: 1283 return new_args[0] 1284 1285 if concat_type is exp.ConcatWs: 1286 new_args = [sep_expr] + new_args 1287 elif isinstance(expression, exp.DPipe): 1288 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 1289 1290 return concat_type(expressions=new_args, **args) 1291 1292 @annotate_types_on_change 1293 def simplify_conditionals(self, expression): 1294 """Simplifies expressions like IF, CASE if their condition is statically known.""" 1295 if isinstance(expression, exp.Case): 1296 this = expression.this 1297 for case in expression.args["ifs"]: 1298 cond = case.this 1299 if this: 1300 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 1301 cond = cond.replace(this.pop().eq(cond)) 1302 1303 if always_true(cond): 1304 return case.args["true"] 1305 1306 if always_false(cond): 1307 case.pop() 1308 if not expression.args["ifs"]: 1309 return expression.args.get("default") or exp.null() 1310 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 1311 if always_true(expression.this): 1312 return expression.args["true"] 1313 if always_false(expression.this): 1314 return expression.args.get("false") or exp.null() 1315 1316 return expression 1317 1318 @annotate_types_on_change 1319 def simplify_startswith(self, expression: exp.Expression) -> exp.Expression: 1320 """ 1321 Reduces a prefix check to either TRUE or FALSE if both the string and the 1322 prefix are statically known. 1323 1324 Example: 1325 >>> from sqlglot import parse_one 1326 >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 1327 'TRUE' 1328 """ 1329 if ( 1330 isinstance(expression, exp.StartsWith) 1331 and expression.this.is_string 1332 and expression.expression.is_string 1333 ): 1334 return exp.convert(expression.name.startswith(expression.expression.name)) 1335 1336 return expression 1337 1338 def _is_datetrunc_predicate(self, left: exp.Expression, right: exp.Expression) -> bool: 1339 return isinstance(left, self.DATETRUNCS) and _is_date_literal(right) 1340 1341 @annotate_types_on_change 1342 @catch(ModuleNotFoundError, UnsupportedUnit) 1343 def simplify_datetrunc(self, expression: exp.Expression) -> exp.Expression: 1344 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1345 comparison = expression.__class__ 1346 1347 if isinstance(expression, self.DATETRUNCS): 1348 this = expression.this 1349 trunc_type = extract_type(this) 1350 date = extract_date(this) 1351 if date and expression.unit: 1352 return date_literal( 1353 date_floor(date, expression.unit.name.lower(), self.dialect), trunc_type 1354 ) 1355 elif comparison not in self.DATETRUNC_COMPARISONS: 1356 return expression 1357 1358 if isinstance(expression, exp.Binary): 1359 l, r = expression.left, expression.right 1360 1361 if not self._is_datetrunc_predicate(l, r): 1362 return expression 1363 1364 l = t.cast(exp.DateTrunc, l) 1365 trunc_arg = l.this 1366 unit = l.unit.name.lower() 1367 date = extract_date(r) 1368 1369 if not date: 1370 return expression 1371 1372 return ( 1373 self.DATETRUNC_BINARY_COMPARISONS[comparison]( 1374 trunc_arg, date, unit, self.dialect, extract_type(r) 1375 ) 1376 or expression 1377 ) 1378 1379 if isinstance(expression, exp.In): 1380 l = expression.this 1381 rs = expression.expressions 1382 1383 if rs and all(self._is_datetrunc_predicate(l, r) for r in rs): 1384 l = t.cast(exp.DateTrunc, l) 1385 unit = l.unit.name.lower() 1386 1387 ranges = [] 1388 for r in rs: 1389 date = extract_date(r) 1390 if not date: 1391 return expression 1392 drange = _datetrunc_range(date, unit, self.dialect) 1393 if drange: 1394 ranges.append(drange) 1395 1396 if not ranges: 1397 return expression 1398 1399 ranges = merge_ranges(ranges) 1400 target_type = extract_type(*rs) 1401 1402 return exp.or_( 1403 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], 1404 copy=False, 1405 ) 1406 1407 return expression 1408 1409 @annotate_types_on_change 1410 def sort_comparison(self, expression: exp.Expression) -> exp.Expression: 1411 if expression.__class__ in self.COMPLEMENT_COMPARISONS: 1412 l, r = expression.this, expression.expression 1413 l_column = isinstance(l, exp.Column) 1414 r_column = isinstance(r, exp.Column) 1415 l_const = _is_constant(l) 1416 r_const = _is_constant(r) 1417 1418 if ( 1419 (l_column and not r_column) 1420 or (r_const and not l_const) 1421 or isinstance(r, exp.SubqueryPredicate) 1422 ): 1423 return expression 1424 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1425 return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1426 this=r, expression=l 1427 ) 1428 return expression 1429 1430 def _flat_simplify(self, expression, simplifier, root=True): 1431 if root or not expression.same_parent: 1432 operands = [] 1433 queue = deque(expression.flatten(unnest=False)) 1434 size = len(queue) 1435 1436 while queue: 1437 a = queue.popleft() 1438 1439 for b in queue: 1440 result = simplifier(expression, a, b) 1441 1442 if result and result is not expression: 1443 queue.remove(b) 1444 queue.appendleft(result) 1445 break 1446 else: 1447 operands.append(a) 1448 1449 if len(operands) < size: 1450 return functools.reduce( 1451 lambda a, b: expression.__class__(this=a, expression=b), operands 1452 ) 1453 return expression
474 def __init__(self, dialect: DialectType = None, annotate_new_expressions: bool = True): 475 self.dialect = Dialect.get_or_raise(dialect) 476 self.annotate_new_expressions = annotate_new_expressions 477 478 self._annotator: TypeAnnotator = TypeAnnotator( 479 schema=ensure_schema(None, dialect=self.dialect), overwrite_types=False 480 )
565 def simplify( 566 self, 567 expression: exp.Expression, 568 constant_propagation: bool = False, 569 coalesce_simplification: bool = False, 570 ): 571 wheres = [] 572 joins = [] 573 574 for node in expression.walk( 575 prune=lambda n: bool(isinstance(n, exp.Condition) or n.meta.get(FINAL)) 576 ): 577 if node.meta.get(FINAL): 578 continue 579 580 # group by expressions cannot be simplified, for example 581 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 582 # the projection must exactly match the group by key 583 group = node.args.get("group") 584 585 if group and hasattr(node, "selects"): 586 groups = set(group.expressions) 587 group.meta[FINAL] = True 588 589 for s in node.selects: 590 for n in s.walk(FINAL): 591 if n in groups: 592 s.meta[FINAL] = True 593 break 594 595 having = node.args.get("having") 596 597 if having: 598 for n in having.walk(): 599 if n in groups: 600 having.meta[FINAL] = True 601 break 602 603 if isinstance(node, exp.Condition): 604 simplified = while_changing( 605 node, lambda e: self._simplify(e, constant_propagation, coalesce_simplification) 606 ) 607 608 if node is expression: 609 expression = simplified 610 elif isinstance(node, exp.Where): 611 wheres.append(node) 612 elif isinstance(node, exp.Join): 613 # snowflake match_conditions have very strict ordering rules 614 if match := node.args.get("match_condition"): 615 match.meta[FINAL] = True 616 617 joins.append(node) 618 619 for where in wheres: 620 if always_true(where.this): 621 where.pop() 622 for join in joins: 623 if ( 624 always_true(join.args.get("on")) 625 and not join.args.get("using") 626 and not join.args.get("method") 627 and (join.side, join.kind) in self.JOINS 628 ): 629 join.args["on"].pop() 630 join.set("side", None) 631 join.set("kind", "CROSS") 632 633 return expression
704 @annotate_types_on_change 705 def rewrite_between(self, expression: exp.Expression) -> exp.Expression: 706 """Rewrite x between y and z to x >= y AND x <= z. 707 708 This is done because comparison simplification is only done on lt/lte/gt/gte. 709 """ 710 if isinstance(expression, exp.Between): 711 negate = isinstance(expression.parent, exp.Not) 712 713 expression = exp.and_( 714 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 715 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 716 copy=False, 717 ) 718 719 if negate: 720 expression = exp.paren(expression, copy=False) 721 722 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
724 @annotate_types_on_change 725 def simplify_not(self, expression: exp.Expression) -> exp.Expression: 726 """ 727 Demorgan's Law 728 NOT (x OR y) -> NOT x AND NOT y 729 NOT (x AND y) -> NOT x OR NOT y 730 """ 731 if isinstance(expression, exp.Not): 732 this = expression.this 733 if is_null(this): 734 return exp.and_(exp.null(), exp.true(), copy=False) 735 if this.__class__ in self.COMPLEMENT_COMPARISONS: 736 right = this.expression 737 complement_subquery_predicate = self.COMPLEMENT_SUBQUERY_PREDICATES.get( 738 right.__class__ 739 ) 740 if complement_subquery_predicate: 741 right = complement_subquery_predicate(this=right.this) 742 743 return self.COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 744 if isinstance(this, exp.Paren): 745 condition = this.unnest() 746 if isinstance(condition, exp.And): 747 return exp.paren( 748 exp.or_( 749 exp.not_(condition.left, copy=False), 750 exp.not_(condition.right, copy=False), 751 copy=False, 752 ), 753 copy=False, 754 ) 755 if isinstance(condition, exp.Or): 756 return exp.paren( 757 exp.and_( 758 exp.not_(condition.left, copy=False), 759 exp.not_(condition.right, copy=False), 760 copy=False, 761 ), 762 copy=False, 763 ) 764 if is_null(condition): 765 return exp.and_(exp.null(), exp.true(), copy=False) 766 if always_true(this): 767 return exp.false() 768 if is_false(this): 769 return exp.true() 770 if isinstance(this, exp.Not) and self.dialect.SAFE_TO_ELIMINATE_DOUBLE_NEGATION: 771 inner = this.this 772 if inner.is_type(exp.DataType.Type.BOOLEAN): 773 # double negation 774 # NOT NOT x -> x, if x is BOOLEAN type 775 return inner 776 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
778 @annotate_types_on_change 779 def simplify_connectors(self, expression, root=True): 780 def _simplify_connectors(expression, left, right): 781 if isinstance(expression, exp.And): 782 if is_false(left) or is_false(right): 783 return exp.false() 784 if is_zero(left) or is_zero(right): 785 return exp.false() 786 if ( 787 (is_null(left) and is_null(right)) 788 or (is_null(left) and always_true(right)) 789 or (always_true(left) and is_null(right)) 790 ): 791 return exp.null() 792 if always_true(left) and always_true(right): 793 return exp.true() 794 if always_true(left): 795 return right 796 if always_true(right): 797 return left 798 return self._simplify_comparison(expression, left, right) 799 elif isinstance(expression, exp.Or): 800 if always_true(left) or always_true(right): 801 return exp.true() 802 if ( 803 (is_null(left) and is_null(right)) 804 or (is_null(left) and always_false(right)) 805 or (always_false(left) and is_null(right)) 806 ): 807 return exp.null() 808 if is_false(left): 809 return right 810 if is_false(right): 811 return left 812 return self._simplify_comparison(expression, left, right, or_=True) 813 814 if isinstance(expression, exp.Connector): 815 original_parent = expression.parent 816 expression = self._flat_simplify(expression, _simplify_connectors, root) 817 818 # If we reduced a connector to, e.g., a column (t1 AND ... AND tn -> Tk), then we need 819 # to ensure that the resulting type is boolean. We know this is true only for connectors, 820 # boolean values and columns that are essentially operands to a connector: 821 # 822 # A AND (((B))) 823 # ~ this is safe to keep because it will eventually be part of another connector 824 if not isinstance( 825 expression, self.SAFE_CONNECTOR_ELIMINATION_RESULT 826 ) and not expression.is_type(exp.DataType.Type.BOOLEAN): 827 while True: 828 if isinstance(original_parent, exp.Connector): 829 break 830 if not isinstance(original_parent, exp.Paren): 831 expression = expression.and_(exp.true(), copy=False) 832 break 833 834 original_parent = original_parent.parent 835 836 return expression
902 @annotate_types_on_change 903 def remove_complements(self, expression, root=True): 904 """ 905 Removing complements. 906 907 A AND NOT A -> FALSE (only for non-NULL A) 908 A OR NOT A -> TRUE (only for non-NULL A) 909 """ 910 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 911 ops = set(expression.flatten()) 912 for op in ops: 913 if isinstance(op, exp.Not) and op.this in ops: 914 if expression.meta.get("nonnull") is True: 915 return exp.false() if isinstance(expression, exp.And) else exp.true() 916 917 return expression
Removing complements.
A AND NOT A -> FALSE (only for non-NULL A) A OR NOT A -> TRUE (only for non-NULL A)
919 @annotate_types_on_change 920 def uniq_sort(self, expression, root=True): 921 """ 922 Uniq and sort a connector. 923 924 C AND A AND B AND B -> A AND B AND C 925 """ 926 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 927 flattened = tuple(expression.flatten()) 928 929 if isinstance(expression, exp.Xor): 930 result_func = exp.xor 931 # Do not deduplicate XOR as A XOR A != A if A == True 932 deduped = None 933 arr = tuple((gen(e), e) for e in flattened) 934 else: 935 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 936 deduped = {gen(e): e for e in flattened} 937 arr = tuple(deduped.items()) 938 939 # check if the operands are already sorted, if not sort them 940 # A AND C AND B -> A AND B AND C 941 for i, (sql, e) in enumerate(arr[1:]): 942 if sql < arr[i][0]: 943 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 944 break 945 else: 946 # we didn't have to sort but maybe we need to dedup 947 if deduped and len(deduped) < len(flattened): 948 unique_operand = flattened[0] 949 if len(deduped) == 1: 950 expression = unique_operand.and_(exp.true(), copy=False) 951 else: 952 expression = result_func(*deduped.values(), copy=False) 953 954 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
956 @annotate_types_on_change 957 def absorb_and_eliminate(self, expression, root=True): 958 """ 959 absorption: 960 A AND (A OR B) -> A 961 A OR (A AND B) -> A 962 A AND (NOT A OR B) -> A AND B 963 A OR (NOT A AND B) -> A OR B 964 elimination: 965 (A AND B) OR (A AND NOT B) -> A 966 (A OR B) AND (A OR NOT B) -> A 967 """ 968 if isinstance(expression, self.AND_OR) and (root or not expression.same_parent): 969 kind = exp.Or if isinstance(expression, exp.And) else exp.And 970 971 ops = tuple(expression.flatten()) 972 973 # Initialize lookup tables: 974 # Set of all operands, used to find complements for absorption. 975 op_set = set() 976 # Sub-operands, used to find subsets for absorption. 977 subops = defaultdict(list) 978 # Pairs of complements, used for elimination. 979 pairs = defaultdict(list) 980 981 # Populate the lookup tables 982 for op in ops: 983 op_set.add(op) 984 985 if not isinstance(op, kind): 986 # In cases like: A OR (A AND B) 987 # Subop will be: ^ 988 subops[op].append({op}) 989 continue 990 991 # In cases like: (A AND B) OR (A AND B AND C) 992 # Subops will be: ^ ^ 993 subset = set(op.flatten()) 994 for i in subset: 995 subops[i].append(subset) 996 997 a, b = op.unnest_operands() 998 if isinstance(a, exp.Not): 999 pairs[frozenset((a.this, b))].append((op, b)) 1000 if isinstance(b, exp.Not): 1001 pairs[frozenset((a, b.this))].append((op, a)) 1002 1003 for op in ops: 1004 if not isinstance(op, kind): 1005 continue 1006 1007 a, b = op.unnest_operands() 1008 1009 # Absorb 1010 if isinstance(a, exp.Not) and a.this in op_set: 1011 a.replace(exp.true() if kind == exp.And else exp.false()) 1012 continue 1013 if isinstance(b, exp.Not) and b.this in op_set: 1014 b.replace(exp.true() if kind == exp.And else exp.false()) 1015 continue 1016 superset = set(op.flatten()) 1017 if any(any(subset < superset for subset in subops[i]) for i in superset): 1018 op.replace(exp.false() if kind == exp.And else exp.true()) 1019 continue 1020 1021 # Eliminate 1022 for other, complement in pairs[frozenset((a, b))]: 1023 op.replace(complement) 1024 other.replace(complement) 1025 1026 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
82 def wrapped(expression, *args, **kwargs): 83 try: 84 return func(expression, *args, **kwargs) 85 except exceptions: 86 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
1078 @annotate_types_on_change 1079 def simplify_literals(self, expression, root=True): 1080 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 1081 return self._flat_simplify(expression, self._simplify_binary, root) 1082 1083 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 1084 return expression.this.this 1085 1086 if type(expression) in self.INVERSE_DATE_OPS: 1087 return ( 1088 self._simplify_binary(expression, expression.this, expression.interval()) 1089 or expression 1090 ) 1091 1092 return expression
1188 @annotate_types_on_change 1189 def simplify_coalesce(self, expression: exp.Expression) -> exp.Expression: 1190 # COALESCE(x) -> x 1191 if ( 1192 isinstance(expression, exp.Coalesce) 1193 and (not expression.expressions or _is_nonnull_constant(expression.this)) 1194 # COALESCE is also used as a Spark partitioning hint 1195 and not isinstance(expression.parent, exp.Hint) 1196 ): 1197 return expression.this 1198 1199 if self.dialect.COALESCE_COMPARISON_NON_STANDARD: 1200 return expression 1201 1202 if not isinstance(expression, self.COMPARISONS): 1203 return expression 1204 1205 if isinstance(expression.left, exp.Coalesce): 1206 coalesce = expression.left 1207 other = expression.right 1208 elif isinstance(expression.right, exp.Coalesce): 1209 coalesce = expression.right 1210 other = expression.left 1211 else: 1212 return expression 1213 1214 # This transformation is valid for non-constants, 1215 # but it really only does anything if they are both constants. 1216 if not _is_constant(other): 1217 return expression 1218 1219 # Find the first constant arg 1220 for arg_index, arg in enumerate(coalesce.expressions): 1221 if _is_constant(arg): 1222 break 1223 else: 1224 return expression 1225 1226 coalesce.set("expressions", coalesce.expressions[:arg_index]) 1227 1228 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 1229 # since we already remove COALESCE at the top of this function. 1230 coalesce = coalesce if coalesce.expressions else coalesce.this 1231 1232 # This expression is more complex than when we started, but it will get simplified further 1233 return exp.paren( 1234 exp.or_( 1235 exp.and_( 1236 coalesce.is_(exp.null()).not_(copy=False), 1237 expression.copy(), 1238 copy=False, 1239 ), 1240 exp.and_( 1241 coalesce.is_(exp.null()), 1242 type(expression)(this=arg.copy(), expression=other.copy()), 1243 copy=False, 1244 ), 1245 copy=False, 1246 ), 1247 copy=False, 1248 )
1250 @annotate_types_on_change 1251 def simplify_concat(self, expression): 1252 """Reduces all groups that contain string literals by concatenating them.""" 1253 if not isinstance(expression, self.CONCATS) or ( 1254 # We can't reduce a CONCAT_WS call if we don't statically know the separator 1255 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 1256 ): 1257 return expression 1258 1259 if isinstance(expression, exp.ConcatWs): 1260 sep_expr, *expressions = expression.expressions 1261 sep = sep_expr.name 1262 concat_type = exp.ConcatWs 1263 args = {} 1264 else: 1265 expressions = expression.expressions 1266 sep = "" 1267 concat_type = exp.Concat 1268 args = { 1269 "safe": expression.args.get("safe"), 1270 "coalesce": expression.args.get("coalesce"), 1271 } 1272 1273 new_args = [] 1274 for is_string_group, group in itertools.groupby( 1275 expressions or expression.flatten(), lambda e: e.is_string 1276 ): 1277 if is_string_group: 1278 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 1279 else: 1280 new_args.extend(group) 1281 1282 if len(new_args) == 1 and new_args[0].is_string: 1283 return new_args[0] 1284 1285 if concat_type is exp.ConcatWs: 1286 new_args = [sep_expr] + new_args 1287 elif isinstance(expression, exp.DPipe): 1288 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 1289 1290 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
1292 @annotate_types_on_change 1293 def simplify_conditionals(self, expression): 1294 """Simplifies expressions like IF, CASE if their condition is statically known.""" 1295 if isinstance(expression, exp.Case): 1296 this = expression.this 1297 for case in expression.args["ifs"]: 1298 cond = case.this 1299 if this: 1300 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 1301 cond = cond.replace(this.pop().eq(cond)) 1302 1303 if always_true(cond): 1304 return case.args["true"] 1305 1306 if always_false(cond): 1307 case.pop() 1308 if not expression.args["ifs"]: 1309 return expression.args.get("default") or exp.null() 1310 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 1311 if always_true(expression.this): 1312 return expression.args["true"] 1313 if always_false(expression.this): 1314 return expression.args.get("false") or exp.null() 1315 1316 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
1318 @annotate_types_on_change 1319 def simplify_startswith(self, expression: exp.Expression) -> exp.Expression: 1320 """ 1321 Reduces a prefix check to either TRUE or FALSE if both the string and the 1322 prefix are statically known. 1323 1324 Example: 1325 >>> from sqlglot import parse_one 1326 >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 1327 'TRUE' 1328 """ 1329 if ( 1330 isinstance(expression, exp.StartsWith) 1331 and expression.this.is_string 1332 and expression.expression.is_string 1333 ): 1334 return exp.convert(expression.name.startswith(expression.expression.name)) 1335 1336 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> Simplifier().simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
82 def wrapped(expression, *args, **kwargs): 83 try: 84 return func(expression, *args, **kwargs) 85 except exceptions: 86 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1409 @annotate_types_on_change 1410 def sort_comparison(self, expression: exp.Expression) -> exp.Expression: 1411 if expression.__class__ in self.COMPLEMENT_COMPARISONS: 1412 l, r = expression.this, expression.expression 1413 l_column = isinstance(l, exp.Column) 1414 r_column = isinstance(r, exp.Column) 1415 l_const = _is_constant(l) 1416 r_const = _is_constant(r) 1417 1418 if ( 1419 (l_column and not r_column) 1420 or (r_const and not l_const) 1421 or isinstance(r, exp.SubqueryPredicate) 1422 ): 1423 return expression 1424 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1425 return self.INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1426 this=r, expression=l 1427 ) 1428 return expression
1456def gen(expression: t.Any, comments: bool = False) -> str: 1457 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1458 1459 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1460 generator is expensive so we have a bare minimum sql generator here. 1461 1462 Args: 1463 expression: the expression to convert into a SQL string. 1464 comments: whether to include the expression's comments. 1465 """ 1466 return Gen().gen(expression, comments=comments)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
Arguments:
- expression: the expression to convert into a SQL string.
- comments: whether to include the expression's comments.
1469class Gen: 1470 def __init__(self): 1471 self.stack = [] 1472 self.sqls = [] 1473 1474 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1475 self.stack = [expression] 1476 self.sqls.clear() 1477 1478 while self.stack: 1479 node = self.stack.pop() 1480 1481 if isinstance(node, exp.Expression): 1482 if comments and node.comments: 1483 self.stack.append(f" /*{','.join(node.comments)}*/") 1484 1485 exp_handler_name = f"{node.key}_sql" 1486 1487 if hasattr(self, exp_handler_name): 1488 getattr(self, exp_handler_name)(node) 1489 elif isinstance(node, exp.Func): 1490 self._function(node) 1491 else: 1492 key = node.key.upper() 1493 self.stack.append(f"{key} " if self._args(node) else key) 1494 elif type(node) is list: 1495 for n in reversed(node): 1496 if n is not None: 1497 self.stack.extend((n, ",")) 1498 if node: 1499 self.stack.pop() 1500 else: 1501 if node is not None: 1502 self.sqls.append(str(node)) 1503 1504 return "".join(self.sqls) 1505 1506 def add_sql(self, e: exp.Add) -> None: 1507 self._binary(e, " + ") 1508 1509 def alias_sql(self, e: exp.Alias) -> None: 1510 self.stack.extend( 1511 ( 1512 e.args.get("alias"), 1513 " AS ", 1514 e.args.get("this"), 1515 ) 1516 ) 1517 1518 def and_sql(self, e: exp.And) -> None: 1519 self._binary(e, " AND ") 1520 1521 def anonymous_sql(self, e: exp.Anonymous) -> None: 1522 this = e.this 1523 if isinstance(this, str): 1524 name = this.upper() 1525 elif isinstance(this, exp.Identifier): 1526 name = this.this 1527 name = f'"{name}"' if this.quoted else name.upper() 1528 else: 1529 raise ValueError( 1530 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1531 ) 1532 1533 self.stack.extend( 1534 ( 1535 ")", 1536 e.expressions, 1537 "(", 1538 name, 1539 ) 1540 ) 1541 1542 def between_sql(self, e: exp.Between) -> None: 1543 self.stack.extend( 1544 ( 1545 e.args.get("high"), 1546 " AND ", 1547 e.args.get("low"), 1548 " BETWEEN ", 1549 e.this, 1550 ) 1551 ) 1552 1553 def boolean_sql(self, e: exp.Boolean) -> None: 1554 self.stack.append("TRUE" if e.this else "FALSE") 1555 1556 def bracket_sql(self, e: exp.Bracket) -> None: 1557 self.stack.extend( 1558 ( 1559 "]", 1560 e.expressions, 1561 "[", 1562 e.this, 1563 ) 1564 ) 1565 1566 def column_sql(self, e: exp.Column) -> None: 1567 for p in reversed(e.parts): 1568 self.stack.extend((p, ".")) 1569 self.stack.pop() 1570 1571 def datatype_sql(self, e: exp.DataType) -> None: 1572 self._args(e, 1) 1573 self.stack.append(f"{e.this.name} ") 1574 1575 def div_sql(self, e: exp.Div) -> None: 1576 self._binary(e, " / ") 1577 1578 def dot_sql(self, e: exp.Dot) -> None: 1579 self._binary(e, ".") 1580 1581 def eq_sql(self, e: exp.EQ) -> None: 1582 self._binary(e, " = ") 1583 1584 def from_sql(self, e: exp.From) -> None: 1585 self.stack.extend((e.this, "FROM ")) 1586 1587 def gt_sql(self, e: exp.GT) -> None: 1588 self._binary(e, " > ") 1589 1590 def gte_sql(self, e: exp.GTE) -> None: 1591 self._binary(e, " >= ") 1592 1593 def identifier_sql(self, e: exp.Identifier) -> None: 1594 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1595 1596 def ilike_sql(self, e: exp.ILike) -> None: 1597 self._binary(e, " ILIKE ") 1598 1599 def in_sql(self, e: exp.In) -> None: 1600 self.stack.append(")") 1601 self._args(e, 1) 1602 self.stack.extend( 1603 ( 1604 "(", 1605 " IN ", 1606 e.this, 1607 ) 1608 ) 1609 1610 def intdiv_sql(self, e: exp.IntDiv) -> None: 1611 self._binary(e, " DIV ") 1612 1613 def is_sql(self, e: exp.Is) -> None: 1614 self._binary(e, " IS ") 1615 1616 def like_sql(self, e: exp.Like) -> None: 1617 self._binary(e, " Like ") 1618 1619 def literal_sql(self, e: exp.Literal) -> None: 1620 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1621 1622 def lt_sql(self, e: exp.LT) -> None: 1623 self._binary(e, " < ") 1624 1625 def lte_sql(self, e: exp.LTE) -> None: 1626 self._binary(e, " <= ") 1627 1628 def mod_sql(self, e: exp.Mod) -> None: 1629 self._binary(e, " % ") 1630 1631 def mul_sql(self, e: exp.Mul) -> None: 1632 self._binary(e, " * ") 1633 1634 def neg_sql(self, e: exp.Neg) -> None: 1635 self._unary(e, "-") 1636 1637 def neq_sql(self, e: exp.NEQ) -> None: 1638 self._binary(e, " <> ") 1639 1640 def not_sql(self, e: exp.Not) -> None: 1641 self._unary(e, "NOT ") 1642 1643 def null_sql(self, e: exp.Null) -> None: 1644 self.stack.append("NULL") 1645 1646 def or_sql(self, e: exp.Or) -> None: 1647 self._binary(e, " OR ") 1648 1649 def paren_sql(self, e: exp.Paren) -> None: 1650 self.stack.extend( 1651 ( 1652 ")", 1653 e.this, 1654 "(", 1655 ) 1656 ) 1657 1658 def sub_sql(self, e: exp.Sub) -> None: 1659 self._binary(e, " - ") 1660 1661 def subquery_sql(self, e: exp.Subquery) -> None: 1662 self._args(e, 2) 1663 alias = e.args.get("alias") 1664 if alias: 1665 self.stack.append(alias) 1666 self.stack.extend((")", e.this, "(")) 1667 1668 def table_sql(self, e: exp.Table) -> None: 1669 self._args(e, 4) 1670 alias = e.args.get("alias") 1671 if alias: 1672 self.stack.append(alias) 1673 for p in reversed(e.parts): 1674 self.stack.extend((p, ".")) 1675 self.stack.pop() 1676 1677 def tablealias_sql(self, e: exp.TableAlias) -> None: 1678 columns = e.columns 1679 1680 if columns: 1681 self.stack.extend((")", columns, "(")) 1682 1683 self.stack.extend((e.this, " AS ")) 1684 1685 def var_sql(self, e: exp.Var) -> None: 1686 self.stack.append(e.this) 1687 1688 def _binary(self, e: exp.Binary, op: str) -> None: 1689 self.stack.extend((e.expression, op, e.this)) 1690 1691 def _unary(self, e: exp.Unary, op: str) -> None: 1692 self.stack.extend((e.this, op)) 1693 1694 def _function(self, e: exp.Func) -> None: 1695 self.stack.extend( 1696 ( 1697 ")", 1698 list(e.args.values()), 1699 "(", 1700 e.sql_name(), 1701 ) 1702 ) 1703 1704 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1705 kvs = [] 1706 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1707 1708 for k in arg_types: 1709 v = node.args.get(k) 1710 1711 if v is not None: 1712 kvs.append([f":{k}", v]) 1713 if kvs: 1714 self.stack.append(kvs) 1715 return True 1716 return False
1474 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1475 self.stack = [expression] 1476 self.sqls.clear() 1477 1478 while self.stack: 1479 node = self.stack.pop() 1480 1481 if isinstance(node, exp.Expression): 1482 if comments and node.comments: 1483 self.stack.append(f" /*{','.join(node.comments)}*/") 1484 1485 exp_handler_name = f"{node.key}_sql" 1486 1487 if hasattr(self, exp_handler_name): 1488 getattr(self, exp_handler_name)(node) 1489 elif isinstance(node, exp.Func): 1490 self._function(node) 1491 else: 1492 key = node.key.upper() 1493 self.stack.append(f"{key} " if self._args(node) else key) 1494 elif type(node) is list: 1495 for n in reversed(node): 1496 if n is not None: 1497 self.stack.extend((n, ",")) 1498 if node: 1499 self.stack.pop() 1500 else: 1501 if node is not None: 1502 self.sqls.append(str(node)) 1503 1504 return "".join(self.sqls)
1521 def anonymous_sql(self, e: exp.Anonymous) -> None: 1522 this = e.this 1523 if isinstance(this, str): 1524 name = this.upper() 1525 elif isinstance(this, exp.Identifier): 1526 name = this.this 1527 name = f'"{name}"' if this.quoted else name.upper() 1528 else: 1529 raise ValueError( 1530 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1531 ) 1532 1533 self.stack.extend( 1534 ( 1535 ")", 1536 e.expressions, 1537 "(", 1538 name, 1539 ) 1540 )