sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 coalesce_simplification: bool = False, 43 dialect: DialectType = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 coalesce_simplification: whether the simplify coalesce rule should be used. 58 This rule tries to remove coalesce functions, which can be useful in certain analyses but 59 can leave the query more verbose. 60 Returns: 61 sqlglot.Expression: simplified expression 62 """ 63 64 dialect = Dialect.get_or_raise(dialect) 65 66 def _simplify(expression): 67 pre_transformation_stack = [expression] 68 post_transformation_stack = [] 69 70 while pre_transformation_stack: 71 node = pre_transformation_stack.pop() 72 73 if node.meta.get(FINAL): 74 continue 75 76 # group by expressions cannot be simplified, for example 77 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 78 # the projection must exactly match the group by key 79 group = node.args.get("group") 80 81 if group and hasattr(node, "selects"): 82 groups = set(group.expressions) 83 group.meta[FINAL] = True 84 85 for s in node.selects: 86 for n in s.walk(): 87 if n in groups: 88 s.meta[FINAL] = True 89 break 90 91 having = node.args.get("having") 92 if having: 93 for n in having.walk(): 94 if n in groups: 95 having.meta[FINAL] = True 96 break 97 98 parent = node.parent 99 root = node is expression 100 101 new_node = rewrite_between(node) 102 new_node = uniq_sort(new_node, root) 103 new_node = absorb_and_eliminate(new_node, root) 104 new_node = simplify_concat(new_node) 105 new_node = simplify_conditionals(new_node) 106 107 if constant_propagation: 108 new_node = propagate_constants(new_node, root) 109 110 if new_node is not node: 111 node.replace(new_node) 112 113 pre_transformation_stack.extend( 114 n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 115 ) 116 post_transformation_stack.append((new_node, parent)) 117 118 while post_transformation_stack: 119 node, parent = post_transformation_stack.pop() 120 root = node is expression 121 122 # Resets parent, arg_key, index pointers– this is needed because some of the 123 # previous transformations mutate the AST, leading to an inconsistent state 124 for k, v in tuple(node.args.items()): 125 node.set(k, v) 126 127 # Post-order transformations 128 new_node = simplify_not(node) 129 new_node = flatten(new_node) 130 new_node = simplify_connectors(new_node, root) 131 new_node = remove_complements(new_node, root) 132 133 if coalesce_simplification: 134 new_node = simplify_coalesce(new_node, dialect) 135 136 new_node.parent = parent 137 138 new_node = simplify_literals(new_node, root) 139 new_node = simplify_equality(new_node) 140 new_node = simplify_parens(new_node) 141 new_node = simplify_datetrunc(new_node, dialect) 142 new_node = sort_comparison(new_node) 143 new_node = simplify_startswith(new_node) 144 145 if new_node is not node: 146 node.replace(new_node) 147 148 return new_node 149 150 expression = while_changing(expression, _simplify) 151 remove_where_true(expression) 152 return expression 153 154 155def catch(*exceptions): 156 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 157 158 def decorator(func): 159 def wrapped(expression, *args, **kwargs): 160 try: 161 return func(expression, *args, **kwargs) 162 except exceptions: 163 return expression 164 165 return wrapped 166 167 return decorator 168 169 170def rewrite_between(expression: exp.Expression) -> exp.Expression: 171 """Rewrite x between y and z to x >= y AND x <= z. 172 173 This is done because comparison simplification is only done on lt/lte/gt/gte. 174 """ 175 if isinstance(expression, exp.Between): 176 negate = isinstance(expression.parent, exp.Not) 177 178 expression = exp.and_( 179 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 180 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 181 copy=False, 182 ) 183 184 if negate: 185 expression = exp.paren(expression, copy=False) 186 187 return expression 188 189 190COMPLEMENT_COMPARISONS = { 191 exp.LT: exp.GTE, 192 exp.GT: exp.LTE, 193 exp.LTE: exp.GT, 194 exp.GTE: exp.LT, 195 exp.EQ: exp.NEQ, 196 exp.NEQ: exp.EQ, 197} 198 199COMPLEMENT_SUBQUERY_PREDICATES = { 200 exp.All: exp.Any, 201 exp.Any: exp.All, 202} 203 204 205def simplify_not(expression): 206 """ 207 Demorgan's Law 208 NOT (x OR y) -> NOT x AND NOT y 209 NOT (x AND y) -> NOT x OR NOT y 210 """ 211 if isinstance(expression, exp.Not): 212 this = expression.this 213 if is_null(this): 214 return exp.null() 215 if this.__class__ in COMPLEMENT_COMPARISONS: 216 right = this.expression 217 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 218 if complement_subquery_predicate: 219 right = complement_subquery_predicate(this=right.this) 220 221 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 222 if isinstance(this, exp.Paren): 223 condition = this.unnest() 224 if isinstance(condition, exp.And): 225 return exp.paren( 226 exp.or_( 227 exp.not_(condition.left, copy=False), 228 exp.not_(condition.right, copy=False), 229 copy=False, 230 ) 231 ) 232 if isinstance(condition, exp.Or): 233 return exp.paren( 234 exp.and_( 235 exp.not_(condition.left, copy=False), 236 exp.not_(condition.right, copy=False), 237 copy=False, 238 ) 239 ) 240 if is_null(condition): 241 return exp.null() 242 if always_true(this): 243 return exp.false() 244 if is_false(this): 245 return exp.true() 246 if isinstance(this, exp.Not): 247 # double negation 248 # NOT NOT x -> x 249 return this.this 250 return expression 251 252 253def flatten(expression): 254 """ 255 A AND (B AND C) -> A AND B AND C 256 A OR (B OR C) -> A OR B OR C 257 """ 258 if isinstance(expression, exp.Connector): 259 for node in expression.args.values(): 260 child = node.unnest() 261 if isinstance(child, expression.__class__): 262 node.replace(child) 263 return expression 264 265 266def simplify_connectors(expression, root=True): 267 def _simplify_connectors(expression, left, right): 268 if isinstance(expression, exp.And): 269 if is_false(left) or is_false(right): 270 return exp.false() 271 if is_zero(left) or is_zero(right): 272 return exp.false() 273 if is_null(left) or is_null(right): 274 return exp.null() 275 if always_true(left) and always_true(right): 276 return exp.true() 277 if always_true(left): 278 return right 279 if always_true(right): 280 return left 281 return _simplify_comparison(expression, left, right) 282 elif isinstance(expression, exp.Or): 283 if always_true(left) or always_true(right): 284 return exp.true() 285 if ( 286 (is_null(left) and is_null(right)) 287 or (is_null(left) and always_false(right)) 288 or (always_false(left) and is_null(right)) 289 ): 290 return exp.null() 291 if is_false(left): 292 return right 293 if is_false(right): 294 return left 295 return _simplify_comparison(expression, left, right, or_=True) 296 elif isinstance(expression, exp.Xor): 297 if left == right: 298 return exp.false() 299 300 if isinstance(expression, exp.Connector): 301 return _flat_simplify(expression, _simplify_connectors, root) 302 return expression 303 304 305LT_LTE = (exp.LT, exp.LTE) 306GT_GTE = (exp.GT, exp.GTE) 307 308COMPARISONS = ( 309 *LT_LTE, 310 *GT_GTE, 311 exp.EQ, 312 exp.NEQ, 313 exp.Is, 314) 315 316INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 317 exp.LT: exp.GT, 318 exp.GT: exp.LT, 319 exp.LTE: exp.GTE, 320 exp.GTE: exp.LTE, 321} 322 323NONDETERMINISTIC = (exp.Rand, exp.Randn) 324AND_OR = (exp.And, exp.Or) 325 326 327def _simplify_comparison(expression, left, right, or_=False): 328 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 329 ll, lr = left.args.values() 330 rl, rr = right.args.values() 331 332 largs = {ll, lr} 333 rargs = {rl, rr} 334 335 matching = largs & rargs 336 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 337 338 if matching and columns: 339 try: 340 l = first(largs - columns) 341 r = first(rargs - columns) 342 except StopIteration: 343 return expression 344 345 if l.is_number and r.is_number: 346 l = l.to_py() 347 r = r.to_py() 348 elif l.is_string and r.is_string: 349 l = l.name 350 r = r.name 351 else: 352 l = extract_date(l) 353 if not l: 354 return None 355 r = extract_date(r) 356 if not r: 357 return None 358 # python won't compare date and datetime, but many engines will upcast 359 l, r = cast_as_datetime(l), cast_as_datetime(r) 360 361 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 362 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 363 return left if (av > bv if or_ else av <= bv) else right 364 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 365 return left if (av < bv if or_ else av >= bv) else right 366 367 # we can't ever shortcut to true because the column could be null 368 if not or_: 369 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 370 if av <= bv: 371 return exp.false() 372 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 373 if av >= bv: 374 return exp.false() 375 elif isinstance(a, exp.EQ): 376 if isinstance(b, exp.LT): 377 return exp.false() if av >= bv else a 378 if isinstance(b, exp.LTE): 379 return exp.false() if av > bv else a 380 if isinstance(b, exp.GT): 381 return exp.false() if av <= bv else a 382 if isinstance(b, exp.GTE): 383 return exp.false() if av < bv else a 384 if isinstance(b, exp.NEQ): 385 return exp.false() if av == bv else a 386 return None 387 388 389def remove_complements(expression, root=True): 390 """ 391 Removing complements. 392 393 A AND NOT A -> FALSE 394 A OR NOT A -> TRUE 395 """ 396 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 397 ops = set(expression.flatten()) 398 for op in ops: 399 if isinstance(op, exp.Not) and op.this in ops: 400 return exp.false() if isinstance(expression, exp.And) else exp.true() 401 402 return expression 403 404 405def uniq_sort(expression, root=True): 406 """ 407 Uniq and sort a connector. 408 409 C AND A AND B AND B -> A AND B AND C 410 """ 411 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 412 flattened = tuple(expression.flatten()) 413 414 if isinstance(expression, exp.Xor): 415 result_func = exp.xor 416 # Do not deduplicate XOR as A XOR A != A if A == True 417 deduped = None 418 arr = tuple((gen(e), e) for e in flattened) 419 else: 420 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 421 deduped = {gen(e): e for e in flattened} 422 arr = tuple(deduped.items()) 423 424 # check if the operands are already sorted, if not sort them 425 # A AND C AND B -> A AND B AND C 426 for i, (sql, e) in enumerate(arr[1:]): 427 if sql < arr[i][0]: 428 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 429 break 430 else: 431 # we didn't have to sort but maybe we need to dedup 432 if deduped and len(deduped) < len(flattened): 433 expression = result_func(*deduped.values(), copy=False) 434 435 return expression 436 437 438def absorb_and_eliminate(expression, root=True): 439 """ 440 absorption: 441 A AND (A OR B) -> A 442 A OR (A AND B) -> A 443 A AND (NOT A OR B) -> A AND B 444 A OR (NOT A AND B) -> A OR B 445 elimination: 446 (A AND B) OR (A AND NOT B) -> A 447 (A OR B) AND (A OR NOT B) -> A 448 """ 449 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 450 kind = exp.Or if isinstance(expression, exp.And) else exp.And 451 452 ops = tuple(expression.flatten()) 453 454 # Initialize lookup tables: 455 # Set of all operands, used to find complements for absorption. 456 op_set = set() 457 # Sub-operands, used to find subsets for absorption. 458 subops = defaultdict(list) 459 # Pairs of complements, used for elimination. 460 pairs = defaultdict(list) 461 462 # Populate the lookup tables 463 for op in ops: 464 op_set.add(op) 465 466 if not isinstance(op, kind): 467 # In cases like: A OR (A AND B) 468 # Subop will be: ^ 469 subops[op].append({op}) 470 continue 471 472 # In cases like: (A AND B) OR (A AND B AND C) 473 # Subops will be: ^ ^ 474 subset = set(op.flatten()) 475 for i in subset: 476 subops[i].append(subset) 477 478 a, b = op.unnest_operands() 479 if isinstance(a, exp.Not): 480 pairs[frozenset((a.this, b))].append((op, b)) 481 if isinstance(b, exp.Not): 482 pairs[frozenset((a, b.this))].append((op, a)) 483 484 for op in ops: 485 if not isinstance(op, kind): 486 continue 487 488 a, b = op.unnest_operands() 489 490 # Absorb 491 if isinstance(a, exp.Not) and a.this in op_set: 492 a.replace(exp.true() if kind == exp.And else exp.false()) 493 continue 494 if isinstance(b, exp.Not) and b.this in op_set: 495 b.replace(exp.true() if kind == exp.And else exp.false()) 496 continue 497 superset = set(op.flatten()) 498 if any(any(subset < superset for subset in subops[i]) for i in superset): 499 op.replace(exp.false() if kind == exp.And else exp.true()) 500 continue 501 502 # Eliminate 503 for other, complement in pairs[frozenset((a, b))]: 504 op.replace(complement) 505 other.replace(complement) 506 507 return expression 508 509 510def propagate_constants(expression, root=True): 511 """ 512 Propagate constants for conjunctions in DNF: 513 514 SELECT * FROM t WHERE a = b AND b = 5 becomes 515 SELECT * FROM t WHERE a = 5 AND b = 5 516 517 Reference: https://www.sqlite.org/optoverview.html 518 """ 519 520 if ( 521 isinstance(expression, exp.And) 522 and (root or not expression.same_parent) 523 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 524 ): 525 constant_mapping = {} 526 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 527 if isinstance(expr, exp.EQ): 528 l, r = expr.left, expr.right 529 530 # TODO: create a helper that can be used to detect nested literal expressions such 531 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 532 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 533 constant_mapping[l] = (id(l), r) 534 535 if constant_mapping: 536 for column in find_all_in_scope(expression, exp.Column): 537 parent = column.parent 538 column_id, constant = constant_mapping.get(column) or (None, None) 539 if ( 540 column_id is not None 541 and id(column) != column_id 542 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 543 ): 544 column.replace(constant.copy()) 545 546 return expression 547 548 549INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 550 exp.DateAdd: exp.Sub, 551 exp.DateSub: exp.Add, 552 exp.DatetimeAdd: exp.Sub, 553 exp.DatetimeSub: exp.Add, 554} 555 556INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 557 **INVERSE_DATE_OPS, 558 exp.Add: exp.Sub, 559 exp.Sub: exp.Add, 560} 561 562 563def _is_number(expression: exp.Expression) -> bool: 564 return expression.is_number 565 566 567def _is_interval(expression: exp.Expression) -> bool: 568 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 569 570 571@catch(ModuleNotFoundError, UnsupportedUnit) 572def simplify_equality(expression: exp.Expression) -> exp.Expression: 573 """ 574 Use the subtraction and addition properties of equality to simplify expressions: 575 576 x + 1 = 3 becomes x = 2 577 578 There are two binary operations in the above expression: + and = 579 Here's how we reference all the operands in the code below: 580 581 l r 582 x + 1 = 3 583 a b 584 """ 585 if isinstance(expression, COMPARISONS): 586 l, r = expression.left, expression.right 587 588 if l.__class__ not in INVERSE_OPS: 589 return expression 590 591 if r.is_number: 592 a_predicate = _is_number 593 b_predicate = _is_number 594 elif _is_date_literal(r): 595 a_predicate = _is_date_literal 596 b_predicate = _is_interval 597 else: 598 return expression 599 600 if l.__class__ in INVERSE_DATE_OPS: 601 l = t.cast(exp.IntervalOp, l) 602 a = l.this 603 b = l.interval() 604 else: 605 l = t.cast(exp.Binary, l) 606 a, b = l.left, l.right 607 608 if not a_predicate(a) and b_predicate(b): 609 pass 610 elif not a_predicate(b) and b_predicate(a): 611 a, b = b, a 612 else: 613 return expression 614 615 return expression.__class__( 616 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 617 ) 618 return expression 619 620 621def simplify_literals(expression, root=True): 622 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 623 return _flat_simplify(expression, _simplify_binary, root) 624 625 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 626 return expression.this.this 627 628 if type(expression) in INVERSE_DATE_OPS: 629 return _simplify_binary(expression, expression.this, expression.interval()) or expression 630 631 return expression 632 633 634NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 635 636 637def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 638 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 639 this = _simplify_integer_cast(expr.this) 640 else: 641 this = expr.this 642 643 if isinstance(expr, exp.Cast) and this.is_int: 644 num = this.to_py() 645 646 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 647 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 648 # engine-dependent 649 if ( 650 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 651 ) or ( 652 UTINYINT_MIN <= num <= UTINYINT_MAX 653 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 654 ): 655 return this 656 657 return expr 658 659 660def _simplify_binary(expression, a, b): 661 if isinstance(expression, COMPARISONS): 662 a = _simplify_integer_cast(a) 663 b = _simplify_integer_cast(b) 664 665 if isinstance(expression, exp.Is): 666 if isinstance(b, exp.Not): 667 c = b.this 668 not_ = True 669 else: 670 c = b 671 not_ = False 672 673 if is_null(c): 674 if isinstance(a, exp.Literal): 675 return exp.true() if not_ else exp.false() 676 if is_null(a): 677 return exp.false() if not_ else exp.true() 678 elif isinstance(expression, NULL_OK): 679 return None 680 elif is_null(a) or is_null(b): 681 return exp.null() 682 683 if a.is_number and b.is_number: 684 num_a = a.to_py() 685 num_b = b.to_py() 686 687 if isinstance(expression, exp.Add): 688 return exp.Literal.number(num_a + num_b) 689 if isinstance(expression, exp.Mul): 690 return exp.Literal.number(num_a * num_b) 691 692 # We only simplify Sub, Div if a and b have the same parent because they're not associative 693 if isinstance(expression, exp.Sub): 694 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 695 if isinstance(expression, exp.Div): 696 # engines have differing int div behavior so intdiv is not safe 697 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 698 return None 699 return exp.Literal.number(num_a / num_b) 700 701 boolean = eval_boolean(expression, num_a, num_b) 702 703 if boolean: 704 return boolean 705 elif a.is_string and b.is_string: 706 boolean = eval_boolean(expression, a.this, b.this) 707 708 if boolean: 709 return boolean 710 elif _is_date_literal(a) and isinstance(b, exp.Interval): 711 date, b = extract_date(a), extract_interval(b) 712 if date and b: 713 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 714 return date_literal(date + b, extract_type(a)) 715 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 716 return date_literal(date - b, extract_type(a)) 717 elif isinstance(a, exp.Interval) and _is_date_literal(b): 718 a, date = extract_interval(a), extract_date(b) 719 # you cannot subtract a date from an interval 720 if a and b and isinstance(expression, exp.Add): 721 return date_literal(a + date, extract_type(b)) 722 elif _is_date_literal(a) and _is_date_literal(b): 723 if isinstance(expression, exp.Predicate): 724 a, b = extract_date(a), extract_date(b) 725 boolean = eval_boolean(expression, a, b) 726 if boolean: 727 return boolean 728 729 return None 730 731 732def simplify_parens(expression): 733 if not isinstance(expression, exp.Paren): 734 return expression 735 736 this = expression.this 737 parent = expression.parent 738 parent_is_predicate = isinstance(parent, exp.Predicate) 739 740 if ( 741 not isinstance(this, exp.Select) 742 and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) 743 and ( 744 not isinstance(parent, (exp.Condition, exp.Binary)) 745 or isinstance(parent, exp.Paren) 746 or ( 747 not isinstance(this, exp.Binary) 748 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 749 ) 750 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 751 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 752 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 753 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 754 ) 755 ): 756 return this 757 return expression 758 759 760def _is_nonnull_constant(expression: exp.Expression) -> bool: 761 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 762 763 764def _is_constant(expression: exp.Expression) -> bool: 765 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 766 767 768def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 769 # COALESCE(x) -> x 770 if ( 771 isinstance(expression, exp.Coalesce) 772 and (not expression.expressions or _is_nonnull_constant(expression.this)) 773 # COALESCE is also used as a Spark partitioning hint 774 and not isinstance(expression.parent, exp.Hint) 775 ): 776 return expression.this 777 778 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 779 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 780 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 781 if dialect == "redshift": 782 return expression 783 784 if not isinstance(expression, COMPARISONS): 785 return expression 786 787 if isinstance(expression.left, exp.Coalesce): 788 coalesce = expression.left 789 other = expression.right 790 elif isinstance(expression.right, exp.Coalesce): 791 coalesce = expression.right 792 other = expression.left 793 else: 794 return expression 795 796 # This transformation is valid for non-constants, 797 # but it really only does anything if they are both constants. 798 if not _is_constant(other): 799 return expression 800 801 # Find the first constant arg 802 for arg_index, arg in enumerate(coalesce.expressions): 803 if _is_constant(arg): 804 break 805 else: 806 return expression 807 808 coalesce.set("expressions", coalesce.expressions[:arg_index]) 809 810 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 811 # since we already remove COALESCE at the top of this function. 812 coalesce = coalesce if coalesce.expressions else coalesce.this 813 814 # This expression is more complex than when we started, but it will get simplified further 815 return exp.paren( 816 exp.or_( 817 exp.and_( 818 coalesce.is_(exp.null()).not_(copy=False), 819 expression.copy(), 820 copy=False, 821 ), 822 exp.and_( 823 coalesce.is_(exp.null()), 824 type(expression)(this=arg.copy(), expression=other.copy()), 825 copy=False, 826 ), 827 copy=False, 828 ) 829 ) 830 831 832CONCATS = (exp.Concat, exp.DPipe) 833 834 835def simplify_concat(expression): 836 """Reduces all groups that contain string literals by concatenating them.""" 837 if not isinstance(expression, CONCATS) or ( 838 # We can't reduce a CONCAT_WS call if we don't statically know the separator 839 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 840 ): 841 return expression 842 843 if isinstance(expression, exp.ConcatWs): 844 sep_expr, *expressions = expression.expressions 845 sep = sep_expr.name 846 concat_type = exp.ConcatWs 847 args = {} 848 else: 849 expressions = expression.expressions 850 sep = "" 851 concat_type = exp.Concat 852 args = { 853 "safe": expression.args.get("safe"), 854 "coalesce": expression.args.get("coalesce"), 855 } 856 857 new_args = [] 858 for is_string_group, group in itertools.groupby( 859 expressions or expression.flatten(), lambda e: e.is_string 860 ): 861 if is_string_group: 862 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 863 else: 864 new_args.extend(group) 865 866 if len(new_args) == 1 and new_args[0].is_string: 867 return new_args[0] 868 869 if concat_type is exp.ConcatWs: 870 new_args = [sep_expr] + new_args 871 elif isinstance(expression, exp.DPipe): 872 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 873 874 return concat_type(expressions=new_args, **args) 875 876 877def simplify_conditionals(expression): 878 """Simplifies expressions like IF, CASE if their condition is statically known.""" 879 if isinstance(expression, exp.Case): 880 this = expression.this 881 for case in expression.args["ifs"]: 882 cond = case.this 883 if this: 884 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 885 cond = cond.replace(this.pop().eq(cond)) 886 887 if always_true(cond): 888 return case.args["true"] 889 890 if always_false(cond): 891 case.pop() 892 if not expression.args["ifs"]: 893 return expression.args.get("default") or exp.null() 894 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 895 if always_true(expression.this): 896 return expression.args["true"] 897 if always_false(expression.this): 898 return expression.args.get("false") or exp.null() 899 900 return expression 901 902 903def simplify_startswith(expression: exp.Expression) -> exp.Expression: 904 """ 905 Reduces a prefix check to either TRUE or FALSE if both the string and the 906 prefix are statically known. 907 908 Example: 909 >>> from sqlglot import parse_one 910 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 911 'TRUE' 912 """ 913 if ( 914 isinstance(expression, exp.StartsWith) 915 and expression.this.is_string 916 and expression.expression.is_string 917 ): 918 return exp.convert(expression.name.startswith(expression.expression.name)) 919 920 return expression 921 922 923DateRange = t.Tuple[datetime.date, datetime.date] 924 925 926def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 927 """ 928 Get the date range for a DATE_TRUNC equality comparison: 929 930 Example: 931 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 932 Returns: 933 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 934 """ 935 floor = date_floor(date, unit, dialect) 936 937 if date != floor: 938 # This will always be False, except for NULL values. 939 return None 940 941 return floor, floor + interval(unit) 942 943 944def _datetrunc_eq_expression( 945 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 946) -> exp.Expression: 947 """Get the logical expression for a date range""" 948 return exp.and_( 949 left >= date_literal(drange[0], target_type), 950 left < date_literal(drange[1], target_type), 951 copy=False, 952 ) 953 954 955def _datetrunc_eq( 956 left: exp.Expression, 957 date: datetime.date, 958 unit: str, 959 dialect: Dialect, 960 target_type: t.Optional[exp.DataType], 961) -> t.Optional[exp.Expression]: 962 drange = _datetrunc_range(date, unit, dialect) 963 if not drange: 964 return None 965 966 return _datetrunc_eq_expression(left, drange, target_type) 967 968 969def _datetrunc_neq( 970 left: exp.Expression, 971 date: datetime.date, 972 unit: str, 973 dialect: Dialect, 974 target_type: t.Optional[exp.DataType], 975) -> t.Optional[exp.Expression]: 976 drange = _datetrunc_range(date, unit, dialect) 977 if not drange: 978 return None 979 980 return exp.and_( 981 left < date_literal(drange[0], target_type), 982 left >= date_literal(drange[1], target_type), 983 copy=False, 984 ) 985 986 987DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 988 exp.LT: lambda l, dt, u, d, t: l 989 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 990 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 991 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 992 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 993 exp.EQ: _datetrunc_eq, 994 exp.NEQ: _datetrunc_neq, 995} 996DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 997DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 998 999 1000def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 1001 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 1002 1003 1004@catch(ModuleNotFoundError, UnsupportedUnit) 1005def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1006 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1007 comparison = expression.__class__ 1008 1009 if isinstance(expression, DATETRUNCS): 1010 this = expression.this 1011 trunc_type = extract_type(this) 1012 date = extract_date(this) 1013 if date and expression.unit: 1014 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1015 elif comparison not in DATETRUNC_COMPARISONS: 1016 return expression 1017 1018 if isinstance(expression, exp.Binary): 1019 l, r = expression.left, expression.right 1020 1021 if not _is_datetrunc_predicate(l, r): 1022 return expression 1023 1024 l = t.cast(exp.DateTrunc, l) 1025 trunc_arg = l.this 1026 unit = l.unit.name.lower() 1027 date = extract_date(r) 1028 1029 if not date: 1030 return expression 1031 1032 return ( 1033 DATETRUNC_BINARY_COMPARISONS[comparison]( 1034 trunc_arg, date, unit, dialect, extract_type(r) 1035 ) 1036 or expression 1037 ) 1038 1039 if isinstance(expression, exp.In): 1040 l = expression.this 1041 rs = expression.expressions 1042 1043 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1044 l = t.cast(exp.DateTrunc, l) 1045 unit = l.unit.name.lower() 1046 1047 ranges = [] 1048 for r in rs: 1049 date = extract_date(r) 1050 if not date: 1051 return expression 1052 drange = _datetrunc_range(date, unit, dialect) 1053 if drange: 1054 ranges.append(drange) 1055 1056 if not ranges: 1057 return expression 1058 1059 ranges = merge_ranges(ranges) 1060 target_type = extract_type(*rs) 1061 1062 return exp.or_( 1063 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1064 ) 1065 1066 return expression 1067 1068 1069def sort_comparison(expression: exp.Expression) -> exp.Expression: 1070 if expression.__class__ in COMPLEMENT_COMPARISONS: 1071 l, r = expression.this, expression.expression 1072 l_column = isinstance(l, exp.Column) 1073 r_column = isinstance(r, exp.Column) 1074 l_const = _is_constant(l) 1075 r_const = _is_constant(r) 1076 1077 if ( 1078 (l_column and not r_column) 1079 or (r_const and not l_const) 1080 or isinstance(r, exp.SubqueryPredicate) 1081 ): 1082 return expression 1083 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1084 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1085 this=r, expression=l 1086 ) 1087 return expression 1088 1089 1090# CROSS joins result in an empty table if the right table is empty. 1091# So we can only simplify certain types of joins to CROSS. 1092# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1093JOINS = { 1094 ("", ""), 1095 ("", "INNER"), 1096 ("RIGHT", ""), 1097 ("RIGHT", "OUTER"), 1098} 1099 1100 1101def remove_where_true(expression): 1102 for where in expression.find_all(exp.Where): 1103 if always_true(where.this): 1104 where.pop() 1105 for join in expression.find_all(exp.Join): 1106 if ( 1107 always_true(join.args.get("on")) 1108 and not join.args.get("using") 1109 and not join.args.get("method") 1110 and (join.side, join.kind) in JOINS 1111 ): 1112 join.args["on"].pop() 1113 join.set("side", None) 1114 join.set("kind", "CROSS") 1115 1116 1117def always_true(expression): 1118 return (isinstance(expression, exp.Boolean) and expression.this) or ( 1119 isinstance(expression, exp.Literal) and not is_zero(expression) 1120 ) 1121 1122 1123def always_false(expression): 1124 return is_false(expression) or is_null(expression) or is_zero(expression) 1125 1126 1127def is_zero(expression): 1128 return isinstance(expression, exp.Literal) and expression.to_py() == 0 1129 1130 1131def is_complement(a, b): 1132 return isinstance(b, exp.Not) and b.this == a 1133 1134 1135def is_false(a: exp.Expression) -> bool: 1136 return type(a) is exp.Boolean and not a.this 1137 1138 1139def is_null(a: exp.Expression) -> bool: 1140 return type(a) is exp.Null 1141 1142 1143def eval_boolean(expression, a, b): 1144 if isinstance(expression, (exp.EQ, exp.Is)): 1145 return boolean_literal(a == b) 1146 if isinstance(expression, exp.NEQ): 1147 return boolean_literal(a != b) 1148 if isinstance(expression, exp.GT): 1149 return boolean_literal(a > b) 1150 if isinstance(expression, exp.GTE): 1151 return boolean_literal(a >= b) 1152 if isinstance(expression, exp.LT): 1153 return boolean_literal(a < b) 1154 if isinstance(expression, exp.LTE): 1155 return boolean_literal(a <= b) 1156 return None 1157 1158 1159def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1160 if isinstance(value, datetime.datetime): 1161 return value.date() 1162 if isinstance(value, datetime.date): 1163 return value 1164 try: 1165 return datetime.datetime.fromisoformat(value).date() 1166 except ValueError: 1167 return None 1168 1169 1170def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1171 if isinstance(value, datetime.datetime): 1172 return value 1173 if isinstance(value, datetime.date): 1174 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1175 try: 1176 return datetime.datetime.fromisoformat(value) 1177 except ValueError: 1178 return None 1179 1180 1181def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1182 if not value: 1183 return None 1184 if to.is_type(exp.DataType.Type.DATE): 1185 return cast_as_date(value) 1186 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1187 return cast_as_datetime(value) 1188 return None 1189 1190 1191def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1192 if isinstance(cast, exp.Cast): 1193 to = cast.to 1194 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1195 to = exp.DataType.build(exp.DataType.Type.DATE) 1196 else: 1197 return None 1198 1199 if isinstance(cast.this, exp.Literal): 1200 value: t.Any = cast.this.name 1201 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1202 value = extract_date(cast.this) 1203 else: 1204 return None 1205 return cast_value(value, to) 1206 1207 1208def _is_date_literal(expression: exp.Expression) -> bool: 1209 return extract_date(expression) is not None 1210 1211 1212def extract_interval(expression): 1213 try: 1214 n = int(expression.this.to_py()) 1215 unit = expression.text("unit").lower() 1216 return interval(unit, n) 1217 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1218 return None 1219 1220 1221def extract_type(*expressions): 1222 target_type = None 1223 for expression in expressions: 1224 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1225 if target_type: 1226 break 1227 1228 return target_type 1229 1230 1231def date_literal(date, target_type=None): 1232 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1233 target_type = ( 1234 exp.DataType.Type.DATETIME 1235 if isinstance(date, datetime.datetime) 1236 else exp.DataType.Type.DATE 1237 ) 1238 1239 return exp.cast(exp.Literal.string(date), target_type) 1240 1241 1242def interval(unit: str, n: int = 1): 1243 from dateutil.relativedelta import relativedelta 1244 1245 if unit == "year": 1246 return relativedelta(years=1 * n) 1247 if unit == "quarter": 1248 return relativedelta(months=3 * n) 1249 if unit == "month": 1250 return relativedelta(months=1 * n) 1251 if unit == "week": 1252 return relativedelta(weeks=1 * n) 1253 if unit == "day": 1254 return relativedelta(days=1 * n) 1255 if unit == "hour": 1256 return relativedelta(hours=1 * n) 1257 if unit == "minute": 1258 return relativedelta(minutes=1 * n) 1259 if unit == "second": 1260 return relativedelta(seconds=1 * n) 1261 1262 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1263 1264 1265def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1266 if unit == "year": 1267 return d.replace(month=1, day=1) 1268 if unit == "quarter": 1269 if d.month <= 3: 1270 return d.replace(month=1, day=1) 1271 elif d.month <= 6: 1272 return d.replace(month=4, day=1) 1273 elif d.month <= 9: 1274 return d.replace(month=7, day=1) 1275 else: 1276 return d.replace(month=10, day=1) 1277 if unit == "month": 1278 return d.replace(month=d.month, day=1) 1279 if unit == "week": 1280 # Assuming week starts on Monday (0) and ends on Sunday (6) 1281 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1282 if unit == "day": 1283 return d 1284 1285 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1286 1287 1288def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1289 floor = date_floor(d, unit, dialect) 1290 1291 if floor == d: 1292 return d 1293 1294 return floor + interval(unit) 1295 1296 1297def boolean_literal(condition): 1298 return exp.true() if condition else exp.false() 1299 1300 1301def _flat_simplify(expression, simplifier, root=True): 1302 if root or not expression.same_parent: 1303 operands = [] 1304 queue = deque(expression.flatten(unnest=False)) 1305 size = len(queue) 1306 1307 while queue: 1308 a = queue.popleft() 1309 1310 for b in queue: 1311 result = simplifier(expression, a, b) 1312 1313 if result and result is not expression: 1314 queue.remove(b) 1315 queue.appendleft(result) 1316 break 1317 else: 1318 operands.append(a) 1319 1320 if len(operands) < size: 1321 return functools.reduce( 1322 lambda a, b: expression.__class__(this=a, expression=b), operands 1323 ) 1324 return expression 1325 1326 1327def gen(expression: t.Any, comments: bool = False) -> str: 1328 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1329 1330 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1331 generator is expensive so we have a bare minimum sql generator here. 1332 1333 Args: 1334 expression: the expression to convert into a SQL string. 1335 comments: whether to include the expression's comments. 1336 """ 1337 return Gen().gen(expression, comments=comments) 1338 1339 1340class Gen: 1341 def __init__(self): 1342 self.stack = [] 1343 self.sqls = [] 1344 1345 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1346 self.stack = [expression] 1347 self.sqls.clear() 1348 1349 while self.stack: 1350 node = self.stack.pop() 1351 1352 if isinstance(node, exp.Expression): 1353 if comments and node.comments: 1354 self.stack.append(f" /*{','.join(node.comments)}*/") 1355 1356 exp_handler_name = f"{node.key}_sql" 1357 1358 if hasattr(self, exp_handler_name): 1359 getattr(self, exp_handler_name)(node) 1360 elif isinstance(node, exp.Func): 1361 self._function(node) 1362 else: 1363 key = node.key.upper() 1364 self.stack.append(f"{key} " if self._args(node) else key) 1365 elif type(node) is list: 1366 for n in reversed(node): 1367 if n is not None: 1368 self.stack.extend((n, ",")) 1369 if node: 1370 self.stack.pop() 1371 else: 1372 if node is not None: 1373 self.sqls.append(str(node)) 1374 1375 return "".join(self.sqls) 1376 1377 def add_sql(self, e: exp.Add) -> None: 1378 self._binary(e, " + ") 1379 1380 def alias_sql(self, e: exp.Alias) -> None: 1381 self.stack.extend( 1382 ( 1383 e.args.get("alias"), 1384 " AS ", 1385 e.args.get("this"), 1386 ) 1387 ) 1388 1389 def and_sql(self, e: exp.And) -> None: 1390 self._binary(e, " AND ") 1391 1392 def anonymous_sql(self, e: exp.Anonymous) -> None: 1393 this = e.this 1394 if isinstance(this, str): 1395 name = this.upper() 1396 elif isinstance(this, exp.Identifier): 1397 name = this.this 1398 name = f'"{name}"' if this.quoted else name.upper() 1399 else: 1400 raise ValueError( 1401 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1402 ) 1403 1404 self.stack.extend( 1405 ( 1406 ")", 1407 e.expressions, 1408 "(", 1409 name, 1410 ) 1411 ) 1412 1413 def between_sql(self, e: exp.Between) -> None: 1414 self.stack.extend( 1415 ( 1416 e.args.get("high"), 1417 " AND ", 1418 e.args.get("low"), 1419 " BETWEEN ", 1420 e.this, 1421 ) 1422 ) 1423 1424 def boolean_sql(self, e: exp.Boolean) -> None: 1425 self.stack.append("TRUE" if e.this else "FALSE") 1426 1427 def bracket_sql(self, e: exp.Bracket) -> None: 1428 self.stack.extend( 1429 ( 1430 "]", 1431 e.expressions, 1432 "[", 1433 e.this, 1434 ) 1435 ) 1436 1437 def column_sql(self, e: exp.Column) -> None: 1438 for p in reversed(e.parts): 1439 self.stack.extend((p, ".")) 1440 self.stack.pop() 1441 1442 def datatype_sql(self, e: exp.DataType) -> None: 1443 self._args(e, 1) 1444 self.stack.append(f"{e.this.name} ") 1445 1446 def div_sql(self, e: exp.Div) -> None: 1447 self._binary(e, " / ") 1448 1449 def dot_sql(self, e: exp.Dot) -> None: 1450 self._binary(e, ".") 1451 1452 def eq_sql(self, e: exp.EQ) -> None: 1453 self._binary(e, " = ") 1454 1455 def from_sql(self, e: exp.From) -> None: 1456 self.stack.extend((e.this, "FROM ")) 1457 1458 def gt_sql(self, e: exp.GT) -> None: 1459 self._binary(e, " > ") 1460 1461 def gte_sql(self, e: exp.GTE) -> None: 1462 self._binary(e, " >= ") 1463 1464 def identifier_sql(self, e: exp.Identifier) -> None: 1465 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1466 1467 def ilike_sql(self, e: exp.ILike) -> None: 1468 self._binary(e, " ILIKE ") 1469 1470 def in_sql(self, e: exp.In) -> None: 1471 self.stack.append(")") 1472 self._args(e, 1) 1473 self.stack.extend( 1474 ( 1475 "(", 1476 " IN ", 1477 e.this, 1478 ) 1479 ) 1480 1481 def intdiv_sql(self, e: exp.IntDiv) -> None: 1482 self._binary(e, " DIV ") 1483 1484 def is_sql(self, e: exp.Is) -> None: 1485 self._binary(e, " IS ") 1486 1487 def like_sql(self, e: exp.Like) -> None: 1488 self._binary(e, " Like ") 1489 1490 def literal_sql(self, e: exp.Literal) -> None: 1491 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1492 1493 def lt_sql(self, e: exp.LT) -> None: 1494 self._binary(e, " < ") 1495 1496 def lte_sql(self, e: exp.LTE) -> None: 1497 self._binary(e, " <= ") 1498 1499 def mod_sql(self, e: exp.Mod) -> None: 1500 self._binary(e, " % ") 1501 1502 def mul_sql(self, e: exp.Mul) -> None: 1503 self._binary(e, " * ") 1504 1505 def neg_sql(self, e: exp.Neg) -> None: 1506 self._unary(e, "-") 1507 1508 def neq_sql(self, e: exp.NEQ) -> None: 1509 self._binary(e, " <> ") 1510 1511 def not_sql(self, e: exp.Not) -> None: 1512 self._unary(e, "NOT ") 1513 1514 def null_sql(self, e: exp.Null) -> None: 1515 self.stack.append("NULL") 1516 1517 def or_sql(self, e: exp.Or) -> None: 1518 self._binary(e, " OR ") 1519 1520 def paren_sql(self, e: exp.Paren) -> None: 1521 self.stack.extend( 1522 ( 1523 ")", 1524 e.this, 1525 "(", 1526 ) 1527 ) 1528 1529 def sub_sql(self, e: exp.Sub) -> None: 1530 self._binary(e, " - ") 1531 1532 def subquery_sql(self, e: exp.Subquery) -> None: 1533 self._args(e, 2) 1534 alias = e.args.get("alias") 1535 if alias: 1536 self.stack.append(alias) 1537 self.stack.extend((")", e.this, "(")) 1538 1539 def table_sql(self, e: exp.Table) -> None: 1540 self._args(e, 4) 1541 alias = e.args.get("alias") 1542 if alias: 1543 self.stack.append(alias) 1544 for p in reversed(e.parts): 1545 self.stack.extend((p, ".")) 1546 self.stack.pop() 1547 1548 def tablealias_sql(self, e: exp.TableAlias) -> None: 1549 columns = e.columns 1550 1551 if columns: 1552 self.stack.extend((")", columns, "(")) 1553 1554 self.stack.extend((e.this, " AS ")) 1555 1556 def var_sql(self, e: exp.Var) -> None: 1557 self.stack.append(e.this) 1558 1559 def _binary(self, e: exp.Binary, op: str) -> None: 1560 self.stack.extend((e.expression, op, e.this)) 1561 1562 def _unary(self, e: exp.Unary, op: str) -> None: 1563 self.stack.extend((e.this, op)) 1564 1565 def _function(self, e: exp.Func) -> None: 1566 self.stack.extend( 1567 ( 1568 ")", 1569 list(e.args.values()), 1570 "(", 1571 e.sql_name(), 1572 ) 1573 ) 1574 1575 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1576 kvs = [] 1577 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1578 1579 for k in arg_types or arg_types: 1580 v = node.args.get(k) 1581 1582 if v is not None: 1583 kvs.append([f":{k}", v]) 1584 if kvs: 1585 self.stack.append(kvs) 1586 return True 1587 return False
Common base class for all non-exit exceptions.
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 coalesce_simplification: bool = False, 44 dialect: DialectType = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 coalesce_simplification: whether the simplify coalesce rule should be used. 59 This rule tries to remove coalesce functions, which can be useful in certain analyses but 60 can leave the query more verbose. 61 Returns: 62 sqlglot.Expression: simplified expression 63 """ 64 65 dialect = Dialect.get_or_raise(dialect) 66 67 def _simplify(expression): 68 pre_transformation_stack = [expression] 69 post_transformation_stack = [] 70 71 while pre_transformation_stack: 72 node = pre_transformation_stack.pop() 73 74 if node.meta.get(FINAL): 75 continue 76 77 # group by expressions cannot be simplified, for example 78 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 79 # the projection must exactly match the group by key 80 group = node.args.get("group") 81 82 if group and hasattr(node, "selects"): 83 groups = set(group.expressions) 84 group.meta[FINAL] = True 85 86 for s in node.selects: 87 for n in s.walk(): 88 if n in groups: 89 s.meta[FINAL] = True 90 break 91 92 having = node.args.get("having") 93 if having: 94 for n in having.walk(): 95 if n in groups: 96 having.meta[FINAL] = True 97 break 98 99 parent = node.parent 100 root = node is expression 101 102 new_node = rewrite_between(node) 103 new_node = uniq_sort(new_node, root) 104 new_node = absorb_and_eliminate(new_node, root) 105 new_node = simplify_concat(new_node) 106 new_node = simplify_conditionals(new_node) 107 108 if constant_propagation: 109 new_node = propagate_constants(new_node, root) 110 111 if new_node is not node: 112 node.replace(new_node) 113 114 pre_transformation_stack.extend( 115 n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 116 ) 117 post_transformation_stack.append((new_node, parent)) 118 119 while post_transformation_stack: 120 node, parent = post_transformation_stack.pop() 121 root = node is expression 122 123 # Resets parent, arg_key, index pointers– this is needed because some of the 124 # previous transformations mutate the AST, leading to an inconsistent state 125 for k, v in tuple(node.args.items()): 126 node.set(k, v) 127 128 # Post-order transformations 129 new_node = simplify_not(node) 130 new_node = flatten(new_node) 131 new_node = simplify_connectors(new_node, root) 132 new_node = remove_complements(new_node, root) 133 134 if coalesce_simplification: 135 new_node = simplify_coalesce(new_node, dialect) 136 137 new_node.parent = parent 138 139 new_node = simplify_literals(new_node, root) 140 new_node = simplify_equality(new_node) 141 new_node = simplify_parens(new_node) 142 new_node = simplify_datetrunc(new_node, dialect) 143 new_node = sort_comparison(new_node) 144 new_node = simplify_startswith(new_node) 145 146 if new_node is not node: 147 node.replace(new_node) 148 149 return new_node 150 151 expression = while_changing(expression, _simplify) 152 remove_where_true(expression) 153 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- coalesce_simplification: whether the simplify coalesce rule should be used. This rule tries to remove coalesce functions, which can be useful in certain analyses but can leave the query more verbose.
Returns:
sqlglot.Expression: simplified expression
156def catch(*exceptions): 157 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 158 159 def decorator(func): 160 def wrapped(expression, *args, **kwargs): 161 try: 162 return func(expression, *args, **kwargs) 163 except exceptions: 164 return expression 165 166 return wrapped 167 168 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
171def rewrite_between(expression: exp.Expression) -> exp.Expression: 172 """Rewrite x between y and z to x >= y AND x <= z. 173 174 This is done because comparison simplification is only done on lt/lte/gt/gte. 175 """ 176 if isinstance(expression, exp.Between): 177 negate = isinstance(expression.parent, exp.Not) 178 179 expression = exp.and_( 180 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 181 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 182 copy=False, 183 ) 184 185 if negate: 186 expression = exp.paren(expression, copy=False) 187 188 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
206def simplify_not(expression): 207 """ 208 Demorgan's Law 209 NOT (x OR y) -> NOT x AND NOT y 210 NOT (x AND y) -> NOT x OR NOT y 211 """ 212 if isinstance(expression, exp.Not): 213 this = expression.this 214 if is_null(this): 215 return exp.null() 216 if this.__class__ in COMPLEMENT_COMPARISONS: 217 right = this.expression 218 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 219 if complement_subquery_predicate: 220 right = complement_subquery_predicate(this=right.this) 221 222 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 223 if isinstance(this, exp.Paren): 224 condition = this.unnest() 225 if isinstance(condition, exp.And): 226 return exp.paren( 227 exp.or_( 228 exp.not_(condition.left, copy=False), 229 exp.not_(condition.right, copy=False), 230 copy=False, 231 ) 232 ) 233 if isinstance(condition, exp.Or): 234 return exp.paren( 235 exp.and_( 236 exp.not_(condition.left, copy=False), 237 exp.not_(condition.right, copy=False), 238 copy=False, 239 ) 240 ) 241 if is_null(condition): 242 return exp.null() 243 if always_true(this): 244 return exp.false() 245 if is_false(this): 246 return exp.true() 247 if isinstance(this, exp.Not): 248 # double negation 249 # NOT NOT x -> x 250 return this.this 251 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
254def flatten(expression): 255 """ 256 A AND (B AND C) -> A AND B AND C 257 A OR (B OR C) -> A OR B OR C 258 """ 259 if isinstance(expression, exp.Connector): 260 for node in expression.args.values(): 261 child = node.unnest() 262 if isinstance(child, expression.__class__): 263 node.replace(child) 264 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
267def simplify_connectors(expression, root=True): 268 def _simplify_connectors(expression, left, right): 269 if isinstance(expression, exp.And): 270 if is_false(left) or is_false(right): 271 return exp.false() 272 if is_zero(left) or is_zero(right): 273 return exp.false() 274 if is_null(left) or is_null(right): 275 return exp.null() 276 if always_true(left) and always_true(right): 277 return exp.true() 278 if always_true(left): 279 return right 280 if always_true(right): 281 return left 282 return _simplify_comparison(expression, left, right) 283 elif isinstance(expression, exp.Or): 284 if always_true(left) or always_true(right): 285 return exp.true() 286 if ( 287 (is_null(left) and is_null(right)) 288 or (is_null(left) and always_false(right)) 289 or (always_false(left) and is_null(right)) 290 ): 291 return exp.null() 292 if is_false(left): 293 return right 294 if is_false(right): 295 return left 296 return _simplify_comparison(expression, left, right, or_=True) 297 elif isinstance(expression, exp.Xor): 298 if left == right: 299 return exp.false() 300 301 if isinstance(expression, exp.Connector): 302 return _flat_simplify(expression, _simplify_connectors, root) 303 return expression
390def remove_complements(expression, root=True): 391 """ 392 Removing complements. 393 394 A AND NOT A -> FALSE 395 A OR NOT A -> TRUE 396 """ 397 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 398 ops = set(expression.flatten()) 399 for op in ops: 400 if isinstance(op, exp.Not) and op.this in ops: 401 return exp.false() if isinstance(expression, exp.And) else exp.true() 402 403 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
406def uniq_sort(expression, root=True): 407 """ 408 Uniq and sort a connector. 409 410 C AND A AND B AND B -> A AND B AND C 411 """ 412 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 413 flattened = tuple(expression.flatten()) 414 415 if isinstance(expression, exp.Xor): 416 result_func = exp.xor 417 # Do not deduplicate XOR as A XOR A != A if A == True 418 deduped = None 419 arr = tuple((gen(e), e) for e in flattened) 420 else: 421 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 422 deduped = {gen(e): e for e in flattened} 423 arr = tuple(deduped.items()) 424 425 # check if the operands are already sorted, if not sort them 426 # A AND C AND B -> A AND B AND C 427 for i, (sql, e) in enumerate(arr[1:]): 428 if sql < arr[i][0]: 429 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 430 break 431 else: 432 # we didn't have to sort but maybe we need to dedup 433 if deduped and len(deduped) < len(flattened): 434 expression = result_func(*deduped.values(), copy=False) 435 436 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
439def absorb_and_eliminate(expression, root=True): 440 """ 441 absorption: 442 A AND (A OR B) -> A 443 A OR (A AND B) -> A 444 A AND (NOT A OR B) -> A AND B 445 A OR (NOT A AND B) -> A OR B 446 elimination: 447 (A AND B) OR (A AND NOT B) -> A 448 (A OR B) AND (A OR NOT B) -> A 449 """ 450 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 451 kind = exp.Or if isinstance(expression, exp.And) else exp.And 452 453 ops = tuple(expression.flatten()) 454 455 # Initialize lookup tables: 456 # Set of all operands, used to find complements for absorption. 457 op_set = set() 458 # Sub-operands, used to find subsets for absorption. 459 subops = defaultdict(list) 460 # Pairs of complements, used for elimination. 461 pairs = defaultdict(list) 462 463 # Populate the lookup tables 464 for op in ops: 465 op_set.add(op) 466 467 if not isinstance(op, kind): 468 # In cases like: A OR (A AND B) 469 # Subop will be: ^ 470 subops[op].append({op}) 471 continue 472 473 # In cases like: (A AND B) OR (A AND B AND C) 474 # Subops will be: ^ ^ 475 subset = set(op.flatten()) 476 for i in subset: 477 subops[i].append(subset) 478 479 a, b = op.unnest_operands() 480 if isinstance(a, exp.Not): 481 pairs[frozenset((a.this, b))].append((op, b)) 482 if isinstance(b, exp.Not): 483 pairs[frozenset((a, b.this))].append((op, a)) 484 485 for op in ops: 486 if not isinstance(op, kind): 487 continue 488 489 a, b = op.unnest_operands() 490 491 # Absorb 492 if isinstance(a, exp.Not) and a.this in op_set: 493 a.replace(exp.true() if kind == exp.And else exp.false()) 494 continue 495 if isinstance(b, exp.Not) and b.this in op_set: 496 b.replace(exp.true() if kind == exp.And else exp.false()) 497 continue 498 superset = set(op.flatten()) 499 if any(any(subset < superset for subset in subops[i]) for i in superset): 500 op.replace(exp.false() if kind == exp.And else exp.true()) 501 continue 502 503 # Eliminate 504 for other, complement in pairs[frozenset((a, b))]: 505 op.replace(complement) 506 other.replace(complement) 507 508 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
511def propagate_constants(expression, root=True): 512 """ 513 Propagate constants for conjunctions in DNF: 514 515 SELECT * FROM t WHERE a = b AND b = 5 becomes 516 SELECT * FROM t WHERE a = 5 AND b = 5 517 518 Reference: https://www.sqlite.org/optoverview.html 519 """ 520 521 if ( 522 isinstance(expression, exp.And) 523 and (root or not expression.same_parent) 524 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 525 ): 526 constant_mapping = {} 527 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 528 if isinstance(expr, exp.EQ): 529 l, r = expr.left, expr.right 530 531 # TODO: create a helper that can be used to detect nested literal expressions such 532 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 533 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 534 constant_mapping[l] = (id(l), r) 535 536 if constant_mapping: 537 for column in find_all_in_scope(expression, exp.Column): 538 parent = column.parent 539 column_id, constant = constant_mapping.get(column) or (None, None) 540 if ( 541 column_id is not None 542 and id(column) != column_id 543 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 544 ): 545 column.replace(constant.copy()) 546 547 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
160 def wrapped(expression, *args, **kwargs): 161 try: 162 return func(expression, *args, **kwargs) 163 except exceptions: 164 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
622def simplify_literals(expression, root=True): 623 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 624 return _flat_simplify(expression, _simplify_binary, root) 625 626 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 627 return expression.this.this 628 629 if type(expression) in INVERSE_DATE_OPS: 630 return _simplify_binary(expression, expression.this, expression.interval()) or expression 631 632 return expression
733def simplify_parens(expression): 734 if not isinstance(expression, exp.Paren): 735 return expression 736 737 this = expression.this 738 parent = expression.parent 739 parent_is_predicate = isinstance(parent, exp.Predicate) 740 741 if ( 742 not isinstance(this, exp.Select) 743 and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) 744 and ( 745 not isinstance(parent, (exp.Condition, exp.Binary)) 746 or isinstance(parent, exp.Paren) 747 or ( 748 not isinstance(this, exp.Binary) 749 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 750 ) 751 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 752 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 753 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 754 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 755 ) 756 ): 757 return this 758 return expression
769def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 770 # COALESCE(x) -> x 771 if ( 772 isinstance(expression, exp.Coalesce) 773 and (not expression.expressions or _is_nonnull_constant(expression.this)) 774 # COALESCE is also used as a Spark partitioning hint 775 and not isinstance(expression.parent, exp.Hint) 776 ): 777 return expression.this 778 779 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 780 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 781 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 782 if dialect == "redshift": 783 return expression 784 785 if not isinstance(expression, COMPARISONS): 786 return expression 787 788 if isinstance(expression.left, exp.Coalesce): 789 coalesce = expression.left 790 other = expression.right 791 elif isinstance(expression.right, exp.Coalesce): 792 coalesce = expression.right 793 other = expression.left 794 else: 795 return expression 796 797 # This transformation is valid for non-constants, 798 # but it really only does anything if they are both constants. 799 if not _is_constant(other): 800 return expression 801 802 # Find the first constant arg 803 for arg_index, arg in enumerate(coalesce.expressions): 804 if _is_constant(arg): 805 break 806 else: 807 return expression 808 809 coalesce.set("expressions", coalesce.expressions[:arg_index]) 810 811 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 812 # since we already remove COALESCE at the top of this function. 813 coalesce = coalesce if coalesce.expressions else coalesce.this 814 815 # This expression is more complex than when we started, but it will get simplified further 816 return exp.paren( 817 exp.or_( 818 exp.and_( 819 coalesce.is_(exp.null()).not_(copy=False), 820 expression.copy(), 821 copy=False, 822 ), 823 exp.and_( 824 coalesce.is_(exp.null()), 825 type(expression)(this=arg.copy(), expression=other.copy()), 826 copy=False, 827 ), 828 copy=False, 829 ) 830 )
836def simplify_concat(expression): 837 """Reduces all groups that contain string literals by concatenating them.""" 838 if not isinstance(expression, CONCATS) or ( 839 # We can't reduce a CONCAT_WS call if we don't statically know the separator 840 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 841 ): 842 return expression 843 844 if isinstance(expression, exp.ConcatWs): 845 sep_expr, *expressions = expression.expressions 846 sep = sep_expr.name 847 concat_type = exp.ConcatWs 848 args = {} 849 else: 850 expressions = expression.expressions 851 sep = "" 852 concat_type = exp.Concat 853 args = { 854 "safe": expression.args.get("safe"), 855 "coalesce": expression.args.get("coalesce"), 856 } 857 858 new_args = [] 859 for is_string_group, group in itertools.groupby( 860 expressions or expression.flatten(), lambda e: e.is_string 861 ): 862 if is_string_group: 863 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 864 else: 865 new_args.extend(group) 866 867 if len(new_args) == 1 and new_args[0].is_string: 868 return new_args[0] 869 870 if concat_type is exp.ConcatWs: 871 new_args = [sep_expr] + new_args 872 elif isinstance(expression, exp.DPipe): 873 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 874 875 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
878def simplify_conditionals(expression): 879 """Simplifies expressions like IF, CASE if their condition is statically known.""" 880 if isinstance(expression, exp.Case): 881 this = expression.this 882 for case in expression.args["ifs"]: 883 cond = case.this 884 if this: 885 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 886 cond = cond.replace(this.pop().eq(cond)) 887 888 if always_true(cond): 889 return case.args["true"] 890 891 if always_false(cond): 892 case.pop() 893 if not expression.args["ifs"]: 894 return expression.args.get("default") or exp.null() 895 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 896 if always_true(expression.this): 897 return expression.args["true"] 898 if always_false(expression.this): 899 return expression.args.get("false") or exp.null() 900 901 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
904def simplify_startswith(expression: exp.Expression) -> exp.Expression: 905 """ 906 Reduces a prefix check to either TRUE or FALSE if both the string and the 907 prefix are statically known. 908 909 Example: 910 >>> from sqlglot import parse_one 911 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 912 'TRUE' 913 """ 914 if ( 915 isinstance(expression, exp.StartsWith) 916 and expression.this.is_string 917 and expression.expression.is_string 918 ): 919 return exp.convert(expression.name.startswith(expression.expression.name)) 920 921 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
160 def wrapped(expression, *args, **kwargs): 161 try: 162 return func(expression, *args, **kwargs) 163 except exceptions: 164 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1070def sort_comparison(expression: exp.Expression) -> exp.Expression: 1071 if expression.__class__ in COMPLEMENT_COMPARISONS: 1072 l, r = expression.this, expression.expression 1073 l_column = isinstance(l, exp.Column) 1074 r_column = isinstance(r, exp.Column) 1075 l_const = _is_constant(l) 1076 r_const = _is_constant(r) 1077 1078 if ( 1079 (l_column and not r_column) 1080 or (r_const and not l_const) 1081 or isinstance(r, exp.SubqueryPredicate) 1082 ): 1083 return expression 1084 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1085 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1086 this=r, expression=l 1087 ) 1088 return expression
1102def remove_where_true(expression): 1103 for where in expression.find_all(exp.Where): 1104 if always_true(where.this): 1105 where.pop() 1106 for join in expression.find_all(exp.Join): 1107 if ( 1108 always_true(join.args.get("on")) 1109 and not join.args.get("using") 1110 and not join.args.get("method") 1111 and (join.side, join.kind) in JOINS 1112 ): 1113 join.args["on"].pop() 1114 join.set("side", None) 1115 join.set("kind", "CROSS")
1144def eval_boolean(expression, a, b): 1145 if isinstance(expression, (exp.EQ, exp.Is)): 1146 return boolean_literal(a == b) 1147 if isinstance(expression, exp.NEQ): 1148 return boolean_literal(a != b) 1149 if isinstance(expression, exp.GT): 1150 return boolean_literal(a > b) 1151 if isinstance(expression, exp.GTE): 1152 return boolean_literal(a >= b) 1153 if isinstance(expression, exp.LT): 1154 return boolean_literal(a < b) 1155 if isinstance(expression, exp.LTE): 1156 return boolean_literal(a <= b) 1157 return None
1160def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1161 if isinstance(value, datetime.datetime): 1162 return value.date() 1163 if isinstance(value, datetime.date): 1164 return value 1165 try: 1166 return datetime.datetime.fromisoformat(value).date() 1167 except ValueError: 1168 return None
1171def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1172 if isinstance(value, datetime.datetime): 1173 return value 1174 if isinstance(value, datetime.date): 1175 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1176 try: 1177 return datetime.datetime.fromisoformat(value) 1178 except ValueError: 1179 return None
1182def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1183 if not value: 1184 return None 1185 if to.is_type(exp.DataType.Type.DATE): 1186 return cast_as_date(value) 1187 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1188 return cast_as_datetime(value) 1189 return None
1192def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1193 if isinstance(cast, exp.Cast): 1194 to = cast.to 1195 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1196 to = exp.DataType.build(exp.DataType.Type.DATE) 1197 else: 1198 return None 1199 1200 if isinstance(cast.this, exp.Literal): 1201 value: t.Any = cast.this.name 1202 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1203 value = extract_date(cast.this) 1204 else: 1205 return None 1206 return cast_value(value, to)
1232def date_literal(date, target_type=None): 1233 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1234 target_type = ( 1235 exp.DataType.Type.DATETIME 1236 if isinstance(date, datetime.datetime) 1237 else exp.DataType.Type.DATE 1238 ) 1239 1240 return exp.cast(exp.Literal.string(date), target_type)
1243def interval(unit: str, n: int = 1): 1244 from dateutil.relativedelta import relativedelta 1245 1246 if unit == "year": 1247 return relativedelta(years=1 * n) 1248 if unit == "quarter": 1249 return relativedelta(months=3 * n) 1250 if unit == "month": 1251 return relativedelta(months=1 * n) 1252 if unit == "week": 1253 return relativedelta(weeks=1 * n) 1254 if unit == "day": 1255 return relativedelta(days=1 * n) 1256 if unit == "hour": 1257 return relativedelta(hours=1 * n) 1258 if unit == "minute": 1259 return relativedelta(minutes=1 * n) 1260 if unit == "second": 1261 return relativedelta(seconds=1 * n) 1262 1263 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1266def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1267 if unit == "year": 1268 return d.replace(month=1, day=1) 1269 if unit == "quarter": 1270 if d.month <= 3: 1271 return d.replace(month=1, day=1) 1272 elif d.month <= 6: 1273 return d.replace(month=4, day=1) 1274 elif d.month <= 9: 1275 return d.replace(month=7, day=1) 1276 else: 1277 return d.replace(month=10, day=1) 1278 if unit == "month": 1279 return d.replace(month=d.month, day=1) 1280 if unit == "week": 1281 # Assuming week starts on Monday (0) and ends on Sunday (6) 1282 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1283 if unit == "day": 1284 return d 1285 1286 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1328def gen(expression: t.Any, comments: bool = False) -> str: 1329 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1330 1331 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1332 generator is expensive so we have a bare minimum sql generator here. 1333 1334 Args: 1335 expression: the expression to convert into a SQL string. 1336 comments: whether to include the expression's comments. 1337 """ 1338 return Gen().gen(expression, comments=comments)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
Arguments:
- expression: the expression to convert into a SQL string.
- comments: whether to include the expression's comments.
1341class Gen: 1342 def __init__(self): 1343 self.stack = [] 1344 self.sqls = [] 1345 1346 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1347 self.stack = [expression] 1348 self.sqls.clear() 1349 1350 while self.stack: 1351 node = self.stack.pop() 1352 1353 if isinstance(node, exp.Expression): 1354 if comments and node.comments: 1355 self.stack.append(f" /*{','.join(node.comments)}*/") 1356 1357 exp_handler_name = f"{node.key}_sql" 1358 1359 if hasattr(self, exp_handler_name): 1360 getattr(self, exp_handler_name)(node) 1361 elif isinstance(node, exp.Func): 1362 self._function(node) 1363 else: 1364 key = node.key.upper() 1365 self.stack.append(f"{key} " if self._args(node) else key) 1366 elif type(node) is list: 1367 for n in reversed(node): 1368 if n is not None: 1369 self.stack.extend((n, ",")) 1370 if node: 1371 self.stack.pop() 1372 else: 1373 if node is not None: 1374 self.sqls.append(str(node)) 1375 1376 return "".join(self.sqls) 1377 1378 def add_sql(self, e: exp.Add) -> None: 1379 self._binary(e, " + ") 1380 1381 def alias_sql(self, e: exp.Alias) -> None: 1382 self.stack.extend( 1383 ( 1384 e.args.get("alias"), 1385 " AS ", 1386 e.args.get("this"), 1387 ) 1388 ) 1389 1390 def and_sql(self, e: exp.And) -> None: 1391 self._binary(e, " AND ") 1392 1393 def anonymous_sql(self, e: exp.Anonymous) -> None: 1394 this = e.this 1395 if isinstance(this, str): 1396 name = this.upper() 1397 elif isinstance(this, exp.Identifier): 1398 name = this.this 1399 name = f'"{name}"' if this.quoted else name.upper() 1400 else: 1401 raise ValueError( 1402 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1403 ) 1404 1405 self.stack.extend( 1406 ( 1407 ")", 1408 e.expressions, 1409 "(", 1410 name, 1411 ) 1412 ) 1413 1414 def between_sql(self, e: exp.Between) -> None: 1415 self.stack.extend( 1416 ( 1417 e.args.get("high"), 1418 " AND ", 1419 e.args.get("low"), 1420 " BETWEEN ", 1421 e.this, 1422 ) 1423 ) 1424 1425 def boolean_sql(self, e: exp.Boolean) -> None: 1426 self.stack.append("TRUE" if e.this else "FALSE") 1427 1428 def bracket_sql(self, e: exp.Bracket) -> None: 1429 self.stack.extend( 1430 ( 1431 "]", 1432 e.expressions, 1433 "[", 1434 e.this, 1435 ) 1436 ) 1437 1438 def column_sql(self, e: exp.Column) -> None: 1439 for p in reversed(e.parts): 1440 self.stack.extend((p, ".")) 1441 self.stack.pop() 1442 1443 def datatype_sql(self, e: exp.DataType) -> None: 1444 self._args(e, 1) 1445 self.stack.append(f"{e.this.name} ") 1446 1447 def div_sql(self, e: exp.Div) -> None: 1448 self._binary(e, " / ") 1449 1450 def dot_sql(self, e: exp.Dot) -> None: 1451 self._binary(e, ".") 1452 1453 def eq_sql(self, e: exp.EQ) -> None: 1454 self._binary(e, " = ") 1455 1456 def from_sql(self, e: exp.From) -> None: 1457 self.stack.extend((e.this, "FROM ")) 1458 1459 def gt_sql(self, e: exp.GT) -> None: 1460 self._binary(e, " > ") 1461 1462 def gte_sql(self, e: exp.GTE) -> None: 1463 self._binary(e, " >= ") 1464 1465 def identifier_sql(self, e: exp.Identifier) -> None: 1466 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1467 1468 def ilike_sql(self, e: exp.ILike) -> None: 1469 self._binary(e, " ILIKE ") 1470 1471 def in_sql(self, e: exp.In) -> None: 1472 self.stack.append(")") 1473 self._args(e, 1) 1474 self.stack.extend( 1475 ( 1476 "(", 1477 " IN ", 1478 e.this, 1479 ) 1480 ) 1481 1482 def intdiv_sql(self, e: exp.IntDiv) -> None: 1483 self._binary(e, " DIV ") 1484 1485 def is_sql(self, e: exp.Is) -> None: 1486 self._binary(e, " IS ") 1487 1488 def like_sql(self, e: exp.Like) -> None: 1489 self._binary(e, " Like ") 1490 1491 def literal_sql(self, e: exp.Literal) -> None: 1492 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1493 1494 def lt_sql(self, e: exp.LT) -> None: 1495 self._binary(e, " < ") 1496 1497 def lte_sql(self, e: exp.LTE) -> None: 1498 self._binary(e, " <= ") 1499 1500 def mod_sql(self, e: exp.Mod) -> None: 1501 self._binary(e, " % ") 1502 1503 def mul_sql(self, e: exp.Mul) -> None: 1504 self._binary(e, " * ") 1505 1506 def neg_sql(self, e: exp.Neg) -> None: 1507 self._unary(e, "-") 1508 1509 def neq_sql(self, e: exp.NEQ) -> None: 1510 self._binary(e, " <> ") 1511 1512 def not_sql(self, e: exp.Not) -> None: 1513 self._unary(e, "NOT ") 1514 1515 def null_sql(self, e: exp.Null) -> None: 1516 self.stack.append("NULL") 1517 1518 def or_sql(self, e: exp.Or) -> None: 1519 self._binary(e, " OR ") 1520 1521 def paren_sql(self, e: exp.Paren) -> None: 1522 self.stack.extend( 1523 ( 1524 ")", 1525 e.this, 1526 "(", 1527 ) 1528 ) 1529 1530 def sub_sql(self, e: exp.Sub) -> None: 1531 self._binary(e, " - ") 1532 1533 def subquery_sql(self, e: exp.Subquery) -> None: 1534 self._args(e, 2) 1535 alias = e.args.get("alias") 1536 if alias: 1537 self.stack.append(alias) 1538 self.stack.extend((")", e.this, "(")) 1539 1540 def table_sql(self, e: exp.Table) -> None: 1541 self._args(e, 4) 1542 alias = e.args.get("alias") 1543 if alias: 1544 self.stack.append(alias) 1545 for p in reversed(e.parts): 1546 self.stack.extend((p, ".")) 1547 self.stack.pop() 1548 1549 def tablealias_sql(self, e: exp.TableAlias) -> None: 1550 columns = e.columns 1551 1552 if columns: 1553 self.stack.extend((")", columns, "(")) 1554 1555 self.stack.extend((e.this, " AS ")) 1556 1557 def var_sql(self, e: exp.Var) -> None: 1558 self.stack.append(e.this) 1559 1560 def _binary(self, e: exp.Binary, op: str) -> None: 1561 self.stack.extend((e.expression, op, e.this)) 1562 1563 def _unary(self, e: exp.Unary, op: str) -> None: 1564 self.stack.extend((e.this, op)) 1565 1566 def _function(self, e: exp.Func) -> None: 1567 self.stack.extend( 1568 ( 1569 ")", 1570 list(e.args.values()), 1571 "(", 1572 e.sql_name(), 1573 ) 1574 ) 1575 1576 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1577 kvs = [] 1578 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1579 1580 for k in arg_types or arg_types: 1581 v = node.args.get(k) 1582 1583 if v is not None: 1584 kvs.append([f":{k}", v]) 1585 if kvs: 1586 self.stack.append(kvs) 1587 return True 1588 return False
1346 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1347 self.stack = [expression] 1348 self.sqls.clear() 1349 1350 while self.stack: 1351 node = self.stack.pop() 1352 1353 if isinstance(node, exp.Expression): 1354 if comments and node.comments: 1355 self.stack.append(f" /*{','.join(node.comments)}*/") 1356 1357 exp_handler_name = f"{node.key}_sql" 1358 1359 if hasattr(self, exp_handler_name): 1360 getattr(self, exp_handler_name)(node) 1361 elif isinstance(node, exp.Func): 1362 self._function(node) 1363 else: 1364 key = node.key.upper() 1365 self.stack.append(f"{key} " if self._args(node) else key) 1366 elif type(node) is list: 1367 for n in reversed(node): 1368 if n is not None: 1369 self.stack.extend((n, ",")) 1370 if node: 1371 self.stack.pop() 1372 else: 1373 if node is not None: 1374 self.sqls.append(str(node)) 1375 1376 return "".join(self.sqls)
1393 def anonymous_sql(self, e: exp.Anonymous) -> None: 1394 this = e.this 1395 if isinstance(this, str): 1396 name = this.upper() 1397 elif isinstance(this, exp.Identifier): 1398 name = this.this 1399 name = f'"{name}"' if this.quoted else name.upper() 1400 else: 1401 raise ValueError( 1402 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1403 ) 1404 1405 self.stack.extend( 1406 ( 1407 ")", 1408 e.expressions, 1409 "(", 1410 name, 1411 ) 1412 )