sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 coalesce_simplification: bool = False, 43 dialect: DialectType = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 coalesce_simplification: whether the simplify coalesce rule should be used. 58 This rule tries to remove coalesce functions, which can be useful in certain analyses but 59 can leave the query more verbose. 60 Returns: 61 sqlglot.Expression: simplified expression 62 """ 63 64 dialect = Dialect.get_or_raise(dialect) 65 66 def _simplify(expression): 67 pre_transformation_stack = [expression] 68 post_transformation_stack = [] 69 70 while pre_transformation_stack: 71 node = pre_transformation_stack.pop() 72 73 if node.meta.get(FINAL): 74 continue 75 76 # group by expressions cannot be simplified, for example 77 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 78 # the projection must exactly match the group by key 79 group = node.args.get("group") 80 81 if group and hasattr(node, "selects"): 82 groups = set(group.expressions) 83 group.meta[FINAL] = True 84 85 for s in node.selects: 86 for n in s.walk(): 87 if n in groups: 88 s.meta[FINAL] = True 89 break 90 91 having = node.args.get("having") 92 if having: 93 for n in having.walk(): 94 if n in groups: 95 having.meta[FINAL] = True 96 break 97 98 parent = node.parent 99 root = node is expression 100 101 new_node = rewrite_between(node) 102 new_node = uniq_sort(new_node, root) 103 new_node = absorb_and_eliminate(new_node, root) 104 new_node = simplify_concat(new_node) 105 new_node = simplify_conditionals(new_node) 106 107 if constant_propagation: 108 new_node = propagate_constants(new_node, root) 109 110 if new_node is not node: 111 node.replace(new_node) 112 113 pre_transformation_stack.extend( 114 n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 115 ) 116 post_transformation_stack.append((new_node, parent)) 117 118 while post_transformation_stack: 119 node, parent = post_transformation_stack.pop() 120 root = node is expression 121 122 # Resets parent, arg_key, index pointers– this is needed because some of the 123 # previous transformations mutate the AST, leading to an inconsistent state 124 for k, v in tuple(node.args.items()): 125 node.set(k, v) 126 127 # Post-order transformations 128 new_node = simplify_not(node) 129 new_node = flatten(new_node) 130 new_node = simplify_connectors(new_node, root) 131 new_node = remove_complements(new_node, root) 132 133 if coalesce_simplification: 134 new_node = simplify_coalesce(new_node, dialect) 135 136 new_node.parent = parent 137 138 new_node = simplify_literals(new_node, root) 139 new_node = simplify_equality(new_node) 140 new_node = simplify_parens(new_node, dialect) 141 new_node = simplify_datetrunc(new_node, dialect) 142 new_node = sort_comparison(new_node) 143 new_node = simplify_startswith(new_node) 144 145 if new_node is not node: 146 node.replace(new_node) 147 148 return new_node 149 150 expression = while_changing(expression, _simplify) 151 remove_where_true(expression) 152 return expression 153 154 155def catch(*exceptions): 156 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 157 158 def decorator(func): 159 def wrapped(expression, *args, **kwargs): 160 try: 161 return func(expression, *args, **kwargs) 162 except exceptions: 163 return expression 164 165 return wrapped 166 167 return decorator 168 169 170def rewrite_between(expression: exp.Expression) -> exp.Expression: 171 """Rewrite x between y and z to x >= y AND x <= z. 172 173 This is done because comparison simplification is only done on lt/lte/gt/gte. 174 """ 175 if isinstance(expression, exp.Between): 176 negate = isinstance(expression.parent, exp.Not) 177 178 expression = exp.and_( 179 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 180 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 181 copy=False, 182 ) 183 184 if negate: 185 expression = exp.paren(expression, copy=False) 186 187 return expression 188 189 190COMPLEMENT_COMPARISONS = { 191 exp.LT: exp.GTE, 192 exp.GT: exp.LTE, 193 exp.LTE: exp.GT, 194 exp.GTE: exp.LT, 195 exp.EQ: exp.NEQ, 196 exp.NEQ: exp.EQ, 197} 198 199COMPLEMENT_SUBQUERY_PREDICATES = { 200 exp.All: exp.Any, 201 exp.Any: exp.All, 202} 203 204 205def simplify_not(expression): 206 """ 207 Demorgan's Law 208 NOT (x OR y) -> NOT x AND NOT y 209 NOT (x AND y) -> NOT x OR NOT y 210 """ 211 if isinstance(expression, exp.Not): 212 this = expression.this 213 if is_null(this): 214 return exp.null() 215 if this.__class__ in COMPLEMENT_COMPARISONS: 216 right = this.expression 217 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 218 if complement_subquery_predicate: 219 right = complement_subquery_predicate(this=right.this) 220 221 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 222 if isinstance(this, exp.Paren): 223 condition = this.unnest() 224 if isinstance(condition, exp.And): 225 return exp.paren( 226 exp.or_( 227 exp.not_(condition.left, copy=False), 228 exp.not_(condition.right, copy=False), 229 copy=False, 230 ) 231 ) 232 if isinstance(condition, exp.Or): 233 return exp.paren( 234 exp.and_( 235 exp.not_(condition.left, copy=False), 236 exp.not_(condition.right, copy=False), 237 copy=False, 238 ) 239 ) 240 if is_null(condition): 241 return exp.null() 242 if always_true(this): 243 return exp.false() 244 if is_false(this): 245 return exp.true() 246 if isinstance(this, exp.Not): 247 # double negation 248 # NOT NOT x -> x 249 return this.this 250 return expression 251 252 253def flatten(expression): 254 """ 255 A AND (B AND C) -> A AND B AND C 256 A OR (B OR C) -> A OR B OR C 257 """ 258 if isinstance(expression, exp.Connector): 259 for node in expression.args.values(): 260 child = node.unnest() 261 if isinstance(child, expression.__class__): 262 node.replace(child) 263 return expression 264 265 266def simplify_connectors(expression, root=True): 267 def _simplify_connectors(expression, left, right): 268 if isinstance(expression, exp.And): 269 if is_false(left) or is_false(right): 270 return exp.false() 271 if is_zero(left) or is_zero(right): 272 return exp.false() 273 if is_null(left) or is_null(right): 274 return exp.null() 275 if always_true(left) and always_true(right): 276 return exp.true() 277 if always_true(left): 278 return right 279 if always_true(right): 280 return left 281 return _simplify_comparison(expression, left, right) 282 elif isinstance(expression, exp.Or): 283 if always_true(left) or always_true(right): 284 return exp.true() 285 if ( 286 (is_null(left) and is_null(right)) 287 or (is_null(left) and always_false(right)) 288 or (always_false(left) and is_null(right)) 289 ): 290 return exp.null() 291 if is_false(left): 292 return right 293 if is_false(right): 294 return left 295 return _simplify_comparison(expression, left, right, or_=True) 296 elif isinstance(expression, exp.Xor): 297 if left == right: 298 return exp.false() 299 300 if isinstance(expression, exp.Connector): 301 return _flat_simplify(expression, _simplify_connectors, root) 302 return expression 303 304 305LT_LTE = (exp.LT, exp.LTE) 306GT_GTE = (exp.GT, exp.GTE) 307 308COMPARISONS = ( 309 *LT_LTE, 310 *GT_GTE, 311 exp.EQ, 312 exp.NEQ, 313 exp.Is, 314) 315 316INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 317 exp.LT: exp.GT, 318 exp.GT: exp.LT, 319 exp.LTE: exp.GTE, 320 exp.GTE: exp.LTE, 321} 322 323NONDETERMINISTIC = (exp.Rand, exp.Randn) 324AND_OR = (exp.And, exp.Or) 325 326 327def _simplify_comparison(expression, left, right, or_=False): 328 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 329 ll, lr = left.args.values() 330 rl, rr = right.args.values() 331 332 largs = {ll, lr} 333 rargs = {rl, rr} 334 335 matching = largs & rargs 336 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 337 338 if matching and columns: 339 try: 340 l = first(largs - columns) 341 r = first(rargs - columns) 342 except StopIteration: 343 return expression 344 345 if l.is_number and r.is_number: 346 l = l.to_py() 347 r = r.to_py() 348 elif l.is_string and r.is_string: 349 l = l.name 350 r = r.name 351 else: 352 l = extract_date(l) 353 if not l: 354 return None 355 r = extract_date(r) 356 if not r: 357 return None 358 # python won't compare date and datetime, but many engines will upcast 359 l, r = cast_as_datetime(l), cast_as_datetime(r) 360 361 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 362 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 363 return left if (av > bv if or_ else av <= bv) else right 364 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 365 return left if (av < bv if or_ else av >= bv) else right 366 367 # we can't ever shortcut to true because the column could be null 368 if not or_: 369 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 370 if av <= bv: 371 return exp.false() 372 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 373 if av >= bv: 374 return exp.false() 375 elif isinstance(a, exp.EQ): 376 if isinstance(b, exp.LT): 377 return exp.false() if av >= bv else a 378 if isinstance(b, exp.LTE): 379 return exp.false() if av > bv else a 380 if isinstance(b, exp.GT): 381 return exp.false() if av <= bv else a 382 if isinstance(b, exp.GTE): 383 return exp.false() if av < bv else a 384 if isinstance(b, exp.NEQ): 385 return exp.false() if av == bv else a 386 return None 387 388 389def remove_complements(expression, root=True): 390 """ 391 Removing complements. 392 393 A AND NOT A -> FALSE 394 A OR NOT A -> TRUE 395 """ 396 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 397 ops = set(expression.flatten()) 398 for op in ops: 399 if isinstance(op, exp.Not) and op.this in ops: 400 return exp.false() if isinstance(expression, exp.And) else exp.true() 401 402 return expression 403 404 405def uniq_sort(expression, root=True): 406 """ 407 Uniq and sort a connector. 408 409 C AND A AND B AND B -> A AND B AND C 410 """ 411 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 412 flattened = tuple(expression.flatten()) 413 414 if isinstance(expression, exp.Xor): 415 result_func = exp.xor 416 # Do not deduplicate XOR as A XOR A != A if A == True 417 deduped = None 418 arr = tuple((gen(e), e) for e in flattened) 419 else: 420 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 421 deduped = {gen(e): e for e in flattened} 422 arr = tuple(deduped.items()) 423 424 # check if the operands are already sorted, if not sort them 425 # A AND C AND B -> A AND B AND C 426 for i, (sql, e) in enumerate(arr[1:]): 427 if sql < arr[i][0]: 428 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 429 break 430 else: 431 # we didn't have to sort but maybe we need to dedup 432 if deduped and len(deduped) < len(flattened): 433 expression = result_func(*deduped.values(), copy=False) 434 435 return expression 436 437 438def absorb_and_eliminate(expression, root=True): 439 """ 440 absorption: 441 A AND (A OR B) -> A 442 A OR (A AND B) -> A 443 A AND (NOT A OR B) -> A AND B 444 A OR (NOT A AND B) -> A OR B 445 elimination: 446 (A AND B) OR (A AND NOT B) -> A 447 (A OR B) AND (A OR NOT B) -> A 448 """ 449 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 450 kind = exp.Or if isinstance(expression, exp.And) else exp.And 451 452 ops = tuple(expression.flatten()) 453 454 # Initialize lookup tables: 455 # Set of all operands, used to find complements for absorption. 456 op_set = set() 457 # Sub-operands, used to find subsets for absorption. 458 subops = defaultdict(list) 459 # Pairs of complements, used for elimination. 460 pairs = defaultdict(list) 461 462 # Populate the lookup tables 463 for op in ops: 464 op_set.add(op) 465 466 if not isinstance(op, kind): 467 # In cases like: A OR (A AND B) 468 # Subop will be: ^ 469 subops[op].append({op}) 470 continue 471 472 # In cases like: (A AND B) OR (A AND B AND C) 473 # Subops will be: ^ ^ 474 subset = set(op.flatten()) 475 for i in subset: 476 subops[i].append(subset) 477 478 a, b = op.unnest_operands() 479 if isinstance(a, exp.Not): 480 pairs[frozenset((a.this, b))].append((op, b)) 481 if isinstance(b, exp.Not): 482 pairs[frozenset((a, b.this))].append((op, a)) 483 484 for op in ops: 485 if not isinstance(op, kind): 486 continue 487 488 a, b = op.unnest_operands() 489 490 # Absorb 491 if isinstance(a, exp.Not) and a.this in op_set: 492 a.replace(exp.true() if kind == exp.And else exp.false()) 493 continue 494 if isinstance(b, exp.Not) and b.this in op_set: 495 b.replace(exp.true() if kind == exp.And else exp.false()) 496 continue 497 superset = set(op.flatten()) 498 if any(any(subset < superset for subset in subops[i]) for i in superset): 499 op.replace(exp.false() if kind == exp.And else exp.true()) 500 continue 501 502 # Eliminate 503 for other, complement in pairs[frozenset((a, b))]: 504 op.replace(complement) 505 other.replace(complement) 506 507 return expression 508 509 510def propagate_constants(expression, root=True): 511 """ 512 Propagate constants for conjunctions in DNF: 513 514 SELECT * FROM t WHERE a = b AND b = 5 becomes 515 SELECT * FROM t WHERE a = 5 AND b = 5 516 517 Reference: https://www.sqlite.org/optoverview.html 518 """ 519 520 if ( 521 isinstance(expression, exp.And) 522 and (root or not expression.same_parent) 523 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 524 ): 525 constant_mapping = {} 526 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 527 if isinstance(expr, exp.EQ): 528 l, r = expr.left, expr.right 529 530 # TODO: create a helper that can be used to detect nested literal expressions such 531 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 532 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 533 constant_mapping[l] = (id(l), r) 534 535 if constant_mapping: 536 for column in find_all_in_scope(expression, exp.Column): 537 parent = column.parent 538 column_id, constant = constant_mapping.get(column) or (None, None) 539 if ( 540 column_id is not None 541 and id(column) != column_id 542 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 543 ): 544 column.replace(constant.copy()) 545 546 return expression 547 548 549INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 550 exp.DateAdd: exp.Sub, 551 exp.DateSub: exp.Add, 552 exp.DatetimeAdd: exp.Sub, 553 exp.DatetimeSub: exp.Add, 554} 555 556INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 557 **INVERSE_DATE_OPS, 558 exp.Add: exp.Sub, 559 exp.Sub: exp.Add, 560} 561 562 563def _is_number(expression: exp.Expression) -> bool: 564 return expression.is_number 565 566 567def _is_interval(expression: exp.Expression) -> bool: 568 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 569 570 571@catch(ModuleNotFoundError, UnsupportedUnit) 572def simplify_equality(expression: exp.Expression) -> exp.Expression: 573 """ 574 Use the subtraction and addition properties of equality to simplify expressions: 575 576 x + 1 = 3 becomes x = 2 577 578 There are two binary operations in the above expression: + and = 579 Here's how we reference all the operands in the code below: 580 581 l r 582 x + 1 = 3 583 a b 584 """ 585 if isinstance(expression, COMPARISONS): 586 l, r = expression.left, expression.right 587 588 if l.__class__ not in INVERSE_OPS: 589 return expression 590 591 if r.is_number: 592 a_predicate = _is_number 593 b_predicate = _is_number 594 elif _is_date_literal(r): 595 a_predicate = _is_date_literal 596 b_predicate = _is_interval 597 else: 598 return expression 599 600 if l.__class__ in INVERSE_DATE_OPS: 601 l = t.cast(exp.IntervalOp, l) 602 a = l.this 603 b = l.interval() 604 else: 605 l = t.cast(exp.Binary, l) 606 a, b = l.left, l.right 607 608 if not a_predicate(a) and b_predicate(b): 609 pass 610 elif not a_predicate(b) and b_predicate(a): 611 a, b = b, a 612 else: 613 return expression 614 615 return expression.__class__( 616 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 617 ) 618 return expression 619 620 621def simplify_literals(expression, root=True): 622 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 623 return _flat_simplify(expression, _simplify_binary, root) 624 625 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 626 return expression.this.this 627 628 if type(expression) in INVERSE_DATE_OPS: 629 return _simplify_binary(expression, expression.this, expression.interval()) or expression 630 631 return expression 632 633 634NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 635 636 637def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 638 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 639 this = _simplify_integer_cast(expr.this) 640 else: 641 this = expr.this 642 643 if isinstance(expr, exp.Cast) and this.is_int: 644 num = this.to_py() 645 646 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 647 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 648 # engine-dependent 649 if ( 650 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 651 ) or ( 652 UTINYINT_MIN <= num <= UTINYINT_MAX 653 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 654 ): 655 return this 656 657 return expr 658 659 660def _simplify_binary(expression, a, b): 661 if isinstance(expression, COMPARISONS): 662 a = _simplify_integer_cast(a) 663 b = _simplify_integer_cast(b) 664 665 if isinstance(expression, exp.Is): 666 if isinstance(b, exp.Not): 667 c = b.this 668 not_ = True 669 else: 670 c = b 671 not_ = False 672 673 if is_null(c): 674 if isinstance(a, exp.Literal): 675 return exp.true() if not_ else exp.false() 676 if is_null(a): 677 return exp.false() if not_ else exp.true() 678 elif isinstance(expression, NULL_OK): 679 return None 680 elif is_null(a) or is_null(b): 681 return exp.null() 682 683 if a.is_number and b.is_number: 684 num_a = a.to_py() 685 num_b = b.to_py() 686 687 if isinstance(expression, exp.Add): 688 return exp.Literal.number(num_a + num_b) 689 if isinstance(expression, exp.Mul): 690 return exp.Literal.number(num_a * num_b) 691 692 # We only simplify Sub, Div if a and b have the same parent because they're not associative 693 if isinstance(expression, exp.Sub): 694 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 695 if isinstance(expression, exp.Div): 696 # engines have differing int div behavior so intdiv is not safe 697 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 698 return None 699 return exp.Literal.number(num_a / num_b) 700 701 boolean = eval_boolean(expression, num_a, num_b) 702 703 if boolean: 704 return boolean 705 elif a.is_string and b.is_string: 706 boolean = eval_boolean(expression, a.this, b.this) 707 708 if boolean: 709 return boolean 710 elif _is_date_literal(a) and isinstance(b, exp.Interval): 711 date, b = extract_date(a), extract_interval(b) 712 if date and b: 713 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 714 return date_literal(date + b, extract_type(a)) 715 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 716 return date_literal(date - b, extract_type(a)) 717 elif isinstance(a, exp.Interval) and _is_date_literal(b): 718 a, date = extract_interval(a), extract_date(b) 719 # you cannot subtract a date from an interval 720 if a and b and isinstance(expression, exp.Add): 721 return date_literal(a + date, extract_type(b)) 722 elif _is_date_literal(a) and _is_date_literal(b): 723 if isinstance(expression, exp.Predicate): 724 a, b = extract_date(a), extract_date(b) 725 boolean = eval_boolean(expression, a, b) 726 if boolean: 727 return boolean 728 729 return None 730 731 732def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 733 if not isinstance(expression, exp.Paren): 734 return expression 735 736 this = expression.this 737 parent = expression.parent 738 parent_is_predicate = isinstance(parent, exp.Predicate) 739 740 if isinstance(this, exp.Select): 741 return expression 742 743 if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): 744 return expression 745 746 # Handle risingwave struct columns 747 # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct 748 if ( 749 dialect == "risingwave" 750 and isinstance(parent, exp.Dot) 751 and (isinstance(parent.right, (exp.Identifier, exp.Star))) 752 ): 753 return expression 754 755 if ( 756 not isinstance(parent, (exp.Condition, exp.Binary)) 757 or isinstance(parent, exp.Paren) 758 or ( 759 not isinstance(this, exp.Binary) 760 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 761 ) 762 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 763 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 764 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 765 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 766 ): 767 return this 768 769 return expression 770 771 772def _is_nonnull_constant(expression: exp.Expression) -> bool: 773 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 774 775 776def _is_constant(expression: exp.Expression) -> bool: 777 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 778 779 780def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 781 # COALESCE(x) -> x 782 if ( 783 isinstance(expression, exp.Coalesce) 784 and (not expression.expressions or _is_nonnull_constant(expression.this)) 785 # COALESCE is also used as a Spark partitioning hint 786 and not isinstance(expression.parent, exp.Hint) 787 ): 788 return expression.this 789 790 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 791 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 792 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 793 if dialect == "redshift": 794 return expression 795 796 if not isinstance(expression, COMPARISONS): 797 return expression 798 799 if isinstance(expression.left, exp.Coalesce): 800 coalesce = expression.left 801 other = expression.right 802 elif isinstance(expression.right, exp.Coalesce): 803 coalesce = expression.right 804 other = expression.left 805 else: 806 return expression 807 808 # This transformation is valid for non-constants, 809 # but it really only does anything if they are both constants. 810 if not _is_constant(other): 811 return expression 812 813 # Find the first constant arg 814 for arg_index, arg in enumerate(coalesce.expressions): 815 if _is_constant(arg): 816 break 817 else: 818 return expression 819 820 coalesce.set("expressions", coalesce.expressions[:arg_index]) 821 822 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 823 # since we already remove COALESCE at the top of this function. 824 coalesce = coalesce if coalesce.expressions else coalesce.this 825 826 # This expression is more complex than when we started, but it will get simplified further 827 return exp.paren( 828 exp.or_( 829 exp.and_( 830 coalesce.is_(exp.null()).not_(copy=False), 831 expression.copy(), 832 copy=False, 833 ), 834 exp.and_( 835 coalesce.is_(exp.null()), 836 type(expression)(this=arg.copy(), expression=other.copy()), 837 copy=False, 838 ), 839 copy=False, 840 ) 841 ) 842 843 844CONCATS = (exp.Concat, exp.DPipe) 845 846 847def simplify_concat(expression): 848 """Reduces all groups that contain string literals by concatenating them.""" 849 if not isinstance(expression, CONCATS) or ( 850 # We can't reduce a CONCAT_WS call if we don't statically know the separator 851 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 852 ): 853 return expression 854 855 if isinstance(expression, exp.ConcatWs): 856 sep_expr, *expressions = expression.expressions 857 sep = sep_expr.name 858 concat_type = exp.ConcatWs 859 args = {} 860 else: 861 expressions = expression.expressions 862 sep = "" 863 concat_type = exp.Concat 864 args = { 865 "safe": expression.args.get("safe"), 866 "coalesce": expression.args.get("coalesce"), 867 } 868 869 new_args = [] 870 for is_string_group, group in itertools.groupby( 871 expressions or expression.flatten(), lambda e: e.is_string 872 ): 873 if is_string_group: 874 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 875 else: 876 new_args.extend(group) 877 878 if len(new_args) == 1 and new_args[0].is_string: 879 return new_args[0] 880 881 if concat_type is exp.ConcatWs: 882 new_args = [sep_expr] + new_args 883 elif isinstance(expression, exp.DPipe): 884 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 885 886 return concat_type(expressions=new_args, **args) 887 888 889def simplify_conditionals(expression): 890 """Simplifies expressions like IF, CASE if their condition is statically known.""" 891 if isinstance(expression, exp.Case): 892 this = expression.this 893 for case in expression.args["ifs"]: 894 cond = case.this 895 if this: 896 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 897 cond = cond.replace(this.pop().eq(cond)) 898 899 if always_true(cond): 900 return case.args["true"] 901 902 if always_false(cond): 903 case.pop() 904 if not expression.args["ifs"]: 905 return expression.args.get("default") or exp.null() 906 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 907 if always_true(expression.this): 908 return expression.args["true"] 909 if always_false(expression.this): 910 return expression.args.get("false") or exp.null() 911 912 return expression 913 914 915def simplify_startswith(expression: exp.Expression) -> exp.Expression: 916 """ 917 Reduces a prefix check to either TRUE or FALSE if both the string and the 918 prefix are statically known. 919 920 Example: 921 >>> from sqlglot import parse_one 922 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 923 'TRUE' 924 """ 925 if ( 926 isinstance(expression, exp.StartsWith) 927 and expression.this.is_string 928 and expression.expression.is_string 929 ): 930 return exp.convert(expression.name.startswith(expression.expression.name)) 931 932 return expression 933 934 935DateRange = t.Tuple[datetime.date, datetime.date] 936 937 938def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 939 """ 940 Get the date range for a DATE_TRUNC equality comparison: 941 942 Example: 943 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 944 Returns: 945 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 946 """ 947 floor = date_floor(date, unit, dialect) 948 949 if date != floor: 950 # This will always be False, except for NULL values. 951 return None 952 953 return floor, floor + interval(unit) 954 955 956def _datetrunc_eq_expression( 957 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 958) -> exp.Expression: 959 """Get the logical expression for a date range""" 960 return exp.and_( 961 left >= date_literal(drange[0], target_type), 962 left < date_literal(drange[1], target_type), 963 copy=False, 964 ) 965 966 967def _datetrunc_eq( 968 left: exp.Expression, 969 date: datetime.date, 970 unit: str, 971 dialect: Dialect, 972 target_type: t.Optional[exp.DataType], 973) -> t.Optional[exp.Expression]: 974 drange = _datetrunc_range(date, unit, dialect) 975 if not drange: 976 return None 977 978 return _datetrunc_eq_expression(left, drange, target_type) 979 980 981def _datetrunc_neq( 982 left: exp.Expression, 983 date: datetime.date, 984 unit: str, 985 dialect: Dialect, 986 target_type: t.Optional[exp.DataType], 987) -> t.Optional[exp.Expression]: 988 drange = _datetrunc_range(date, unit, dialect) 989 if not drange: 990 return None 991 992 return exp.and_( 993 left < date_literal(drange[0], target_type), 994 left >= date_literal(drange[1], target_type), 995 copy=False, 996 ) 997 998 999DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 1000 exp.LT: lambda l, dt, u, d, t: l 1001 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 1002 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 1003 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 1004 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 1005 exp.EQ: _datetrunc_eq, 1006 exp.NEQ: _datetrunc_neq, 1007} 1008DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 1009DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 1010 1011 1012def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 1013 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 1014 1015 1016@catch(ModuleNotFoundError, UnsupportedUnit) 1017def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1018 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1019 comparison = expression.__class__ 1020 1021 if isinstance(expression, DATETRUNCS): 1022 this = expression.this 1023 trunc_type = extract_type(this) 1024 date = extract_date(this) 1025 if date and expression.unit: 1026 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1027 elif comparison not in DATETRUNC_COMPARISONS: 1028 return expression 1029 1030 if isinstance(expression, exp.Binary): 1031 l, r = expression.left, expression.right 1032 1033 if not _is_datetrunc_predicate(l, r): 1034 return expression 1035 1036 l = t.cast(exp.DateTrunc, l) 1037 trunc_arg = l.this 1038 unit = l.unit.name.lower() 1039 date = extract_date(r) 1040 1041 if not date: 1042 return expression 1043 1044 return ( 1045 DATETRUNC_BINARY_COMPARISONS[comparison]( 1046 trunc_arg, date, unit, dialect, extract_type(r) 1047 ) 1048 or expression 1049 ) 1050 1051 if isinstance(expression, exp.In): 1052 l = expression.this 1053 rs = expression.expressions 1054 1055 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1056 l = t.cast(exp.DateTrunc, l) 1057 unit = l.unit.name.lower() 1058 1059 ranges = [] 1060 for r in rs: 1061 date = extract_date(r) 1062 if not date: 1063 return expression 1064 drange = _datetrunc_range(date, unit, dialect) 1065 if drange: 1066 ranges.append(drange) 1067 1068 if not ranges: 1069 return expression 1070 1071 ranges = merge_ranges(ranges) 1072 target_type = extract_type(*rs) 1073 1074 return exp.or_( 1075 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1076 ) 1077 1078 return expression 1079 1080 1081def sort_comparison(expression: exp.Expression) -> exp.Expression: 1082 if expression.__class__ in COMPLEMENT_COMPARISONS: 1083 l, r = expression.this, expression.expression 1084 l_column = isinstance(l, exp.Column) 1085 r_column = isinstance(r, exp.Column) 1086 l_const = _is_constant(l) 1087 r_const = _is_constant(r) 1088 1089 if ( 1090 (l_column and not r_column) 1091 or (r_const and not l_const) 1092 or isinstance(r, exp.SubqueryPredicate) 1093 ): 1094 return expression 1095 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1096 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1097 this=r, expression=l 1098 ) 1099 return expression 1100 1101 1102# CROSS joins result in an empty table if the right table is empty. 1103# So we can only simplify certain types of joins to CROSS. 1104# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1105JOINS = { 1106 ("", ""), 1107 ("", "INNER"), 1108 ("RIGHT", ""), 1109 ("RIGHT", "OUTER"), 1110} 1111 1112 1113def remove_where_true(expression): 1114 for where in expression.find_all(exp.Where): 1115 if always_true(where.this): 1116 where.pop() 1117 for join in expression.find_all(exp.Join): 1118 if ( 1119 always_true(join.args.get("on")) 1120 and not join.args.get("using") 1121 and not join.args.get("method") 1122 and (join.side, join.kind) in JOINS 1123 ): 1124 join.args["on"].pop() 1125 join.set("side", None) 1126 join.set("kind", "CROSS") 1127 1128 1129def always_true(expression): 1130 return (isinstance(expression, exp.Boolean) and expression.this) or ( 1131 isinstance(expression, exp.Literal) and not is_zero(expression) 1132 ) 1133 1134 1135def always_false(expression): 1136 return is_false(expression) or is_null(expression) or is_zero(expression) 1137 1138 1139def is_zero(expression): 1140 return isinstance(expression, exp.Literal) and expression.to_py() == 0 1141 1142 1143def is_complement(a, b): 1144 return isinstance(b, exp.Not) and b.this == a 1145 1146 1147def is_false(a: exp.Expression) -> bool: 1148 return type(a) is exp.Boolean and not a.this 1149 1150 1151def is_null(a: exp.Expression) -> bool: 1152 return type(a) is exp.Null 1153 1154 1155def eval_boolean(expression, a, b): 1156 if isinstance(expression, (exp.EQ, exp.Is)): 1157 return boolean_literal(a == b) 1158 if isinstance(expression, exp.NEQ): 1159 return boolean_literal(a != b) 1160 if isinstance(expression, exp.GT): 1161 return boolean_literal(a > b) 1162 if isinstance(expression, exp.GTE): 1163 return boolean_literal(a >= b) 1164 if isinstance(expression, exp.LT): 1165 return boolean_literal(a < b) 1166 if isinstance(expression, exp.LTE): 1167 return boolean_literal(a <= b) 1168 return None 1169 1170 1171def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1172 if isinstance(value, datetime.datetime): 1173 return value.date() 1174 if isinstance(value, datetime.date): 1175 return value 1176 try: 1177 return datetime.datetime.fromisoformat(value).date() 1178 except ValueError: 1179 return None 1180 1181 1182def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1183 if isinstance(value, datetime.datetime): 1184 return value 1185 if isinstance(value, datetime.date): 1186 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1187 try: 1188 return datetime.datetime.fromisoformat(value) 1189 except ValueError: 1190 return None 1191 1192 1193def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1194 if not value: 1195 return None 1196 if to.is_type(exp.DataType.Type.DATE): 1197 return cast_as_date(value) 1198 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1199 return cast_as_datetime(value) 1200 return None 1201 1202 1203def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1204 if isinstance(cast, exp.Cast): 1205 to = cast.to 1206 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1207 to = exp.DataType.build(exp.DataType.Type.DATE) 1208 else: 1209 return None 1210 1211 if isinstance(cast.this, exp.Literal): 1212 value: t.Any = cast.this.name 1213 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1214 value = extract_date(cast.this) 1215 else: 1216 return None 1217 return cast_value(value, to) 1218 1219 1220def _is_date_literal(expression: exp.Expression) -> bool: 1221 return extract_date(expression) is not None 1222 1223 1224def extract_interval(expression): 1225 try: 1226 n = int(expression.this.to_py()) 1227 unit = expression.text("unit").lower() 1228 return interval(unit, n) 1229 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1230 return None 1231 1232 1233def extract_type(*expressions): 1234 target_type = None 1235 for expression in expressions: 1236 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1237 if target_type: 1238 break 1239 1240 return target_type 1241 1242 1243def date_literal(date, target_type=None): 1244 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1245 target_type = ( 1246 exp.DataType.Type.DATETIME 1247 if isinstance(date, datetime.datetime) 1248 else exp.DataType.Type.DATE 1249 ) 1250 1251 return exp.cast(exp.Literal.string(date), target_type) 1252 1253 1254def interval(unit: str, n: int = 1): 1255 from dateutil.relativedelta import relativedelta 1256 1257 if unit == "year": 1258 return relativedelta(years=1 * n) 1259 if unit == "quarter": 1260 return relativedelta(months=3 * n) 1261 if unit == "month": 1262 return relativedelta(months=1 * n) 1263 if unit == "week": 1264 return relativedelta(weeks=1 * n) 1265 if unit == "day": 1266 return relativedelta(days=1 * n) 1267 if unit == "hour": 1268 return relativedelta(hours=1 * n) 1269 if unit == "minute": 1270 return relativedelta(minutes=1 * n) 1271 if unit == "second": 1272 return relativedelta(seconds=1 * n) 1273 1274 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1275 1276 1277def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1278 if unit == "year": 1279 return d.replace(month=1, day=1) 1280 if unit == "quarter": 1281 if d.month <= 3: 1282 return d.replace(month=1, day=1) 1283 elif d.month <= 6: 1284 return d.replace(month=4, day=1) 1285 elif d.month <= 9: 1286 return d.replace(month=7, day=1) 1287 else: 1288 return d.replace(month=10, day=1) 1289 if unit == "month": 1290 return d.replace(month=d.month, day=1) 1291 if unit == "week": 1292 # Assuming week starts on Monday (0) and ends on Sunday (6) 1293 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1294 if unit == "day": 1295 return d 1296 1297 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1298 1299 1300def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1301 floor = date_floor(d, unit, dialect) 1302 1303 if floor == d: 1304 return d 1305 1306 return floor + interval(unit) 1307 1308 1309def boolean_literal(condition): 1310 return exp.true() if condition else exp.false() 1311 1312 1313def _flat_simplify(expression, simplifier, root=True): 1314 if root or not expression.same_parent: 1315 operands = [] 1316 queue = deque(expression.flatten(unnest=False)) 1317 size = len(queue) 1318 1319 while queue: 1320 a = queue.popleft() 1321 1322 for b in queue: 1323 result = simplifier(expression, a, b) 1324 1325 if result and result is not expression: 1326 queue.remove(b) 1327 queue.appendleft(result) 1328 break 1329 else: 1330 operands.append(a) 1331 1332 if len(operands) < size: 1333 return functools.reduce( 1334 lambda a, b: expression.__class__(this=a, expression=b), operands 1335 ) 1336 return expression 1337 1338 1339def gen(expression: t.Any, comments: bool = False) -> str: 1340 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1341 1342 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1343 generator is expensive so we have a bare minimum sql generator here. 1344 1345 Args: 1346 expression: the expression to convert into a SQL string. 1347 comments: whether to include the expression's comments. 1348 """ 1349 return Gen().gen(expression, comments=comments) 1350 1351 1352class Gen: 1353 def __init__(self): 1354 self.stack = [] 1355 self.sqls = [] 1356 1357 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1358 self.stack = [expression] 1359 self.sqls.clear() 1360 1361 while self.stack: 1362 node = self.stack.pop() 1363 1364 if isinstance(node, exp.Expression): 1365 if comments and node.comments: 1366 self.stack.append(f" /*{','.join(node.comments)}*/") 1367 1368 exp_handler_name = f"{node.key}_sql" 1369 1370 if hasattr(self, exp_handler_name): 1371 getattr(self, exp_handler_name)(node) 1372 elif isinstance(node, exp.Func): 1373 self._function(node) 1374 else: 1375 key = node.key.upper() 1376 self.stack.append(f"{key} " if self._args(node) else key) 1377 elif type(node) is list: 1378 for n in reversed(node): 1379 if n is not None: 1380 self.stack.extend((n, ",")) 1381 if node: 1382 self.stack.pop() 1383 else: 1384 if node is not None: 1385 self.sqls.append(str(node)) 1386 1387 return "".join(self.sqls) 1388 1389 def add_sql(self, e: exp.Add) -> None: 1390 self._binary(e, " + ") 1391 1392 def alias_sql(self, e: exp.Alias) -> None: 1393 self.stack.extend( 1394 ( 1395 e.args.get("alias"), 1396 " AS ", 1397 e.args.get("this"), 1398 ) 1399 ) 1400 1401 def and_sql(self, e: exp.And) -> None: 1402 self._binary(e, " AND ") 1403 1404 def anonymous_sql(self, e: exp.Anonymous) -> None: 1405 this = e.this 1406 if isinstance(this, str): 1407 name = this.upper() 1408 elif isinstance(this, exp.Identifier): 1409 name = this.this 1410 name = f'"{name}"' if this.quoted else name.upper() 1411 else: 1412 raise ValueError( 1413 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1414 ) 1415 1416 self.stack.extend( 1417 ( 1418 ")", 1419 e.expressions, 1420 "(", 1421 name, 1422 ) 1423 ) 1424 1425 def between_sql(self, e: exp.Between) -> None: 1426 self.stack.extend( 1427 ( 1428 e.args.get("high"), 1429 " AND ", 1430 e.args.get("low"), 1431 " BETWEEN ", 1432 e.this, 1433 ) 1434 ) 1435 1436 def boolean_sql(self, e: exp.Boolean) -> None: 1437 self.stack.append("TRUE" if e.this else "FALSE") 1438 1439 def bracket_sql(self, e: exp.Bracket) -> None: 1440 self.stack.extend( 1441 ( 1442 "]", 1443 e.expressions, 1444 "[", 1445 e.this, 1446 ) 1447 ) 1448 1449 def column_sql(self, e: exp.Column) -> None: 1450 for p in reversed(e.parts): 1451 self.stack.extend((p, ".")) 1452 self.stack.pop() 1453 1454 def datatype_sql(self, e: exp.DataType) -> None: 1455 self._args(e, 1) 1456 self.stack.append(f"{e.this.name} ") 1457 1458 def div_sql(self, e: exp.Div) -> None: 1459 self._binary(e, " / ") 1460 1461 def dot_sql(self, e: exp.Dot) -> None: 1462 self._binary(e, ".") 1463 1464 def eq_sql(self, e: exp.EQ) -> None: 1465 self._binary(e, " = ") 1466 1467 def from_sql(self, e: exp.From) -> None: 1468 self.stack.extend((e.this, "FROM ")) 1469 1470 def gt_sql(self, e: exp.GT) -> None: 1471 self._binary(e, " > ") 1472 1473 def gte_sql(self, e: exp.GTE) -> None: 1474 self._binary(e, " >= ") 1475 1476 def identifier_sql(self, e: exp.Identifier) -> None: 1477 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1478 1479 def ilike_sql(self, e: exp.ILike) -> None: 1480 self._binary(e, " ILIKE ") 1481 1482 def in_sql(self, e: exp.In) -> None: 1483 self.stack.append(")") 1484 self._args(e, 1) 1485 self.stack.extend( 1486 ( 1487 "(", 1488 " IN ", 1489 e.this, 1490 ) 1491 ) 1492 1493 def intdiv_sql(self, e: exp.IntDiv) -> None: 1494 self._binary(e, " DIV ") 1495 1496 def is_sql(self, e: exp.Is) -> None: 1497 self._binary(e, " IS ") 1498 1499 def like_sql(self, e: exp.Like) -> None: 1500 self._binary(e, " Like ") 1501 1502 def literal_sql(self, e: exp.Literal) -> None: 1503 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1504 1505 def lt_sql(self, e: exp.LT) -> None: 1506 self._binary(e, " < ") 1507 1508 def lte_sql(self, e: exp.LTE) -> None: 1509 self._binary(e, " <= ") 1510 1511 def mod_sql(self, e: exp.Mod) -> None: 1512 self._binary(e, " % ") 1513 1514 def mul_sql(self, e: exp.Mul) -> None: 1515 self._binary(e, " * ") 1516 1517 def neg_sql(self, e: exp.Neg) -> None: 1518 self._unary(e, "-") 1519 1520 def neq_sql(self, e: exp.NEQ) -> None: 1521 self._binary(e, " <> ") 1522 1523 def not_sql(self, e: exp.Not) -> None: 1524 self._unary(e, "NOT ") 1525 1526 def null_sql(self, e: exp.Null) -> None: 1527 self.stack.append("NULL") 1528 1529 def or_sql(self, e: exp.Or) -> None: 1530 self._binary(e, " OR ") 1531 1532 def paren_sql(self, e: exp.Paren) -> None: 1533 self.stack.extend( 1534 ( 1535 ")", 1536 e.this, 1537 "(", 1538 ) 1539 ) 1540 1541 def sub_sql(self, e: exp.Sub) -> None: 1542 self._binary(e, " - ") 1543 1544 def subquery_sql(self, e: exp.Subquery) -> None: 1545 self._args(e, 2) 1546 alias = e.args.get("alias") 1547 if alias: 1548 self.stack.append(alias) 1549 self.stack.extend((")", e.this, "(")) 1550 1551 def table_sql(self, e: exp.Table) -> None: 1552 self._args(e, 4) 1553 alias = e.args.get("alias") 1554 if alias: 1555 self.stack.append(alias) 1556 for p in reversed(e.parts): 1557 self.stack.extend((p, ".")) 1558 self.stack.pop() 1559 1560 def tablealias_sql(self, e: exp.TableAlias) -> None: 1561 columns = e.columns 1562 1563 if columns: 1564 self.stack.extend((")", columns, "(")) 1565 1566 self.stack.extend((e.this, " AS ")) 1567 1568 def var_sql(self, e: exp.Var) -> None: 1569 self.stack.append(e.this) 1570 1571 def _binary(self, e: exp.Binary, op: str) -> None: 1572 self.stack.extend((e.expression, op, e.this)) 1573 1574 def _unary(self, e: exp.Unary, op: str) -> None: 1575 self.stack.extend((e.this, op)) 1576 1577 def _function(self, e: exp.Func) -> None: 1578 self.stack.extend( 1579 ( 1580 ")", 1581 list(e.args.values()), 1582 "(", 1583 e.sql_name(), 1584 ) 1585 ) 1586 1587 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1588 kvs = [] 1589 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1590 1591 for k in arg_types or arg_types: 1592 v = node.args.get(k) 1593 1594 if v is not None: 1595 kvs.append([f":{k}", v]) 1596 if kvs: 1597 self.stack.append(kvs) 1598 return True 1599 return False
Common base class for all non-exit exceptions.
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 coalesce_simplification: bool = False, 44 dialect: DialectType = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 coalesce_simplification: whether the simplify coalesce rule should be used. 59 This rule tries to remove coalesce functions, which can be useful in certain analyses but 60 can leave the query more verbose. 61 Returns: 62 sqlglot.Expression: simplified expression 63 """ 64 65 dialect = Dialect.get_or_raise(dialect) 66 67 def _simplify(expression): 68 pre_transformation_stack = [expression] 69 post_transformation_stack = [] 70 71 while pre_transformation_stack: 72 node = pre_transformation_stack.pop() 73 74 if node.meta.get(FINAL): 75 continue 76 77 # group by expressions cannot be simplified, for example 78 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 79 # the projection must exactly match the group by key 80 group = node.args.get("group") 81 82 if group and hasattr(node, "selects"): 83 groups = set(group.expressions) 84 group.meta[FINAL] = True 85 86 for s in node.selects: 87 for n in s.walk(): 88 if n in groups: 89 s.meta[FINAL] = True 90 break 91 92 having = node.args.get("having") 93 if having: 94 for n in having.walk(): 95 if n in groups: 96 having.meta[FINAL] = True 97 break 98 99 parent = node.parent 100 root = node is expression 101 102 new_node = rewrite_between(node) 103 new_node = uniq_sort(new_node, root) 104 new_node = absorb_and_eliminate(new_node, root) 105 new_node = simplify_concat(new_node) 106 new_node = simplify_conditionals(new_node) 107 108 if constant_propagation: 109 new_node = propagate_constants(new_node, root) 110 111 if new_node is not node: 112 node.replace(new_node) 113 114 pre_transformation_stack.extend( 115 n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 116 ) 117 post_transformation_stack.append((new_node, parent)) 118 119 while post_transformation_stack: 120 node, parent = post_transformation_stack.pop() 121 root = node is expression 122 123 # Resets parent, arg_key, index pointers– this is needed because some of the 124 # previous transformations mutate the AST, leading to an inconsistent state 125 for k, v in tuple(node.args.items()): 126 node.set(k, v) 127 128 # Post-order transformations 129 new_node = simplify_not(node) 130 new_node = flatten(new_node) 131 new_node = simplify_connectors(new_node, root) 132 new_node = remove_complements(new_node, root) 133 134 if coalesce_simplification: 135 new_node = simplify_coalesce(new_node, dialect) 136 137 new_node.parent = parent 138 139 new_node = simplify_literals(new_node, root) 140 new_node = simplify_equality(new_node) 141 new_node = simplify_parens(new_node, dialect) 142 new_node = simplify_datetrunc(new_node, dialect) 143 new_node = sort_comparison(new_node) 144 new_node = simplify_startswith(new_node) 145 146 if new_node is not node: 147 node.replace(new_node) 148 149 return new_node 150 151 expression = while_changing(expression, _simplify) 152 remove_where_true(expression) 153 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:
sqlglot.Expression: simplified expression
156def catch(*exceptions): 157 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 158 159 def decorator(func): 160 def wrapped(expression, *args, **kwargs): 161 try: 162 return func(expression, *args, **kwargs) 163 except exceptions: 164 return expression 165 166 return wrapped 167 168 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
171def rewrite_between(expression: exp.Expression) -> exp.Expression: 172 """Rewrite x between y and z to x >= y AND x <= z. 173 174 This is done because comparison simplification is only done on lt/lte/gt/gte. 175 """ 176 if isinstance(expression, exp.Between): 177 negate = isinstance(expression.parent, exp.Not) 178 179 expression = exp.and_( 180 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 181 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 182 copy=False, 183 ) 184 185 if negate: 186 expression = exp.paren(expression, copy=False) 187 188 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
206def simplify_not(expression): 207 """ 208 Demorgan's Law 209 NOT (x OR y) -> NOT x AND NOT y 210 NOT (x AND y) -> NOT x OR NOT y 211 """ 212 if isinstance(expression, exp.Not): 213 this = expression.this 214 if is_null(this): 215 return exp.null() 216 if this.__class__ in COMPLEMENT_COMPARISONS: 217 right = this.expression 218 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 219 if complement_subquery_predicate: 220 right = complement_subquery_predicate(this=right.this) 221 222 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 223 if isinstance(this, exp.Paren): 224 condition = this.unnest() 225 if isinstance(condition, exp.And): 226 return exp.paren( 227 exp.or_( 228 exp.not_(condition.left, copy=False), 229 exp.not_(condition.right, copy=False), 230 copy=False, 231 ) 232 ) 233 if isinstance(condition, exp.Or): 234 return exp.paren( 235 exp.and_( 236 exp.not_(condition.left, copy=False), 237 exp.not_(condition.right, copy=False), 238 copy=False, 239 ) 240 ) 241 if is_null(condition): 242 return exp.null() 243 if always_true(this): 244 return exp.false() 245 if is_false(this): 246 return exp.true() 247 if isinstance(this, exp.Not): 248 # double negation 249 # NOT NOT x -> x 250 return this.this 251 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
254def flatten(expression): 255 """ 256 A AND (B AND C) -> A AND B AND C 257 A OR (B OR C) -> A OR B OR C 258 """ 259 if isinstance(expression, exp.Connector): 260 for node in expression.args.values(): 261 child = node.unnest() 262 if isinstance(child, expression.__class__): 263 node.replace(child) 264 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
267def simplify_connectors(expression, root=True): 268 def _simplify_connectors(expression, left, right): 269 if isinstance(expression, exp.And): 270 if is_false(left) or is_false(right): 271 return exp.false() 272 if is_zero(left) or is_zero(right): 273 return exp.false() 274 if is_null(left) or is_null(right): 275 return exp.null() 276 if always_true(left) and always_true(right): 277 return exp.true() 278 if always_true(left): 279 return right 280 if always_true(right): 281 return left 282 return _simplify_comparison(expression, left, right) 283 elif isinstance(expression, exp.Or): 284 if always_true(left) or always_true(right): 285 return exp.true() 286 if ( 287 (is_null(left) and is_null(right)) 288 or (is_null(left) and always_false(right)) 289 or (always_false(left) and is_null(right)) 290 ): 291 return exp.null() 292 if is_false(left): 293 return right 294 if is_false(right): 295 return left 296 return _simplify_comparison(expression, left, right, or_=True) 297 elif isinstance(expression, exp.Xor): 298 if left == right: 299 return exp.false() 300 301 if isinstance(expression, exp.Connector): 302 return _flat_simplify(expression, _simplify_connectors, root) 303 return expression
390def remove_complements(expression, root=True): 391 """ 392 Removing complements. 393 394 A AND NOT A -> FALSE 395 A OR NOT A -> TRUE 396 """ 397 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 398 ops = set(expression.flatten()) 399 for op in ops: 400 if isinstance(op, exp.Not) and op.this in ops: 401 return exp.false() if isinstance(expression, exp.And) else exp.true() 402 403 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
406def uniq_sort(expression, root=True): 407 """ 408 Uniq and sort a connector. 409 410 C AND A AND B AND B -> A AND B AND C 411 """ 412 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 413 flattened = tuple(expression.flatten()) 414 415 if isinstance(expression, exp.Xor): 416 result_func = exp.xor 417 # Do not deduplicate XOR as A XOR A != A if A == True 418 deduped = None 419 arr = tuple((gen(e), e) for e in flattened) 420 else: 421 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 422 deduped = {gen(e): e for e in flattened} 423 arr = tuple(deduped.items()) 424 425 # check if the operands are already sorted, if not sort them 426 # A AND C AND B -> A AND B AND C 427 for i, (sql, e) in enumerate(arr[1:]): 428 if sql < arr[i][0]: 429 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 430 break 431 else: 432 # we didn't have to sort but maybe we need to dedup 433 if deduped and len(deduped) < len(flattened): 434 expression = result_func(*deduped.values(), copy=False) 435 436 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
439def absorb_and_eliminate(expression, root=True): 440 """ 441 absorption: 442 A AND (A OR B) -> A 443 A OR (A AND B) -> A 444 A AND (NOT A OR B) -> A AND B 445 A OR (NOT A AND B) -> A OR B 446 elimination: 447 (A AND B) OR (A AND NOT B) -> A 448 (A OR B) AND (A OR NOT B) -> A 449 """ 450 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 451 kind = exp.Or if isinstance(expression, exp.And) else exp.And 452 453 ops = tuple(expression.flatten()) 454 455 # Initialize lookup tables: 456 # Set of all operands, used to find complements for absorption. 457 op_set = set() 458 # Sub-operands, used to find subsets for absorption. 459 subops = defaultdict(list) 460 # Pairs of complements, used for elimination. 461 pairs = defaultdict(list) 462 463 # Populate the lookup tables 464 for op in ops: 465 op_set.add(op) 466 467 if not isinstance(op, kind): 468 # In cases like: A OR (A AND B) 469 # Subop will be: ^ 470 subops[op].append({op}) 471 continue 472 473 # In cases like: (A AND B) OR (A AND B AND C) 474 # Subops will be: ^ ^ 475 subset = set(op.flatten()) 476 for i in subset: 477 subops[i].append(subset) 478 479 a, b = op.unnest_operands() 480 if isinstance(a, exp.Not): 481 pairs[frozenset((a.this, b))].append((op, b)) 482 if isinstance(b, exp.Not): 483 pairs[frozenset((a, b.this))].append((op, a)) 484 485 for op in ops: 486 if not isinstance(op, kind): 487 continue 488 489 a, b = op.unnest_operands() 490 491 # Absorb 492 if isinstance(a, exp.Not) and a.this in op_set: 493 a.replace(exp.true() if kind == exp.And else exp.false()) 494 continue 495 if isinstance(b, exp.Not) and b.this in op_set: 496 b.replace(exp.true() if kind == exp.And else exp.false()) 497 continue 498 superset = set(op.flatten()) 499 if any(any(subset < superset for subset in subops[i]) for i in superset): 500 op.replace(exp.false() if kind == exp.And else exp.true()) 501 continue 502 503 # Eliminate 504 for other, complement in pairs[frozenset((a, b))]: 505 op.replace(complement) 506 other.replace(complement) 507 508 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
511def propagate_constants(expression, root=True): 512 """ 513 Propagate constants for conjunctions in DNF: 514 515 SELECT * FROM t WHERE a = b AND b = 5 becomes 516 SELECT * FROM t WHERE a = 5 AND b = 5 517 518 Reference: https://www.sqlite.org/optoverview.html 519 """ 520 521 if ( 522 isinstance(expression, exp.And) 523 and (root or not expression.same_parent) 524 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 525 ): 526 constant_mapping = {} 527 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 528 if isinstance(expr, exp.EQ): 529 l, r = expr.left, expr.right 530 531 # TODO: create a helper that can be used to detect nested literal expressions such 532 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 533 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 534 constant_mapping[l] = (id(l), r) 535 536 if constant_mapping: 537 for column in find_all_in_scope(expression, exp.Column): 538 parent = column.parent 539 column_id, constant = constant_mapping.get(column) or (None, None) 540 if ( 541 column_id is not None 542 and id(column) != column_id 543 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 544 ): 545 column.replace(constant.copy()) 546 547 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
160 def wrapped(expression, *args, **kwargs): 161 try: 162 return func(expression, *args, **kwargs) 163 except exceptions: 164 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
622def simplify_literals(expression, root=True): 623 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 624 return _flat_simplify(expression, _simplify_binary, root) 625 626 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 627 return expression.this.this 628 629 if type(expression) in INVERSE_DATE_OPS: 630 return _simplify_binary(expression, expression.this, expression.interval()) or expression 631 632 return expression
733def simplify_parens(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 734 if not isinstance(expression, exp.Paren): 735 return expression 736 737 this = expression.this 738 parent = expression.parent 739 parent_is_predicate = isinstance(parent, exp.Predicate) 740 741 if isinstance(this, exp.Select): 742 return expression 743 744 if isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)): 745 return expression 746 747 # Handle risingwave struct columns 748 # see https://docs.risingwave.com/sql/data-types/struct#retrieve-data-in-a-struct 749 if ( 750 dialect == "risingwave" 751 and isinstance(parent, exp.Dot) 752 and (isinstance(parent.right, (exp.Identifier, exp.Star))) 753 ): 754 return expression 755 756 if ( 757 not isinstance(parent, (exp.Condition, exp.Binary)) 758 or isinstance(parent, exp.Paren) 759 or ( 760 not isinstance(this, exp.Binary) 761 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 762 ) 763 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 764 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 765 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 766 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 767 ): 768 return this 769 770 return expression
781def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 782 # COALESCE(x) -> x 783 if ( 784 isinstance(expression, exp.Coalesce) 785 and (not expression.expressions or _is_nonnull_constant(expression.this)) 786 # COALESCE is also used as a Spark partitioning hint 787 and not isinstance(expression.parent, exp.Hint) 788 ): 789 return expression.this 790 791 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 792 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 793 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 794 if dialect == "redshift": 795 return expression 796 797 if not isinstance(expression, COMPARISONS): 798 return expression 799 800 if isinstance(expression.left, exp.Coalesce): 801 coalesce = expression.left 802 other = expression.right 803 elif isinstance(expression.right, exp.Coalesce): 804 coalesce = expression.right 805 other = expression.left 806 else: 807 return expression 808 809 # This transformation is valid for non-constants, 810 # but it really only does anything if they are both constants. 811 if not _is_constant(other): 812 return expression 813 814 # Find the first constant arg 815 for arg_index, arg in enumerate(coalesce.expressions): 816 if _is_constant(arg): 817 break 818 else: 819 return expression 820 821 coalesce.set("expressions", coalesce.expressions[:arg_index]) 822 823 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 824 # since we already remove COALESCE at the top of this function. 825 coalesce = coalesce if coalesce.expressions else coalesce.this 826 827 # This expression is more complex than when we started, but it will get simplified further 828 return exp.paren( 829 exp.or_( 830 exp.and_( 831 coalesce.is_(exp.null()).not_(copy=False), 832 expression.copy(), 833 copy=False, 834 ), 835 exp.and_( 836 coalesce.is_(exp.null()), 837 type(expression)(this=arg.copy(), expression=other.copy()), 838 copy=False, 839 ), 840 copy=False, 841 ) 842 )
848def simplify_concat(expression): 849 """Reduces all groups that contain string literals by concatenating them.""" 850 if not isinstance(expression, CONCATS) or ( 851 # We can't reduce a CONCAT_WS call if we don't statically know the separator 852 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 853 ): 854 return expression 855 856 if isinstance(expression, exp.ConcatWs): 857 sep_expr, *expressions = expression.expressions 858 sep = sep_expr.name 859 concat_type = exp.ConcatWs 860 args = {} 861 else: 862 expressions = expression.expressions 863 sep = "" 864 concat_type = exp.Concat 865 args = { 866 "safe": expression.args.get("safe"), 867 "coalesce": expression.args.get("coalesce"), 868 } 869 870 new_args = [] 871 for is_string_group, group in itertools.groupby( 872 expressions or expression.flatten(), lambda e: e.is_string 873 ): 874 if is_string_group: 875 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 876 else: 877 new_args.extend(group) 878 879 if len(new_args) == 1 and new_args[0].is_string: 880 return new_args[0] 881 882 if concat_type is exp.ConcatWs: 883 new_args = [sep_expr] + new_args 884 elif isinstance(expression, exp.DPipe): 885 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 886 887 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
890def simplify_conditionals(expression): 891 """Simplifies expressions like IF, CASE if their condition is statically known.""" 892 if isinstance(expression, exp.Case): 893 this = expression.this 894 for case in expression.args["ifs"]: 895 cond = case.this 896 if this: 897 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 898 cond = cond.replace(this.pop().eq(cond)) 899 900 if always_true(cond): 901 return case.args["true"] 902 903 if always_false(cond): 904 case.pop() 905 if not expression.args["ifs"]: 906 return expression.args.get("default") or exp.null() 907 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 908 if always_true(expression.this): 909 return expression.args["true"] 910 if always_false(expression.this): 911 return expression.args.get("false") or exp.null() 912 913 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
916def simplify_startswith(expression: exp.Expression) -> exp.Expression: 917 """ 918 Reduces a prefix check to either TRUE or FALSE if both the string and the 919 prefix are statically known. 920 921 Example: 922 >>> from sqlglot import parse_one 923 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 924 'TRUE' 925 """ 926 if ( 927 isinstance(expression, exp.StartsWith) 928 and expression.this.is_string 929 and expression.expression.is_string 930 ): 931 return exp.convert(expression.name.startswith(expression.expression.name)) 932 933 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
160 def wrapped(expression, *args, **kwargs): 161 try: 162 return func(expression, *args, **kwargs) 163 except exceptions: 164 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1082def sort_comparison(expression: exp.Expression) -> exp.Expression: 1083 if expression.__class__ in COMPLEMENT_COMPARISONS: 1084 l, r = expression.this, expression.expression 1085 l_column = isinstance(l, exp.Column) 1086 r_column = isinstance(r, exp.Column) 1087 l_const = _is_constant(l) 1088 r_const = _is_constant(r) 1089 1090 if ( 1091 (l_column and not r_column) 1092 or (r_const and not l_const) 1093 or isinstance(r, exp.SubqueryPredicate) 1094 ): 1095 return expression 1096 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1097 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1098 this=r, expression=l 1099 ) 1100 return expression
1114def remove_where_true(expression): 1115 for where in expression.find_all(exp.Where): 1116 if always_true(where.this): 1117 where.pop() 1118 for join in expression.find_all(exp.Join): 1119 if ( 1120 always_true(join.args.get("on")) 1121 and not join.args.get("using") 1122 and not join.args.get("method") 1123 and (join.side, join.kind) in JOINS 1124 ): 1125 join.args["on"].pop() 1126 join.set("side", None) 1127 join.set("kind", "CROSS")
1156def eval_boolean(expression, a, b): 1157 if isinstance(expression, (exp.EQ, exp.Is)): 1158 return boolean_literal(a == b) 1159 if isinstance(expression, exp.NEQ): 1160 return boolean_literal(a != b) 1161 if isinstance(expression, exp.GT): 1162 return boolean_literal(a > b) 1163 if isinstance(expression, exp.GTE): 1164 return boolean_literal(a >= b) 1165 if isinstance(expression, exp.LT): 1166 return boolean_literal(a < b) 1167 if isinstance(expression, exp.LTE): 1168 return boolean_literal(a <= b) 1169 return None
1172def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1173 if isinstance(value, datetime.datetime): 1174 return value.date() 1175 if isinstance(value, datetime.date): 1176 return value 1177 try: 1178 return datetime.datetime.fromisoformat(value).date() 1179 except ValueError: 1180 return None
1183def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1184 if isinstance(value, datetime.datetime): 1185 return value 1186 if isinstance(value, datetime.date): 1187 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1188 try: 1189 return datetime.datetime.fromisoformat(value) 1190 except ValueError: 1191 return None
1194def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1195 if not value: 1196 return None 1197 if to.is_type(exp.DataType.Type.DATE): 1198 return cast_as_date(value) 1199 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1200 return cast_as_datetime(value) 1201 return None
1204def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1205 if isinstance(cast, exp.Cast): 1206 to = cast.to 1207 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1208 to = exp.DataType.build(exp.DataType.Type.DATE) 1209 else: 1210 return None 1211 1212 if isinstance(cast.this, exp.Literal): 1213 value: t.Any = cast.this.name 1214 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1215 value = extract_date(cast.this) 1216 else: 1217 return None 1218 return cast_value(value, to)
1244def date_literal(date, target_type=None): 1245 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1246 target_type = ( 1247 exp.DataType.Type.DATETIME 1248 if isinstance(date, datetime.datetime) 1249 else exp.DataType.Type.DATE 1250 ) 1251 1252 return exp.cast(exp.Literal.string(date), target_type)
1255def interval(unit: str, n: int = 1): 1256 from dateutil.relativedelta import relativedelta 1257 1258 if unit == "year": 1259 return relativedelta(years=1 * n) 1260 if unit == "quarter": 1261 return relativedelta(months=3 * n) 1262 if unit == "month": 1263 return relativedelta(months=1 * n) 1264 if unit == "week": 1265 return relativedelta(weeks=1 * n) 1266 if unit == "day": 1267 return relativedelta(days=1 * n) 1268 if unit == "hour": 1269 return relativedelta(hours=1 * n) 1270 if unit == "minute": 1271 return relativedelta(minutes=1 * n) 1272 if unit == "second": 1273 return relativedelta(seconds=1 * n) 1274 1275 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1278def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1279 if unit == "year": 1280 return d.replace(month=1, day=1) 1281 if unit == "quarter": 1282 if d.month <= 3: 1283 return d.replace(month=1, day=1) 1284 elif d.month <= 6: 1285 return d.replace(month=4, day=1) 1286 elif d.month <= 9: 1287 return d.replace(month=7, day=1) 1288 else: 1289 return d.replace(month=10, day=1) 1290 if unit == "month": 1291 return d.replace(month=d.month, day=1) 1292 if unit == "week": 1293 # Assuming week starts on Monday (0) and ends on Sunday (6) 1294 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1295 if unit == "day": 1296 return d 1297 1298 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1340def gen(expression: t.Any, comments: bool = False) -> str: 1341 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1342 1343 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1344 generator is expensive so we have a bare minimum sql generator here. 1345 1346 Args: 1347 expression: the expression to convert into a SQL string. 1348 comments: whether to include the expression's comments. 1349 """ 1350 return Gen().gen(expression, comments=comments)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
Arguments:
- expression: the expression to convert into a SQL string.
- comments: whether to include the expression's comments.
1353class Gen: 1354 def __init__(self): 1355 self.stack = [] 1356 self.sqls = [] 1357 1358 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1359 self.stack = [expression] 1360 self.sqls.clear() 1361 1362 while self.stack: 1363 node = self.stack.pop() 1364 1365 if isinstance(node, exp.Expression): 1366 if comments and node.comments: 1367 self.stack.append(f" /*{','.join(node.comments)}*/") 1368 1369 exp_handler_name = f"{node.key}_sql" 1370 1371 if hasattr(self, exp_handler_name): 1372 getattr(self, exp_handler_name)(node) 1373 elif isinstance(node, exp.Func): 1374 self._function(node) 1375 else: 1376 key = node.key.upper() 1377 self.stack.append(f"{key} " if self._args(node) else key) 1378 elif type(node) is list: 1379 for n in reversed(node): 1380 if n is not None: 1381 self.stack.extend((n, ",")) 1382 if node: 1383 self.stack.pop() 1384 else: 1385 if node is not None: 1386 self.sqls.append(str(node)) 1387 1388 return "".join(self.sqls) 1389 1390 def add_sql(self, e: exp.Add) -> None: 1391 self._binary(e, " + ") 1392 1393 def alias_sql(self, e: exp.Alias) -> None: 1394 self.stack.extend( 1395 ( 1396 e.args.get("alias"), 1397 " AS ", 1398 e.args.get("this"), 1399 ) 1400 ) 1401 1402 def and_sql(self, e: exp.And) -> None: 1403 self._binary(e, " AND ") 1404 1405 def anonymous_sql(self, e: exp.Anonymous) -> None: 1406 this = e.this 1407 if isinstance(this, str): 1408 name = this.upper() 1409 elif isinstance(this, exp.Identifier): 1410 name = this.this 1411 name = f'"{name}"' if this.quoted else name.upper() 1412 else: 1413 raise ValueError( 1414 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1415 ) 1416 1417 self.stack.extend( 1418 ( 1419 ")", 1420 e.expressions, 1421 "(", 1422 name, 1423 ) 1424 ) 1425 1426 def between_sql(self, e: exp.Between) -> None: 1427 self.stack.extend( 1428 ( 1429 e.args.get("high"), 1430 " AND ", 1431 e.args.get("low"), 1432 " BETWEEN ", 1433 e.this, 1434 ) 1435 ) 1436 1437 def boolean_sql(self, e: exp.Boolean) -> None: 1438 self.stack.append("TRUE" if e.this else "FALSE") 1439 1440 def bracket_sql(self, e: exp.Bracket) -> None: 1441 self.stack.extend( 1442 ( 1443 "]", 1444 e.expressions, 1445 "[", 1446 e.this, 1447 ) 1448 ) 1449 1450 def column_sql(self, e: exp.Column) -> None: 1451 for p in reversed(e.parts): 1452 self.stack.extend((p, ".")) 1453 self.stack.pop() 1454 1455 def datatype_sql(self, e: exp.DataType) -> None: 1456 self._args(e, 1) 1457 self.stack.append(f"{e.this.name} ") 1458 1459 def div_sql(self, e: exp.Div) -> None: 1460 self._binary(e, " / ") 1461 1462 def dot_sql(self, e: exp.Dot) -> None: 1463 self._binary(e, ".") 1464 1465 def eq_sql(self, e: exp.EQ) -> None: 1466 self._binary(e, " = ") 1467 1468 def from_sql(self, e: exp.From) -> None: 1469 self.stack.extend((e.this, "FROM ")) 1470 1471 def gt_sql(self, e: exp.GT) -> None: 1472 self._binary(e, " > ") 1473 1474 def gte_sql(self, e: exp.GTE) -> None: 1475 self._binary(e, " >= ") 1476 1477 def identifier_sql(self, e: exp.Identifier) -> None: 1478 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1479 1480 def ilike_sql(self, e: exp.ILike) -> None: 1481 self._binary(e, " ILIKE ") 1482 1483 def in_sql(self, e: exp.In) -> None: 1484 self.stack.append(")") 1485 self._args(e, 1) 1486 self.stack.extend( 1487 ( 1488 "(", 1489 " IN ", 1490 e.this, 1491 ) 1492 ) 1493 1494 def intdiv_sql(self, e: exp.IntDiv) -> None: 1495 self._binary(e, " DIV ") 1496 1497 def is_sql(self, e: exp.Is) -> None: 1498 self._binary(e, " IS ") 1499 1500 def like_sql(self, e: exp.Like) -> None: 1501 self._binary(e, " Like ") 1502 1503 def literal_sql(self, e: exp.Literal) -> None: 1504 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1505 1506 def lt_sql(self, e: exp.LT) -> None: 1507 self._binary(e, " < ") 1508 1509 def lte_sql(self, e: exp.LTE) -> None: 1510 self._binary(e, " <= ") 1511 1512 def mod_sql(self, e: exp.Mod) -> None: 1513 self._binary(e, " % ") 1514 1515 def mul_sql(self, e: exp.Mul) -> None: 1516 self._binary(e, " * ") 1517 1518 def neg_sql(self, e: exp.Neg) -> None: 1519 self._unary(e, "-") 1520 1521 def neq_sql(self, e: exp.NEQ) -> None: 1522 self._binary(e, " <> ") 1523 1524 def not_sql(self, e: exp.Not) -> None: 1525 self._unary(e, "NOT ") 1526 1527 def null_sql(self, e: exp.Null) -> None: 1528 self.stack.append("NULL") 1529 1530 def or_sql(self, e: exp.Or) -> None: 1531 self._binary(e, " OR ") 1532 1533 def paren_sql(self, e: exp.Paren) -> None: 1534 self.stack.extend( 1535 ( 1536 ")", 1537 e.this, 1538 "(", 1539 ) 1540 ) 1541 1542 def sub_sql(self, e: exp.Sub) -> None: 1543 self._binary(e, " - ") 1544 1545 def subquery_sql(self, e: exp.Subquery) -> None: 1546 self._args(e, 2) 1547 alias = e.args.get("alias") 1548 if alias: 1549 self.stack.append(alias) 1550 self.stack.extend((")", e.this, "(")) 1551 1552 def table_sql(self, e: exp.Table) -> None: 1553 self._args(e, 4) 1554 alias = e.args.get("alias") 1555 if alias: 1556 self.stack.append(alias) 1557 for p in reversed(e.parts): 1558 self.stack.extend((p, ".")) 1559 self.stack.pop() 1560 1561 def tablealias_sql(self, e: exp.TableAlias) -> None: 1562 columns = e.columns 1563 1564 if columns: 1565 self.stack.extend((")", columns, "(")) 1566 1567 self.stack.extend((e.this, " AS ")) 1568 1569 def var_sql(self, e: exp.Var) -> None: 1570 self.stack.append(e.this) 1571 1572 def _binary(self, e: exp.Binary, op: str) -> None: 1573 self.stack.extend((e.expression, op, e.this)) 1574 1575 def _unary(self, e: exp.Unary, op: str) -> None: 1576 self.stack.extend((e.this, op)) 1577 1578 def _function(self, e: exp.Func) -> None: 1579 self.stack.extend( 1580 ( 1581 ")", 1582 list(e.args.values()), 1583 "(", 1584 e.sql_name(), 1585 ) 1586 ) 1587 1588 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1589 kvs = [] 1590 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1591 1592 for k in arg_types or arg_types: 1593 v = node.args.get(k) 1594 1595 if v is not None: 1596 kvs.append([f":{k}", v]) 1597 if kvs: 1598 self.stack.append(kvs) 1599 return True 1600 return False
1358 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1359 self.stack = [expression] 1360 self.sqls.clear() 1361 1362 while self.stack: 1363 node = self.stack.pop() 1364 1365 if isinstance(node, exp.Expression): 1366 if comments and node.comments: 1367 self.stack.append(f" /*{','.join(node.comments)}*/") 1368 1369 exp_handler_name = f"{node.key}_sql" 1370 1371 if hasattr(self, exp_handler_name): 1372 getattr(self, exp_handler_name)(node) 1373 elif isinstance(node, exp.Func): 1374 self._function(node) 1375 else: 1376 key = node.key.upper() 1377 self.stack.append(f"{key} " if self._args(node) else key) 1378 elif type(node) is list: 1379 for n in reversed(node): 1380 if n is not None: 1381 self.stack.extend((n, ",")) 1382 if node: 1383 self.stack.pop() 1384 else: 1385 if node is not None: 1386 self.sqls.append(str(node)) 1387 1388 return "".join(self.sqls)
1405 def anonymous_sql(self, e: exp.Anonymous) -> None: 1406 this = e.this 1407 if isinstance(this, str): 1408 name = this.upper() 1409 elif isinstance(this, exp.Identifier): 1410 name = this.this 1411 name = f'"{name}"' if this.quoted else name.upper() 1412 else: 1413 raise ValueError( 1414 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1415 ) 1416 1417 self.stack.extend( 1418 ( 1419 ")", 1420 e.expressions, 1421 "(", 1422 name, 1423 ) 1424 )