sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 dialect: DialectType = None, 43 max_depth: t.Optional[int] = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 58 Returns: 59 sqlglot.Expression: simplified expression 60 """ 61 62 dialect = Dialect.get_or_raise(dialect) 63 64 def _simplify(expression, root=True): 65 if ( 66 max_depth 67 and isinstance(expression, exp.Connector) 68 and not isinstance(expression.parent, exp.Connector) 69 ): 70 depth = connector_depth(expression) 71 if depth > max_depth: 72 logger.info( 73 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 74 ) 75 return expression 76 77 if expression.meta.get(FINAL): 78 return expression 79 80 # group by expressions cannot be simplified, for example 81 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 82 # the projection must exactly match the group by key 83 group = expression.args.get("group") 84 85 if group and hasattr(expression, "selects"): 86 groups = set(group.expressions) 87 group.meta[FINAL] = True 88 89 for e in expression.selects: 90 for node in e.walk(): 91 if node in groups: 92 e.meta[FINAL] = True 93 break 94 95 having = expression.args.get("having") 96 if having: 97 for node in having.walk(): 98 if node in groups: 99 having.meta[FINAL] = True 100 break 101 102 # Pre-order transformations 103 node = expression 104 node = rewrite_between(node) 105 node = uniq_sort(node, root) 106 node = absorb_and_eliminate(node, root) 107 node = simplify_concat(node) 108 node = simplify_conditionals(node) 109 110 if constant_propagation: 111 node = propagate_constants(node, root) 112 113 exp.replace_children(node, lambda e: _simplify(e, False)) 114 115 # Post-order transformations 116 node = simplify_not(node) 117 node = flatten(node) 118 node = simplify_connectors(node, root) 119 node = remove_complements(node, root) 120 node = simplify_coalesce(node, dialect) 121 node.parent = expression.parent 122 node = simplify_literals(node, root) 123 node = simplify_equality(node) 124 node = simplify_parens(node) 125 node = simplify_datetrunc(node, dialect) 126 node = sort_comparison(node) 127 node = simplify_startswith(node) 128 129 if root: 130 expression.replace(node) 131 return node 132 133 expression = while_changing(expression, _simplify) 134 remove_where_true(expression) 135 return expression 136 137 138def connector_depth(expression: exp.Expression) -> int: 139 """ 140 Determine the maximum depth of a tree of Connectors. 141 142 For example: 143 >>> from sqlglot import parse_one 144 >>> connector_depth(parse_one("a AND b AND c AND d")) 145 3 146 """ 147 stack = deque([(expression, 0)]) 148 max_depth = 0 149 150 while stack: 151 expression, depth = stack.pop() 152 153 if not isinstance(expression, exp.Connector): 154 continue 155 156 depth += 1 157 max_depth = max(depth, max_depth) 158 159 stack.append((expression.left, depth)) 160 stack.append((expression.right, depth)) 161 162 return max_depth 163 164 165def catch(*exceptions): 166 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 167 168 def decorator(func): 169 def wrapped(expression, *args, **kwargs): 170 try: 171 return func(expression, *args, **kwargs) 172 except exceptions: 173 return expression 174 175 return wrapped 176 177 return decorator 178 179 180def rewrite_between(expression: exp.Expression) -> exp.Expression: 181 """Rewrite x between y and z to x >= y AND x <= z. 182 183 This is done because comparison simplification is only done on lt/lte/gt/gte. 184 """ 185 if isinstance(expression, exp.Between): 186 negate = isinstance(expression.parent, exp.Not) 187 188 expression = exp.and_( 189 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 190 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 191 copy=False, 192 ) 193 194 if negate: 195 expression = exp.paren(expression, copy=False) 196 197 return expression 198 199 200COMPLEMENT_COMPARISONS = { 201 exp.LT: exp.GTE, 202 exp.GT: exp.LTE, 203 exp.LTE: exp.GT, 204 exp.GTE: exp.LT, 205 exp.EQ: exp.NEQ, 206 exp.NEQ: exp.EQ, 207} 208 209COMPLEMENT_SUBQUERY_PREDICATES = { 210 exp.All: exp.Any, 211 exp.Any: exp.All, 212} 213 214 215def simplify_not(expression): 216 """ 217 Demorgan's Law 218 NOT (x OR y) -> NOT x AND NOT y 219 NOT (x AND y) -> NOT x OR NOT y 220 """ 221 if isinstance(expression, exp.Not): 222 this = expression.this 223 if is_null(this): 224 return exp.null() 225 if this.__class__ in COMPLEMENT_COMPARISONS: 226 right = this.expression 227 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 228 if complement_subquery_predicate: 229 right = complement_subquery_predicate(this=right.this) 230 231 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 232 if isinstance(this, exp.Paren): 233 condition = this.unnest() 234 if isinstance(condition, exp.And): 235 return exp.paren( 236 exp.or_( 237 exp.not_(condition.left, copy=False), 238 exp.not_(condition.right, copy=False), 239 copy=False, 240 ) 241 ) 242 if isinstance(condition, exp.Or): 243 return exp.paren( 244 exp.and_( 245 exp.not_(condition.left, copy=False), 246 exp.not_(condition.right, copy=False), 247 copy=False, 248 ) 249 ) 250 if is_null(condition): 251 return exp.null() 252 if always_true(this): 253 return exp.false() 254 if is_false(this): 255 return exp.true() 256 if isinstance(this, exp.Not): 257 # double negation 258 # NOT NOT x -> x 259 return this.this 260 return expression 261 262 263def flatten(expression): 264 """ 265 A AND (B AND C) -> A AND B AND C 266 A OR (B OR C) -> A OR B OR C 267 """ 268 if isinstance(expression, exp.Connector): 269 for node in expression.args.values(): 270 child = node.unnest() 271 if isinstance(child, expression.__class__): 272 node.replace(child) 273 return expression 274 275 276def simplify_connectors(expression, root=True): 277 def _simplify_connectors(expression, left, right): 278 if isinstance(expression, exp.And): 279 if is_false(left) or is_false(right): 280 return exp.false() 281 if is_zero(left) or is_zero(right): 282 return exp.false() 283 if is_null(left) or is_null(right): 284 return exp.null() 285 if always_true(left) and always_true(right): 286 return exp.true() 287 if always_true(left): 288 return right 289 if always_true(right): 290 return left 291 return _simplify_comparison(expression, left, right) 292 elif isinstance(expression, exp.Or): 293 if always_true(left) or always_true(right): 294 return exp.true() 295 if ( 296 (is_null(left) and is_null(right)) 297 or (is_null(left) and always_false(right)) 298 or (always_false(left) and is_null(right)) 299 ): 300 return exp.null() 301 if is_false(left): 302 return right 303 if is_false(right): 304 return left 305 return _simplify_comparison(expression, left, right, or_=True) 306 elif isinstance(expression, exp.Xor): 307 if left == right: 308 return exp.false() 309 310 if isinstance(expression, exp.Connector): 311 return _flat_simplify(expression, _simplify_connectors, root) 312 return expression 313 314 315LT_LTE = (exp.LT, exp.LTE) 316GT_GTE = (exp.GT, exp.GTE) 317 318COMPARISONS = ( 319 *LT_LTE, 320 *GT_GTE, 321 exp.EQ, 322 exp.NEQ, 323 exp.Is, 324) 325 326INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 327 exp.LT: exp.GT, 328 exp.GT: exp.LT, 329 exp.LTE: exp.GTE, 330 exp.GTE: exp.LTE, 331} 332 333NONDETERMINISTIC = (exp.Rand, exp.Randn) 334AND_OR = (exp.And, exp.Or) 335 336 337def _simplify_comparison(expression, left, right, or_=False): 338 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 339 ll, lr = left.args.values() 340 rl, rr = right.args.values() 341 342 largs = {ll, lr} 343 rargs = {rl, rr} 344 345 matching = largs & rargs 346 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 347 348 if matching and columns: 349 try: 350 l = first(largs - columns) 351 r = first(rargs - columns) 352 except StopIteration: 353 return expression 354 355 if l.is_number and r.is_number: 356 l = l.to_py() 357 r = r.to_py() 358 elif l.is_string and r.is_string: 359 l = l.name 360 r = r.name 361 else: 362 l = extract_date(l) 363 if not l: 364 return None 365 r = extract_date(r) 366 if not r: 367 return None 368 # python won't compare date and datetime, but many engines will upcast 369 l, r = cast_as_datetime(l), cast_as_datetime(r) 370 371 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 372 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 373 return left if (av > bv if or_ else av <= bv) else right 374 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 375 return left if (av < bv if or_ else av >= bv) else right 376 377 # we can't ever shortcut to true because the column could be null 378 if not or_: 379 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 380 if av <= bv: 381 return exp.false() 382 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 383 if av >= bv: 384 return exp.false() 385 elif isinstance(a, exp.EQ): 386 if isinstance(b, exp.LT): 387 return exp.false() if av >= bv else a 388 if isinstance(b, exp.LTE): 389 return exp.false() if av > bv else a 390 if isinstance(b, exp.GT): 391 return exp.false() if av <= bv else a 392 if isinstance(b, exp.GTE): 393 return exp.false() if av < bv else a 394 if isinstance(b, exp.NEQ): 395 return exp.false() if av == bv else a 396 return None 397 398 399def remove_complements(expression, root=True): 400 """ 401 Removing complements. 402 403 A AND NOT A -> FALSE 404 A OR NOT A -> TRUE 405 """ 406 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 407 ops = set(expression.flatten()) 408 for op in ops: 409 if isinstance(op, exp.Not) and op.this in ops: 410 return exp.false() if isinstance(expression, exp.And) else exp.true() 411 412 return expression 413 414 415def uniq_sort(expression, root=True): 416 """ 417 Uniq and sort a connector. 418 419 C AND A AND B AND B -> A AND B AND C 420 """ 421 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 422 flattened = tuple(expression.flatten()) 423 424 if isinstance(expression, exp.Xor): 425 result_func = exp.xor 426 # Do not deduplicate XOR as A XOR A != A if A == True 427 deduped = None 428 arr = tuple((gen(e), e) for e in flattened) 429 else: 430 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 431 deduped = {gen(e): e for e in flattened} 432 arr = tuple(deduped.items()) 433 434 # check if the operands are already sorted, if not sort them 435 # A AND C AND B -> A AND B AND C 436 for i, (sql, e) in enumerate(arr[1:]): 437 if sql < arr[i][0]: 438 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 439 break 440 else: 441 # we didn't have to sort but maybe we need to dedup 442 if deduped and len(deduped) < len(flattened): 443 expression = result_func(*deduped.values(), copy=False) 444 445 return expression 446 447 448def absorb_and_eliminate(expression, root=True): 449 """ 450 absorption: 451 A AND (A OR B) -> A 452 A OR (A AND B) -> A 453 A AND (NOT A OR B) -> A AND B 454 A OR (NOT A AND B) -> A OR B 455 elimination: 456 (A AND B) OR (A AND NOT B) -> A 457 (A OR B) AND (A OR NOT B) -> A 458 """ 459 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 460 kind = exp.Or if isinstance(expression, exp.And) else exp.And 461 462 ops = tuple(expression.flatten()) 463 464 # Initialize lookup tables: 465 # Set of all operands, used to find complements for absorption. 466 op_set = set() 467 # Sub-operands, used to find subsets for absorption. 468 subops = defaultdict(list) 469 # Pairs of complements, used for elimination. 470 pairs = defaultdict(list) 471 472 # Populate the lookup tables 473 for op in ops: 474 op_set.add(op) 475 476 if not isinstance(op, kind): 477 # In cases like: A OR (A AND B) 478 # Subop will be: ^ 479 subops[op].append({op}) 480 continue 481 482 # In cases like: (A AND B) OR (A AND B AND C) 483 # Subops will be: ^ ^ 484 subset = set(op.flatten()) 485 for i in subset: 486 subops[i].append(subset) 487 488 a, b = op.unnest_operands() 489 if isinstance(a, exp.Not): 490 pairs[frozenset((a.this, b))].append((op, b)) 491 if isinstance(b, exp.Not): 492 pairs[frozenset((a, b.this))].append((op, a)) 493 494 for op in ops: 495 if not isinstance(op, kind): 496 continue 497 498 a, b = op.unnest_operands() 499 500 # Absorb 501 if isinstance(a, exp.Not) and a.this in op_set: 502 a.replace(exp.true() if kind == exp.And else exp.false()) 503 continue 504 if isinstance(b, exp.Not) and b.this in op_set: 505 b.replace(exp.true() if kind == exp.And else exp.false()) 506 continue 507 superset = set(op.flatten()) 508 if any(any(subset < superset for subset in subops[i]) for i in superset): 509 op.replace(exp.false() if kind == exp.And else exp.true()) 510 continue 511 512 # Eliminate 513 for other, complement in pairs[frozenset((a, b))]: 514 op.replace(complement) 515 other.replace(complement) 516 517 return expression 518 519 520def propagate_constants(expression, root=True): 521 """ 522 Propagate constants for conjunctions in DNF: 523 524 SELECT * FROM t WHERE a = b AND b = 5 becomes 525 SELECT * FROM t WHERE a = 5 AND b = 5 526 527 Reference: https://www.sqlite.org/optoverview.html 528 """ 529 530 if ( 531 isinstance(expression, exp.And) 532 and (root or not expression.same_parent) 533 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 534 ): 535 constant_mapping = {} 536 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 537 if isinstance(expr, exp.EQ): 538 l, r = expr.left, expr.right 539 540 # TODO: create a helper that can be used to detect nested literal expressions such 541 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 542 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 543 constant_mapping[l] = (id(l), r) 544 545 if constant_mapping: 546 for column in find_all_in_scope(expression, exp.Column): 547 parent = column.parent 548 column_id, constant = constant_mapping.get(column) or (None, None) 549 if ( 550 column_id is not None 551 and id(column) != column_id 552 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 553 ): 554 column.replace(constant.copy()) 555 556 return expression 557 558 559INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 560 exp.DateAdd: exp.Sub, 561 exp.DateSub: exp.Add, 562 exp.DatetimeAdd: exp.Sub, 563 exp.DatetimeSub: exp.Add, 564} 565 566INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 567 **INVERSE_DATE_OPS, 568 exp.Add: exp.Sub, 569 exp.Sub: exp.Add, 570} 571 572 573def _is_number(expression: exp.Expression) -> bool: 574 return expression.is_number 575 576 577def _is_interval(expression: exp.Expression) -> bool: 578 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 579 580 581@catch(ModuleNotFoundError, UnsupportedUnit) 582def simplify_equality(expression: exp.Expression) -> exp.Expression: 583 """ 584 Use the subtraction and addition properties of equality to simplify expressions: 585 586 x + 1 = 3 becomes x = 2 587 588 There are two binary operations in the above expression: + and = 589 Here's how we reference all the operands in the code below: 590 591 l r 592 x + 1 = 3 593 a b 594 """ 595 if isinstance(expression, COMPARISONS): 596 l, r = expression.left, expression.right 597 598 if l.__class__ not in INVERSE_OPS: 599 return expression 600 601 if r.is_number: 602 a_predicate = _is_number 603 b_predicate = _is_number 604 elif _is_date_literal(r): 605 a_predicate = _is_date_literal 606 b_predicate = _is_interval 607 else: 608 return expression 609 610 if l.__class__ in INVERSE_DATE_OPS: 611 l = t.cast(exp.IntervalOp, l) 612 a = l.this 613 b = l.interval() 614 else: 615 l = t.cast(exp.Binary, l) 616 a, b = l.left, l.right 617 618 if not a_predicate(a) and b_predicate(b): 619 pass 620 elif not a_predicate(b) and b_predicate(a): 621 a, b = b, a 622 else: 623 return expression 624 625 return expression.__class__( 626 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 627 ) 628 return expression 629 630 631def simplify_literals(expression, root=True): 632 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 633 return _flat_simplify(expression, _simplify_binary, root) 634 635 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 636 return expression.this.this 637 638 if type(expression) in INVERSE_DATE_OPS: 639 return _simplify_binary(expression, expression.this, expression.interval()) or expression 640 641 return expression 642 643 644NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 645 646 647def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 648 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 649 this = _simplify_integer_cast(expr.this) 650 else: 651 this = expr.this 652 653 if isinstance(expr, exp.Cast) and this.is_int: 654 num = this.to_py() 655 656 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 657 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 658 # engine-dependent 659 if ( 660 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 661 ) or ( 662 UTINYINT_MIN <= num <= UTINYINT_MAX 663 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 664 ): 665 return this 666 667 return expr 668 669 670def _simplify_binary(expression, a, b): 671 if isinstance(expression, COMPARISONS): 672 a = _simplify_integer_cast(a) 673 b = _simplify_integer_cast(b) 674 675 if isinstance(expression, exp.Is): 676 if isinstance(b, exp.Not): 677 c = b.this 678 not_ = True 679 else: 680 c = b 681 not_ = False 682 683 if is_null(c): 684 if isinstance(a, exp.Literal): 685 return exp.true() if not_ else exp.false() 686 if is_null(a): 687 return exp.false() if not_ else exp.true() 688 elif isinstance(expression, NULL_OK): 689 return None 690 elif is_null(a) or is_null(b): 691 return exp.null() 692 693 if a.is_number and b.is_number: 694 num_a = a.to_py() 695 num_b = b.to_py() 696 697 if isinstance(expression, exp.Add): 698 return exp.Literal.number(num_a + num_b) 699 if isinstance(expression, exp.Mul): 700 return exp.Literal.number(num_a * num_b) 701 702 # We only simplify Sub, Div if a and b have the same parent because they're not associative 703 if isinstance(expression, exp.Sub): 704 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 705 if isinstance(expression, exp.Div): 706 # engines have differing int div behavior so intdiv is not safe 707 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 708 return None 709 return exp.Literal.number(num_a / num_b) 710 711 boolean = eval_boolean(expression, num_a, num_b) 712 713 if boolean: 714 return boolean 715 elif a.is_string and b.is_string: 716 boolean = eval_boolean(expression, a.this, b.this) 717 718 if boolean: 719 return boolean 720 elif _is_date_literal(a) and isinstance(b, exp.Interval): 721 date, b = extract_date(a), extract_interval(b) 722 if date and b: 723 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 724 return date_literal(date + b, extract_type(a)) 725 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 726 return date_literal(date - b, extract_type(a)) 727 elif isinstance(a, exp.Interval) and _is_date_literal(b): 728 a, date = extract_interval(a), extract_date(b) 729 # you cannot subtract a date from an interval 730 if a and b and isinstance(expression, exp.Add): 731 return date_literal(a + date, extract_type(b)) 732 elif _is_date_literal(a) and _is_date_literal(b): 733 if isinstance(expression, exp.Predicate): 734 a, b = extract_date(a), extract_date(b) 735 boolean = eval_boolean(expression, a, b) 736 if boolean: 737 return boolean 738 739 return None 740 741 742def simplify_parens(expression): 743 if not isinstance(expression, exp.Paren): 744 return expression 745 746 this = expression.this 747 parent = expression.parent 748 parent_is_predicate = isinstance(parent, exp.Predicate) 749 750 if ( 751 not isinstance(this, exp.Select) 752 and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) 753 and ( 754 not isinstance(parent, (exp.Condition, exp.Binary)) 755 or isinstance(parent, exp.Paren) 756 or ( 757 not isinstance(this, exp.Binary) 758 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 759 ) 760 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 761 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 762 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 763 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 764 ) 765 ): 766 return this 767 return expression 768 769 770def _is_nonnull_constant(expression: exp.Expression) -> bool: 771 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 772 773 774def _is_constant(expression: exp.Expression) -> bool: 775 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 776 777 778def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 779 # COALESCE(x) -> x 780 if ( 781 isinstance(expression, exp.Coalesce) 782 and (not expression.expressions or _is_nonnull_constant(expression.this)) 783 # COALESCE is also used as a Spark partitioning hint 784 and not isinstance(expression.parent, exp.Hint) 785 ): 786 return expression.this 787 788 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 789 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 790 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 791 if dialect == "redshift": 792 return expression 793 794 if not isinstance(expression, COMPARISONS): 795 return expression 796 797 if isinstance(expression.left, exp.Coalesce): 798 coalesce = expression.left 799 other = expression.right 800 elif isinstance(expression.right, exp.Coalesce): 801 coalesce = expression.right 802 other = expression.left 803 else: 804 return expression 805 806 # This transformation is valid for non-constants, 807 # but it really only does anything if they are both constants. 808 if not _is_constant(other): 809 return expression 810 811 # Find the first constant arg 812 for arg_index, arg in enumerate(coalesce.expressions): 813 if _is_constant(arg): 814 break 815 else: 816 return expression 817 818 coalesce.set("expressions", coalesce.expressions[:arg_index]) 819 820 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 821 # since we already remove COALESCE at the top of this function. 822 coalesce = coalesce if coalesce.expressions else coalesce.this 823 824 # This expression is more complex than when we started, but it will get simplified further 825 return exp.paren( 826 exp.or_( 827 exp.and_( 828 coalesce.is_(exp.null()).not_(copy=False), 829 expression.copy(), 830 copy=False, 831 ), 832 exp.and_( 833 coalesce.is_(exp.null()), 834 type(expression)(this=arg.copy(), expression=other.copy()), 835 copy=False, 836 ), 837 copy=False, 838 ) 839 ) 840 841 842CONCATS = (exp.Concat, exp.DPipe) 843 844 845def simplify_concat(expression): 846 """Reduces all groups that contain string literals by concatenating them.""" 847 if not isinstance(expression, CONCATS) or ( 848 # We can't reduce a CONCAT_WS call if we don't statically know the separator 849 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 850 ): 851 return expression 852 853 if isinstance(expression, exp.ConcatWs): 854 sep_expr, *expressions = expression.expressions 855 sep = sep_expr.name 856 concat_type = exp.ConcatWs 857 args = {} 858 else: 859 expressions = expression.expressions 860 sep = "" 861 concat_type = exp.Concat 862 args = { 863 "safe": expression.args.get("safe"), 864 "coalesce": expression.args.get("coalesce"), 865 } 866 867 new_args = [] 868 for is_string_group, group in itertools.groupby( 869 expressions or expression.flatten(), lambda e: e.is_string 870 ): 871 if is_string_group: 872 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 873 else: 874 new_args.extend(group) 875 876 if len(new_args) == 1 and new_args[0].is_string: 877 return new_args[0] 878 879 if concat_type is exp.ConcatWs: 880 new_args = [sep_expr] + new_args 881 elif isinstance(expression, exp.DPipe): 882 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 883 884 return concat_type(expressions=new_args, **args) 885 886 887def simplify_conditionals(expression): 888 """Simplifies expressions like IF, CASE if their condition is statically known.""" 889 if isinstance(expression, exp.Case): 890 this = expression.this 891 for case in expression.args["ifs"]: 892 cond = case.this 893 if this: 894 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 895 cond = cond.replace(this.pop().eq(cond)) 896 897 if always_true(cond): 898 return case.args["true"] 899 900 if always_false(cond): 901 case.pop() 902 if not expression.args["ifs"]: 903 return expression.args.get("default") or exp.null() 904 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 905 if always_true(expression.this): 906 return expression.args["true"] 907 if always_false(expression.this): 908 return expression.args.get("false") or exp.null() 909 910 return expression 911 912 913def simplify_startswith(expression: exp.Expression) -> exp.Expression: 914 """ 915 Reduces a prefix check to either TRUE or FALSE if both the string and the 916 prefix are statically known. 917 918 Example: 919 >>> from sqlglot import parse_one 920 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 921 'TRUE' 922 """ 923 if ( 924 isinstance(expression, exp.StartsWith) 925 and expression.this.is_string 926 and expression.expression.is_string 927 ): 928 return exp.convert(expression.name.startswith(expression.expression.name)) 929 930 return expression 931 932 933DateRange = t.Tuple[datetime.date, datetime.date] 934 935 936def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 937 """ 938 Get the date range for a DATE_TRUNC equality comparison: 939 940 Example: 941 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 942 Returns: 943 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 944 """ 945 floor = date_floor(date, unit, dialect) 946 947 if date != floor: 948 # This will always be False, except for NULL values. 949 return None 950 951 return floor, floor + interval(unit) 952 953 954def _datetrunc_eq_expression( 955 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 956) -> exp.Expression: 957 """Get the logical expression for a date range""" 958 return exp.and_( 959 left >= date_literal(drange[0], target_type), 960 left < date_literal(drange[1], target_type), 961 copy=False, 962 ) 963 964 965def _datetrunc_eq( 966 left: exp.Expression, 967 date: datetime.date, 968 unit: str, 969 dialect: Dialect, 970 target_type: t.Optional[exp.DataType], 971) -> t.Optional[exp.Expression]: 972 drange = _datetrunc_range(date, unit, dialect) 973 if not drange: 974 return None 975 976 return _datetrunc_eq_expression(left, drange, target_type) 977 978 979def _datetrunc_neq( 980 left: exp.Expression, 981 date: datetime.date, 982 unit: str, 983 dialect: Dialect, 984 target_type: t.Optional[exp.DataType], 985) -> t.Optional[exp.Expression]: 986 drange = _datetrunc_range(date, unit, dialect) 987 if not drange: 988 return None 989 990 return exp.and_( 991 left < date_literal(drange[0], target_type), 992 left >= date_literal(drange[1], target_type), 993 copy=False, 994 ) 995 996 997DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 998 exp.LT: lambda l, dt, u, d, t: l 999 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 1000 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 1001 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 1002 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 1003 exp.EQ: _datetrunc_eq, 1004 exp.NEQ: _datetrunc_neq, 1005} 1006DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 1007DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 1008 1009 1010def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 1011 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 1012 1013 1014@catch(ModuleNotFoundError, UnsupportedUnit) 1015def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1016 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1017 comparison = expression.__class__ 1018 1019 if isinstance(expression, DATETRUNCS): 1020 this = expression.this 1021 trunc_type = extract_type(this) 1022 date = extract_date(this) 1023 if date and expression.unit: 1024 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1025 elif comparison not in DATETRUNC_COMPARISONS: 1026 return expression 1027 1028 if isinstance(expression, exp.Binary): 1029 l, r = expression.left, expression.right 1030 1031 if not _is_datetrunc_predicate(l, r): 1032 return expression 1033 1034 l = t.cast(exp.DateTrunc, l) 1035 trunc_arg = l.this 1036 unit = l.unit.name.lower() 1037 date = extract_date(r) 1038 1039 if not date: 1040 return expression 1041 1042 return ( 1043 DATETRUNC_BINARY_COMPARISONS[comparison]( 1044 trunc_arg, date, unit, dialect, extract_type(r) 1045 ) 1046 or expression 1047 ) 1048 1049 if isinstance(expression, exp.In): 1050 l = expression.this 1051 rs = expression.expressions 1052 1053 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1054 l = t.cast(exp.DateTrunc, l) 1055 unit = l.unit.name.lower() 1056 1057 ranges = [] 1058 for r in rs: 1059 date = extract_date(r) 1060 if not date: 1061 return expression 1062 drange = _datetrunc_range(date, unit, dialect) 1063 if drange: 1064 ranges.append(drange) 1065 1066 if not ranges: 1067 return expression 1068 1069 ranges = merge_ranges(ranges) 1070 target_type = extract_type(*rs) 1071 1072 return exp.or_( 1073 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1074 ) 1075 1076 return expression 1077 1078 1079def sort_comparison(expression: exp.Expression) -> exp.Expression: 1080 if expression.__class__ in COMPLEMENT_COMPARISONS: 1081 l, r = expression.this, expression.expression 1082 l_column = isinstance(l, exp.Column) 1083 r_column = isinstance(r, exp.Column) 1084 l_const = _is_constant(l) 1085 r_const = _is_constant(r) 1086 1087 if ( 1088 (l_column and not r_column) 1089 or (r_const and not l_const) 1090 or isinstance(r, exp.SubqueryPredicate) 1091 ): 1092 return expression 1093 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1094 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1095 this=r, expression=l 1096 ) 1097 return expression 1098 1099 1100# CROSS joins result in an empty table if the right table is empty. 1101# So we can only simplify certain types of joins to CROSS. 1102# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1103JOINS = { 1104 ("", ""), 1105 ("", "INNER"), 1106 ("RIGHT", ""), 1107 ("RIGHT", "OUTER"), 1108} 1109 1110 1111def remove_where_true(expression): 1112 for where in expression.find_all(exp.Where): 1113 if always_true(where.this): 1114 where.pop() 1115 for join in expression.find_all(exp.Join): 1116 if ( 1117 always_true(join.args.get("on")) 1118 and not join.args.get("using") 1119 and not join.args.get("method") 1120 and (join.side, join.kind) in JOINS 1121 ): 1122 join.args["on"].pop() 1123 join.set("side", None) 1124 join.set("kind", "CROSS") 1125 1126 1127def always_true(expression): 1128 return (isinstance(expression, exp.Boolean) and expression.this) or ( 1129 isinstance(expression, exp.Literal) and not is_zero(expression) 1130 ) 1131 1132 1133def always_false(expression): 1134 return is_false(expression) or is_null(expression) or is_zero(expression) 1135 1136 1137def is_zero(expression): 1138 return isinstance(expression, exp.Literal) and expression.to_py() == 0 1139 1140 1141def is_complement(a, b): 1142 return isinstance(b, exp.Not) and b.this == a 1143 1144 1145def is_false(a: exp.Expression) -> bool: 1146 return type(a) is exp.Boolean and not a.this 1147 1148 1149def is_null(a: exp.Expression) -> bool: 1150 return type(a) is exp.Null 1151 1152 1153def eval_boolean(expression, a, b): 1154 if isinstance(expression, (exp.EQ, exp.Is)): 1155 return boolean_literal(a == b) 1156 if isinstance(expression, exp.NEQ): 1157 return boolean_literal(a != b) 1158 if isinstance(expression, exp.GT): 1159 return boolean_literal(a > b) 1160 if isinstance(expression, exp.GTE): 1161 return boolean_literal(a >= b) 1162 if isinstance(expression, exp.LT): 1163 return boolean_literal(a < b) 1164 if isinstance(expression, exp.LTE): 1165 return boolean_literal(a <= b) 1166 return None 1167 1168 1169def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1170 if isinstance(value, datetime.datetime): 1171 return value.date() 1172 if isinstance(value, datetime.date): 1173 return value 1174 try: 1175 return datetime.datetime.fromisoformat(value).date() 1176 except ValueError: 1177 return None 1178 1179 1180def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1181 if isinstance(value, datetime.datetime): 1182 return value 1183 if isinstance(value, datetime.date): 1184 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1185 try: 1186 return datetime.datetime.fromisoformat(value) 1187 except ValueError: 1188 return None 1189 1190 1191def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1192 if not value: 1193 return None 1194 if to.is_type(exp.DataType.Type.DATE): 1195 return cast_as_date(value) 1196 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1197 return cast_as_datetime(value) 1198 return None 1199 1200 1201def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1202 if isinstance(cast, exp.Cast): 1203 to = cast.to 1204 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1205 to = exp.DataType.build(exp.DataType.Type.DATE) 1206 else: 1207 return None 1208 1209 if isinstance(cast.this, exp.Literal): 1210 value: t.Any = cast.this.name 1211 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1212 value = extract_date(cast.this) 1213 else: 1214 return None 1215 return cast_value(value, to) 1216 1217 1218def _is_date_literal(expression: exp.Expression) -> bool: 1219 return extract_date(expression) is not None 1220 1221 1222def extract_interval(expression): 1223 try: 1224 n = int(expression.this.to_py()) 1225 unit = expression.text("unit").lower() 1226 return interval(unit, n) 1227 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1228 return None 1229 1230 1231def extract_type(*expressions): 1232 target_type = None 1233 for expression in expressions: 1234 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1235 if target_type: 1236 break 1237 1238 return target_type 1239 1240 1241def date_literal(date, target_type=None): 1242 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1243 target_type = ( 1244 exp.DataType.Type.DATETIME 1245 if isinstance(date, datetime.datetime) 1246 else exp.DataType.Type.DATE 1247 ) 1248 1249 return exp.cast(exp.Literal.string(date), target_type) 1250 1251 1252def interval(unit: str, n: int = 1): 1253 from dateutil.relativedelta import relativedelta 1254 1255 if unit == "year": 1256 return relativedelta(years=1 * n) 1257 if unit == "quarter": 1258 return relativedelta(months=3 * n) 1259 if unit == "month": 1260 return relativedelta(months=1 * n) 1261 if unit == "week": 1262 return relativedelta(weeks=1 * n) 1263 if unit == "day": 1264 return relativedelta(days=1 * n) 1265 if unit == "hour": 1266 return relativedelta(hours=1 * n) 1267 if unit == "minute": 1268 return relativedelta(minutes=1 * n) 1269 if unit == "second": 1270 return relativedelta(seconds=1 * n) 1271 1272 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1273 1274 1275def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1276 if unit == "year": 1277 return d.replace(month=1, day=1) 1278 if unit == "quarter": 1279 if d.month <= 3: 1280 return d.replace(month=1, day=1) 1281 elif d.month <= 6: 1282 return d.replace(month=4, day=1) 1283 elif d.month <= 9: 1284 return d.replace(month=7, day=1) 1285 else: 1286 return d.replace(month=10, day=1) 1287 if unit == "month": 1288 return d.replace(month=d.month, day=1) 1289 if unit == "week": 1290 # Assuming week starts on Monday (0) and ends on Sunday (6) 1291 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1292 if unit == "day": 1293 return d 1294 1295 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1296 1297 1298def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1299 floor = date_floor(d, unit, dialect) 1300 1301 if floor == d: 1302 return d 1303 1304 return floor + interval(unit) 1305 1306 1307def boolean_literal(condition): 1308 return exp.true() if condition else exp.false() 1309 1310 1311def _flat_simplify(expression, simplifier, root=True): 1312 if root or not expression.same_parent: 1313 operands = [] 1314 queue = deque(expression.flatten(unnest=False)) 1315 size = len(queue) 1316 1317 while queue: 1318 a = queue.popleft() 1319 1320 for b in queue: 1321 result = simplifier(expression, a, b) 1322 1323 if result and result is not expression: 1324 queue.remove(b) 1325 queue.appendleft(result) 1326 break 1327 else: 1328 operands.append(a) 1329 1330 if len(operands) < size: 1331 return functools.reduce( 1332 lambda a, b: expression.__class__(this=a, expression=b), operands 1333 ) 1334 return expression 1335 1336 1337def gen(expression: t.Any, comments: bool = False) -> str: 1338 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1339 1340 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1341 generator is expensive so we have a bare minimum sql generator here. 1342 1343 Args: 1344 expression: the expression to convert into a SQL string. 1345 comments: whether to include the expression's comments. 1346 """ 1347 return Gen().gen(expression, comments=comments) 1348 1349 1350class Gen: 1351 def __init__(self): 1352 self.stack = [] 1353 self.sqls = [] 1354 1355 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1356 self.stack = [expression] 1357 self.sqls.clear() 1358 1359 while self.stack: 1360 node = self.stack.pop() 1361 1362 if isinstance(node, exp.Expression): 1363 if comments and node.comments: 1364 self.stack.append(f" /*{','.join(node.comments)}*/") 1365 1366 exp_handler_name = f"{node.key}_sql" 1367 1368 if hasattr(self, exp_handler_name): 1369 getattr(self, exp_handler_name)(node) 1370 elif isinstance(node, exp.Func): 1371 self._function(node) 1372 else: 1373 key = node.key.upper() 1374 self.stack.append(f"{key} " if self._args(node) else key) 1375 elif type(node) is list: 1376 for n in reversed(node): 1377 if n is not None: 1378 self.stack.extend((n, ",")) 1379 if node: 1380 self.stack.pop() 1381 else: 1382 if node is not None: 1383 self.sqls.append(str(node)) 1384 1385 return "".join(self.sqls) 1386 1387 def add_sql(self, e: exp.Add) -> None: 1388 self._binary(e, " + ") 1389 1390 def alias_sql(self, e: exp.Alias) -> None: 1391 self.stack.extend( 1392 ( 1393 e.args.get("alias"), 1394 " AS ", 1395 e.args.get("this"), 1396 ) 1397 ) 1398 1399 def and_sql(self, e: exp.And) -> None: 1400 self._binary(e, " AND ") 1401 1402 def anonymous_sql(self, e: exp.Anonymous) -> None: 1403 this = e.this 1404 if isinstance(this, str): 1405 name = this.upper() 1406 elif isinstance(this, exp.Identifier): 1407 name = this.this 1408 name = f'"{name}"' if this.quoted else name.upper() 1409 else: 1410 raise ValueError( 1411 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1412 ) 1413 1414 self.stack.extend( 1415 ( 1416 ")", 1417 e.expressions, 1418 "(", 1419 name, 1420 ) 1421 ) 1422 1423 def between_sql(self, e: exp.Between) -> None: 1424 self.stack.extend( 1425 ( 1426 e.args.get("high"), 1427 " AND ", 1428 e.args.get("low"), 1429 " BETWEEN ", 1430 e.this, 1431 ) 1432 ) 1433 1434 def boolean_sql(self, e: exp.Boolean) -> None: 1435 self.stack.append("TRUE" if e.this else "FALSE") 1436 1437 def bracket_sql(self, e: exp.Bracket) -> None: 1438 self.stack.extend( 1439 ( 1440 "]", 1441 e.expressions, 1442 "[", 1443 e.this, 1444 ) 1445 ) 1446 1447 def column_sql(self, e: exp.Column) -> None: 1448 for p in reversed(e.parts): 1449 self.stack.extend((p, ".")) 1450 self.stack.pop() 1451 1452 def datatype_sql(self, e: exp.DataType) -> None: 1453 self._args(e, 1) 1454 self.stack.append(f"{e.this.name} ") 1455 1456 def div_sql(self, e: exp.Div) -> None: 1457 self._binary(e, " / ") 1458 1459 def dot_sql(self, e: exp.Dot) -> None: 1460 self._binary(e, ".") 1461 1462 def eq_sql(self, e: exp.EQ) -> None: 1463 self._binary(e, " = ") 1464 1465 def from_sql(self, e: exp.From) -> None: 1466 self.stack.extend((e.this, "FROM ")) 1467 1468 def gt_sql(self, e: exp.GT) -> None: 1469 self._binary(e, " > ") 1470 1471 def gte_sql(self, e: exp.GTE) -> None: 1472 self._binary(e, " >= ") 1473 1474 def identifier_sql(self, e: exp.Identifier) -> None: 1475 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1476 1477 def ilike_sql(self, e: exp.ILike) -> None: 1478 self._binary(e, " ILIKE ") 1479 1480 def in_sql(self, e: exp.In) -> None: 1481 self.stack.append(")") 1482 self._args(e, 1) 1483 self.stack.extend( 1484 ( 1485 "(", 1486 " IN ", 1487 e.this, 1488 ) 1489 ) 1490 1491 def intdiv_sql(self, e: exp.IntDiv) -> None: 1492 self._binary(e, " DIV ") 1493 1494 def is_sql(self, e: exp.Is) -> None: 1495 self._binary(e, " IS ") 1496 1497 def like_sql(self, e: exp.Like) -> None: 1498 self._binary(e, " Like ") 1499 1500 def literal_sql(self, e: exp.Literal) -> None: 1501 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1502 1503 def lt_sql(self, e: exp.LT) -> None: 1504 self._binary(e, " < ") 1505 1506 def lte_sql(self, e: exp.LTE) -> None: 1507 self._binary(e, " <= ") 1508 1509 def mod_sql(self, e: exp.Mod) -> None: 1510 self._binary(e, " % ") 1511 1512 def mul_sql(self, e: exp.Mul) -> None: 1513 self._binary(e, " * ") 1514 1515 def neg_sql(self, e: exp.Neg) -> None: 1516 self._unary(e, "-") 1517 1518 def neq_sql(self, e: exp.NEQ) -> None: 1519 self._binary(e, " <> ") 1520 1521 def not_sql(self, e: exp.Not) -> None: 1522 self._unary(e, "NOT ") 1523 1524 def null_sql(self, e: exp.Null) -> None: 1525 self.stack.append("NULL") 1526 1527 def or_sql(self, e: exp.Or) -> None: 1528 self._binary(e, " OR ") 1529 1530 def paren_sql(self, e: exp.Paren) -> None: 1531 self.stack.extend( 1532 ( 1533 ")", 1534 e.this, 1535 "(", 1536 ) 1537 ) 1538 1539 def sub_sql(self, e: exp.Sub) -> None: 1540 self._binary(e, " - ") 1541 1542 def subquery_sql(self, e: exp.Subquery) -> None: 1543 self._args(e, 2) 1544 alias = e.args.get("alias") 1545 if alias: 1546 self.stack.append(alias) 1547 self.stack.extend((")", e.this, "(")) 1548 1549 def table_sql(self, e: exp.Table) -> None: 1550 self._args(e, 4) 1551 alias = e.args.get("alias") 1552 if alias: 1553 self.stack.append(alias) 1554 for p in reversed(e.parts): 1555 self.stack.extend((p, ".")) 1556 self.stack.pop() 1557 1558 def tablealias_sql(self, e: exp.TableAlias) -> None: 1559 columns = e.columns 1560 1561 if columns: 1562 self.stack.extend((")", columns, "(")) 1563 1564 self.stack.extend((e.this, " AS ")) 1565 1566 def var_sql(self, e: exp.Var) -> None: 1567 self.stack.append(e.this) 1568 1569 def _binary(self, e: exp.Binary, op: str) -> None: 1570 self.stack.extend((e.expression, op, e.this)) 1571 1572 def _unary(self, e: exp.Unary, op: str) -> None: 1573 self.stack.extend((e.this, op)) 1574 1575 def _function(self, e: exp.Func) -> None: 1576 self.stack.extend( 1577 ( 1578 ")", 1579 list(e.args.values()), 1580 "(", 1581 e.sql_name(), 1582 ) 1583 ) 1584 1585 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1586 kvs = [] 1587 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1588 1589 for k in arg_types or arg_types: 1590 v = node.args.get(k) 1591 1592 if v is not None: 1593 kvs.append([f":{k}", v]) 1594 if kvs: 1595 self.stack.append(kvs) 1596 return True 1597 return False
Common base class for all non-exit exceptions.
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 dialect: DialectType = None, 44 max_depth: t.Optional[int] = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 59 Returns: 60 sqlglot.Expression: simplified expression 61 """ 62 63 dialect = Dialect.get_or_raise(dialect) 64 65 def _simplify(expression, root=True): 66 if ( 67 max_depth 68 and isinstance(expression, exp.Connector) 69 and not isinstance(expression.parent, exp.Connector) 70 ): 71 depth = connector_depth(expression) 72 if depth > max_depth: 73 logger.info( 74 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 75 ) 76 return expression 77 78 if expression.meta.get(FINAL): 79 return expression 80 81 # group by expressions cannot be simplified, for example 82 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 83 # the projection must exactly match the group by key 84 group = expression.args.get("group") 85 86 if group and hasattr(expression, "selects"): 87 groups = set(group.expressions) 88 group.meta[FINAL] = True 89 90 for e in expression.selects: 91 for node in e.walk(): 92 if node in groups: 93 e.meta[FINAL] = True 94 break 95 96 having = expression.args.get("having") 97 if having: 98 for node in having.walk(): 99 if node in groups: 100 having.meta[FINAL] = True 101 break 102 103 # Pre-order transformations 104 node = expression 105 node = rewrite_between(node) 106 node = uniq_sort(node, root) 107 node = absorb_and_eliminate(node, root) 108 node = simplify_concat(node) 109 node = simplify_conditionals(node) 110 111 if constant_propagation: 112 node = propagate_constants(node, root) 113 114 exp.replace_children(node, lambda e: _simplify(e, False)) 115 116 # Post-order transformations 117 node = simplify_not(node) 118 node = flatten(node) 119 node = simplify_connectors(node, root) 120 node = remove_complements(node, root) 121 node = simplify_coalesce(node, dialect) 122 node.parent = expression.parent 123 node = simplify_literals(node, root) 124 node = simplify_equality(node) 125 node = simplify_parens(node) 126 node = simplify_datetrunc(node, dialect) 127 node = sort_comparison(node) 128 node = simplify_startswith(node) 129 130 if root: 131 expression.replace(node) 132 return node 133 134 expression = while_changing(expression, _simplify) 135 remove_where_true(expression) 136 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- max_depth: Chains of Connectors (AND, OR, etc) exceeding
max_depth
will be skipped
Returns:
sqlglot.Expression: simplified expression
139def connector_depth(expression: exp.Expression) -> int: 140 """ 141 Determine the maximum depth of a tree of Connectors. 142 143 For example: 144 >>> from sqlglot import parse_one 145 >>> connector_depth(parse_one("a AND b AND c AND d")) 146 3 147 """ 148 stack = deque([(expression, 0)]) 149 max_depth = 0 150 151 while stack: 152 expression, depth = stack.pop() 153 154 if not isinstance(expression, exp.Connector): 155 continue 156 157 depth += 1 158 max_depth = max(depth, max_depth) 159 160 stack.append((expression.left, depth)) 161 stack.append((expression.right, depth)) 162 163 return max_depth
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one >>> connector_depth(parse_one("a AND b AND c AND d")) 3
166def catch(*exceptions): 167 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 168 169 def decorator(func): 170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression 175 176 return wrapped 177 178 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
181def rewrite_between(expression: exp.Expression) -> exp.Expression: 182 """Rewrite x between y and z to x >= y AND x <= z. 183 184 This is done because comparison simplification is only done on lt/lte/gt/gte. 185 """ 186 if isinstance(expression, exp.Between): 187 negate = isinstance(expression.parent, exp.Not) 188 189 expression = exp.and_( 190 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 191 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 192 copy=False, 193 ) 194 195 if negate: 196 expression = exp.paren(expression, copy=False) 197 198 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
216def simplify_not(expression): 217 """ 218 Demorgan's Law 219 NOT (x OR y) -> NOT x AND NOT y 220 NOT (x AND y) -> NOT x OR NOT y 221 """ 222 if isinstance(expression, exp.Not): 223 this = expression.this 224 if is_null(this): 225 return exp.null() 226 if this.__class__ in COMPLEMENT_COMPARISONS: 227 right = this.expression 228 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 229 if complement_subquery_predicate: 230 right = complement_subquery_predicate(this=right.this) 231 232 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 233 if isinstance(this, exp.Paren): 234 condition = this.unnest() 235 if isinstance(condition, exp.And): 236 return exp.paren( 237 exp.or_( 238 exp.not_(condition.left, copy=False), 239 exp.not_(condition.right, copy=False), 240 copy=False, 241 ) 242 ) 243 if isinstance(condition, exp.Or): 244 return exp.paren( 245 exp.and_( 246 exp.not_(condition.left, copy=False), 247 exp.not_(condition.right, copy=False), 248 copy=False, 249 ) 250 ) 251 if is_null(condition): 252 return exp.null() 253 if always_true(this): 254 return exp.false() 255 if is_false(this): 256 return exp.true() 257 if isinstance(this, exp.Not): 258 # double negation 259 # NOT NOT x -> x 260 return this.this 261 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
264def flatten(expression): 265 """ 266 A AND (B AND C) -> A AND B AND C 267 A OR (B OR C) -> A OR B OR C 268 """ 269 if isinstance(expression, exp.Connector): 270 for node in expression.args.values(): 271 child = node.unnest() 272 if isinstance(child, expression.__class__): 273 node.replace(child) 274 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
277def simplify_connectors(expression, root=True): 278 def _simplify_connectors(expression, left, right): 279 if isinstance(expression, exp.And): 280 if is_false(left) or is_false(right): 281 return exp.false() 282 if is_zero(left) or is_zero(right): 283 return exp.false() 284 if is_null(left) or is_null(right): 285 return exp.null() 286 if always_true(left) and always_true(right): 287 return exp.true() 288 if always_true(left): 289 return right 290 if always_true(right): 291 return left 292 return _simplify_comparison(expression, left, right) 293 elif isinstance(expression, exp.Or): 294 if always_true(left) or always_true(right): 295 return exp.true() 296 if ( 297 (is_null(left) and is_null(right)) 298 or (is_null(left) and always_false(right)) 299 or (always_false(left) and is_null(right)) 300 ): 301 return exp.null() 302 if is_false(left): 303 return right 304 if is_false(right): 305 return left 306 return _simplify_comparison(expression, left, right, or_=True) 307 elif isinstance(expression, exp.Xor): 308 if left == right: 309 return exp.false() 310 311 if isinstance(expression, exp.Connector): 312 return _flat_simplify(expression, _simplify_connectors, root) 313 return expression
400def remove_complements(expression, root=True): 401 """ 402 Removing complements. 403 404 A AND NOT A -> FALSE 405 A OR NOT A -> TRUE 406 """ 407 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 408 ops = set(expression.flatten()) 409 for op in ops: 410 if isinstance(op, exp.Not) and op.this in ops: 411 return exp.false() if isinstance(expression, exp.And) else exp.true() 412 413 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
416def uniq_sort(expression, root=True): 417 """ 418 Uniq and sort a connector. 419 420 C AND A AND B AND B -> A AND B AND C 421 """ 422 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 423 flattened = tuple(expression.flatten()) 424 425 if isinstance(expression, exp.Xor): 426 result_func = exp.xor 427 # Do not deduplicate XOR as A XOR A != A if A == True 428 deduped = None 429 arr = tuple((gen(e), e) for e in flattened) 430 else: 431 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 432 deduped = {gen(e): e for e in flattened} 433 arr = tuple(deduped.items()) 434 435 # check if the operands are already sorted, if not sort them 436 # A AND C AND B -> A AND B AND C 437 for i, (sql, e) in enumerate(arr[1:]): 438 if sql < arr[i][0]: 439 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 440 break 441 else: 442 # we didn't have to sort but maybe we need to dedup 443 if deduped and len(deduped) < len(flattened): 444 expression = result_func(*deduped.values(), copy=False) 445 446 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
449def absorb_and_eliminate(expression, root=True): 450 """ 451 absorption: 452 A AND (A OR B) -> A 453 A OR (A AND B) -> A 454 A AND (NOT A OR B) -> A AND B 455 A OR (NOT A AND B) -> A OR B 456 elimination: 457 (A AND B) OR (A AND NOT B) -> A 458 (A OR B) AND (A OR NOT B) -> A 459 """ 460 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 461 kind = exp.Or if isinstance(expression, exp.And) else exp.And 462 463 ops = tuple(expression.flatten()) 464 465 # Initialize lookup tables: 466 # Set of all operands, used to find complements for absorption. 467 op_set = set() 468 # Sub-operands, used to find subsets for absorption. 469 subops = defaultdict(list) 470 # Pairs of complements, used for elimination. 471 pairs = defaultdict(list) 472 473 # Populate the lookup tables 474 for op in ops: 475 op_set.add(op) 476 477 if not isinstance(op, kind): 478 # In cases like: A OR (A AND B) 479 # Subop will be: ^ 480 subops[op].append({op}) 481 continue 482 483 # In cases like: (A AND B) OR (A AND B AND C) 484 # Subops will be: ^ ^ 485 subset = set(op.flatten()) 486 for i in subset: 487 subops[i].append(subset) 488 489 a, b = op.unnest_operands() 490 if isinstance(a, exp.Not): 491 pairs[frozenset((a.this, b))].append((op, b)) 492 if isinstance(b, exp.Not): 493 pairs[frozenset((a, b.this))].append((op, a)) 494 495 for op in ops: 496 if not isinstance(op, kind): 497 continue 498 499 a, b = op.unnest_operands() 500 501 # Absorb 502 if isinstance(a, exp.Not) and a.this in op_set: 503 a.replace(exp.true() if kind == exp.And else exp.false()) 504 continue 505 if isinstance(b, exp.Not) and b.this in op_set: 506 b.replace(exp.true() if kind == exp.And else exp.false()) 507 continue 508 superset = set(op.flatten()) 509 if any(any(subset < superset for subset in subops[i]) for i in superset): 510 op.replace(exp.false() if kind == exp.And else exp.true()) 511 continue 512 513 # Eliminate 514 for other, complement in pairs[frozenset((a, b))]: 515 op.replace(complement) 516 other.replace(complement) 517 518 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
521def propagate_constants(expression, root=True): 522 """ 523 Propagate constants for conjunctions in DNF: 524 525 SELECT * FROM t WHERE a = b AND b = 5 becomes 526 SELECT * FROM t WHERE a = 5 AND b = 5 527 528 Reference: https://www.sqlite.org/optoverview.html 529 """ 530 531 if ( 532 isinstance(expression, exp.And) 533 and (root or not expression.same_parent) 534 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 535 ): 536 constant_mapping = {} 537 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 538 if isinstance(expr, exp.EQ): 539 l, r = expr.left, expr.right 540 541 # TODO: create a helper that can be used to detect nested literal expressions such 542 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 543 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 544 constant_mapping[l] = (id(l), r) 545 546 if constant_mapping: 547 for column in find_all_in_scope(expression, exp.Column): 548 parent = column.parent 549 column_id, constant = constant_mapping.get(column) or (None, None) 550 if ( 551 column_id is not None 552 and id(column) != column_id 553 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 554 ): 555 column.replace(constant.copy()) 556 557 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
632def simplify_literals(expression, root=True): 633 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 634 return _flat_simplify(expression, _simplify_binary, root) 635 636 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 637 return expression.this.this 638 639 if type(expression) in INVERSE_DATE_OPS: 640 return _simplify_binary(expression, expression.this, expression.interval()) or expression 641 642 return expression
743def simplify_parens(expression): 744 if not isinstance(expression, exp.Paren): 745 return expression 746 747 this = expression.this 748 parent = expression.parent 749 parent_is_predicate = isinstance(parent, exp.Predicate) 750 751 if ( 752 not isinstance(this, exp.Select) 753 and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) 754 and ( 755 not isinstance(parent, (exp.Condition, exp.Binary)) 756 or isinstance(parent, exp.Paren) 757 or ( 758 not isinstance(this, exp.Binary) 759 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 760 ) 761 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 762 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 763 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 764 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 765 ) 766 ): 767 return this 768 return expression
779def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 780 # COALESCE(x) -> x 781 if ( 782 isinstance(expression, exp.Coalesce) 783 and (not expression.expressions or _is_nonnull_constant(expression.this)) 784 # COALESCE is also used as a Spark partitioning hint 785 and not isinstance(expression.parent, exp.Hint) 786 ): 787 return expression.this 788 789 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 790 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 791 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 792 if dialect == "redshift": 793 return expression 794 795 if not isinstance(expression, COMPARISONS): 796 return expression 797 798 if isinstance(expression.left, exp.Coalesce): 799 coalesce = expression.left 800 other = expression.right 801 elif isinstance(expression.right, exp.Coalesce): 802 coalesce = expression.right 803 other = expression.left 804 else: 805 return expression 806 807 # This transformation is valid for non-constants, 808 # but it really only does anything if they are both constants. 809 if not _is_constant(other): 810 return expression 811 812 # Find the first constant arg 813 for arg_index, arg in enumerate(coalesce.expressions): 814 if _is_constant(arg): 815 break 816 else: 817 return expression 818 819 coalesce.set("expressions", coalesce.expressions[:arg_index]) 820 821 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 822 # since we already remove COALESCE at the top of this function. 823 coalesce = coalesce if coalesce.expressions else coalesce.this 824 825 # This expression is more complex than when we started, but it will get simplified further 826 return exp.paren( 827 exp.or_( 828 exp.and_( 829 coalesce.is_(exp.null()).not_(copy=False), 830 expression.copy(), 831 copy=False, 832 ), 833 exp.and_( 834 coalesce.is_(exp.null()), 835 type(expression)(this=arg.copy(), expression=other.copy()), 836 copy=False, 837 ), 838 copy=False, 839 ) 840 )
846def simplify_concat(expression): 847 """Reduces all groups that contain string literals by concatenating them.""" 848 if not isinstance(expression, CONCATS) or ( 849 # We can't reduce a CONCAT_WS call if we don't statically know the separator 850 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 851 ): 852 return expression 853 854 if isinstance(expression, exp.ConcatWs): 855 sep_expr, *expressions = expression.expressions 856 sep = sep_expr.name 857 concat_type = exp.ConcatWs 858 args = {} 859 else: 860 expressions = expression.expressions 861 sep = "" 862 concat_type = exp.Concat 863 args = { 864 "safe": expression.args.get("safe"), 865 "coalesce": expression.args.get("coalesce"), 866 } 867 868 new_args = [] 869 for is_string_group, group in itertools.groupby( 870 expressions or expression.flatten(), lambda e: e.is_string 871 ): 872 if is_string_group: 873 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 874 else: 875 new_args.extend(group) 876 877 if len(new_args) == 1 and new_args[0].is_string: 878 return new_args[0] 879 880 if concat_type is exp.ConcatWs: 881 new_args = [sep_expr] + new_args 882 elif isinstance(expression, exp.DPipe): 883 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 884 885 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
888def simplify_conditionals(expression): 889 """Simplifies expressions like IF, CASE if their condition is statically known.""" 890 if isinstance(expression, exp.Case): 891 this = expression.this 892 for case in expression.args["ifs"]: 893 cond = case.this 894 if this: 895 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 896 cond = cond.replace(this.pop().eq(cond)) 897 898 if always_true(cond): 899 return case.args["true"] 900 901 if always_false(cond): 902 case.pop() 903 if not expression.args["ifs"]: 904 return expression.args.get("default") or exp.null() 905 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 906 if always_true(expression.this): 907 return expression.args["true"] 908 if always_false(expression.this): 909 return expression.args.get("false") or exp.null() 910 911 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
914def simplify_startswith(expression: exp.Expression) -> exp.Expression: 915 """ 916 Reduces a prefix check to either TRUE or FALSE if both the string and the 917 prefix are statically known. 918 919 Example: 920 >>> from sqlglot import parse_one 921 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 922 'TRUE' 923 """ 924 if ( 925 isinstance(expression, exp.StartsWith) 926 and expression.this.is_string 927 and expression.expression.is_string 928 ): 929 return exp.convert(expression.name.startswith(expression.expression.name)) 930 931 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1080def sort_comparison(expression: exp.Expression) -> exp.Expression: 1081 if expression.__class__ in COMPLEMENT_COMPARISONS: 1082 l, r = expression.this, expression.expression 1083 l_column = isinstance(l, exp.Column) 1084 r_column = isinstance(r, exp.Column) 1085 l_const = _is_constant(l) 1086 r_const = _is_constant(r) 1087 1088 if ( 1089 (l_column and not r_column) 1090 or (r_const and not l_const) 1091 or isinstance(r, exp.SubqueryPredicate) 1092 ): 1093 return expression 1094 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1095 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1096 this=r, expression=l 1097 ) 1098 return expression
1112def remove_where_true(expression): 1113 for where in expression.find_all(exp.Where): 1114 if always_true(where.this): 1115 where.pop() 1116 for join in expression.find_all(exp.Join): 1117 if ( 1118 always_true(join.args.get("on")) 1119 and not join.args.get("using") 1120 and not join.args.get("method") 1121 and (join.side, join.kind) in JOINS 1122 ): 1123 join.args["on"].pop() 1124 join.set("side", None) 1125 join.set("kind", "CROSS")
1154def eval_boolean(expression, a, b): 1155 if isinstance(expression, (exp.EQ, exp.Is)): 1156 return boolean_literal(a == b) 1157 if isinstance(expression, exp.NEQ): 1158 return boolean_literal(a != b) 1159 if isinstance(expression, exp.GT): 1160 return boolean_literal(a > b) 1161 if isinstance(expression, exp.GTE): 1162 return boolean_literal(a >= b) 1163 if isinstance(expression, exp.LT): 1164 return boolean_literal(a < b) 1165 if isinstance(expression, exp.LTE): 1166 return boolean_literal(a <= b) 1167 return None
1170def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1171 if isinstance(value, datetime.datetime): 1172 return value.date() 1173 if isinstance(value, datetime.date): 1174 return value 1175 try: 1176 return datetime.datetime.fromisoformat(value).date() 1177 except ValueError: 1178 return None
1181def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1182 if isinstance(value, datetime.datetime): 1183 return value 1184 if isinstance(value, datetime.date): 1185 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1186 try: 1187 return datetime.datetime.fromisoformat(value) 1188 except ValueError: 1189 return None
1192def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1193 if not value: 1194 return None 1195 if to.is_type(exp.DataType.Type.DATE): 1196 return cast_as_date(value) 1197 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1198 return cast_as_datetime(value) 1199 return None
1202def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1203 if isinstance(cast, exp.Cast): 1204 to = cast.to 1205 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1206 to = exp.DataType.build(exp.DataType.Type.DATE) 1207 else: 1208 return None 1209 1210 if isinstance(cast.this, exp.Literal): 1211 value: t.Any = cast.this.name 1212 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1213 value = extract_date(cast.this) 1214 else: 1215 return None 1216 return cast_value(value, to)
1242def date_literal(date, target_type=None): 1243 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1244 target_type = ( 1245 exp.DataType.Type.DATETIME 1246 if isinstance(date, datetime.datetime) 1247 else exp.DataType.Type.DATE 1248 ) 1249 1250 return exp.cast(exp.Literal.string(date), target_type)
1253def interval(unit: str, n: int = 1): 1254 from dateutil.relativedelta import relativedelta 1255 1256 if unit == "year": 1257 return relativedelta(years=1 * n) 1258 if unit == "quarter": 1259 return relativedelta(months=3 * n) 1260 if unit == "month": 1261 return relativedelta(months=1 * n) 1262 if unit == "week": 1263 return relativedelta(weeks=1 * n) 1264 if unit == "day": 1265 return relativedelta(days=1 * n) 1266 if unit == "hour": 1267 return relativedelta(hours=1 * n) 1268 if unit == "minute": 1269 return relativedelta(minutes=1 * n) 1270 if unit == "second": 1271 return relativedelta(seconds=1 * n) 1272 1273 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1276def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1277 if unit == "year": 1278 return d.replace(month=1, day=1) 1279 if unit == "quarter": 1280 if d.month <= 3: 1281 return d.replace(month=1, day=1) 1282 elif d.month <= 6: 1283 return d.replace(month=4, day=1) 1284 elif d.month <= 9: 1285 return d.replace(month=7, day=1) 1286 else: 1287 return d.replace(month=10, day=1) 1288 if unit == "month": 1289 return d.replace(month=d.month, day=1) 1290 if unit == "week": 1291 # Assuming week starts on Monday (0) and ends on Sunday (6) 1292 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1293 if unit == "day": 1294 return d 1295 1296 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1338def gen(expression: t.Any, comments: bool = False) -> str: 1339 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1340 1341 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1342 generator is expensive so we have a bare minimum sql generator here. 1343 1344 Args: 1345 expression: the expression to convert into a SQL string. 1346 comments: whether to include the expression's comments. 1347 """ 1348 return Gen().gen(expression, comments=comments)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
Arguments:
- expression: the expression to convert into a SQL string.
- comments: whether to include the expression's comments.
1351class Gen: 1352 def __init__(self): 1353 self.stack = [] 1354 self.sqls = [] 1355 1356 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1357 self.stack = [expression] 1358 self.sqls.clear() 1359 1360 while self.stack: 1361 node = self.stack.pop() 1362 1363 if isinstance(node, exp.Expression): 1364 if comments and node.comments: 1365 self.stack.append(f" /*{','.join(node.comments)}*/") 1366 1367 exp_handler_name = f"{node.key}_sql" 1368 1369 if hasattr(self, exp_handler_name): 1370 getattr(self, exp_handler_name)(node) 1371 elif isinstance(node, exp.Func): 1372 self._function(node) 1373 else: 1374 key = node.key.upper() 1375 self.stack.append(f"{key} " if self._args(node) else key) 1376 elif type(node) is list: 1377 for n in reversed(node): 1378 if n is not None: 1379 self.stack.extend((n, ",")) 1380 if node: 1381 self.stack.pop() 1382 else: 1383 if node is not None: 1384 self.sqls.append(str(node)) 1385 1386 return "".join(self.sqls) 1387 1388 def add_sql(self, e: exp.Add) -> None: 1389 self._binary(e, " + ") 1390 1391 def alias_sql(self, e: exp.Alias) -> None: 1392 self.stack.extend( 1393 ( 1394 e.args.get("alias"), 1395 " AS ", 1396 e.args.get("this"), 1397 ) 1398 ) 1399 1400 def and_sql(self, e: exp.And) -> None: 1401 self._binary(e, " AND ") 1402 1403 def anonymous_sql(self, e: exp.Anonymous) -> None: 1404 this = e.this 1405 if isinstance(this, str): 1406 name = this.upper() 1407 elif isinstance(this, exp.Identifier): 1408 name = this.this 1409 name = f'"{name}"' if this.quoted else name.upper() 1410 else: 1411 raise ValueError( 1412 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1413 ) 1414 1415 self.stack.extend( 1416 ( 1417 ")", 1418 e.expressions, 1419 "(", 1420 name, 1421 ) 1422 ) 1423 1424 def between_sql(self, e: exp.Between) -> None: 1425 self.stack.extend( 1426 ( 1427 e.args.get("high"), 1428 " AND ", 1429 e.args.get("low"), 1430 " BETWEEN ", 1431 e.this, 1432 ) 1433 ) 1434 1435 def boolean_sql(self, e: exp.Boolean) -> None: 1436 self.stack.append("TRUE" if e.this else "FALSE") 1437 1438 def bracket_sql(self, e: exp.Bracket) -> None: 1439 self.stack.extend( 1440 ( 1441 "]", 1442 e.expressions, 1443 "[", 1444 e.this, 1445 ) 1446 ) 1447 1448 def column_sql(self, e: exp.Column) -> None: 1449 for p in reversed(e.parts): 1450 self.stack.extend((p, ".")) 1451 self.stack.pop() 1452 1453 def datatype_sql(self, e: exp.DataType) -> None: 1454 self._args(e, 1) 1455 self.stack.append(f"{e.this.name} ") 1456 1457 def div_sql(self, e: exp.Div) -> None: 1458 self._binary(e, " / ") 1459 1460 def dot_sql(self, e: exp.Dot) -> None: 1461 self._binary(e, ".") 1462 1463 def eq_sql(self, e: exp.EQ) -> None: 1464 self._binary(e, " = ") 1465 1466 def from_sql(self, e: exp.From) -> None: 1467 self.stack.extend((e.this, "FROM ")) 1468 1469 def gt_sql(self, e: exp.GT) -> None: 1470 self._binary(e, " > ") 1471 1472 def gte_sql(self, e: exp.GTE) -> None: 1473 self._binary(e, " >= ") 1474 1475 def identifier_sql(self, e: exp.Identifier) -> None: 1476 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1477 1478 def ilike_sql(self, e: exp.ILike) -> None: 1479 self._binary(e, " ILIKE ") 1480 1481 def in_sql(self, e: exp.In) -> None: 1482 self.stack.append(")") 1483 self._args(e, 1) 1484 self.stack.extend( 1485 ( 1486 "(", 1487 " IN ", 1488 e.this, 1489 ) 1490 ) 1491 1492 def intdiv_sql(self, e: exp.IntDiv) -> None: 1493 self._binary(e, " DIV ") 1494 1495 def is_sql(self, e: exp.Is) -> None: 1496 self._binary(e, " IS ") 1497 1498 def like_sql(self, e: exp.Like) -> None: 1499 self._binary(e, " Like ") 1500 1501 def literal_sql(self, e: exp.Literal) -> None: 1502 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1503 1504 def lt_sql(self, e: exp.LT) -> None: 1505 self._binary(e, " < ") 1506 1507 def lte_sql(self, e: exp.LTE) -> None: 1508 self._binary(e, " <= ") 1509 1510 def mod_sql(self, e: exp.Mod) -> None: 1511 self._binary(e, " % ") 1512 1513 def mul_sql(self, e: exp.Mul) -> None: 1514 self._binary(e, " * ") 1515 1516 def neg_sql(self, e: exp.Neg) -> None: 1517 self._unary(e, "-") 1518 1519 def neq_sql(self, e: exp.NEQ) -> None: 1520 self._binary(e, " <> ") 1521 1522 def not_sql(self, e: exp.Not) -> None: 1523 self._unary(e, "NOT ") 1524 1525 def null_sql(self, e: exp.Null) -> None: 1526 self.stack.append("NULL") 1527 1528 def or_sql(self, e: exp.Or) -> None: 1529 self._binary(e, " OR ") 1530 1531 def paren_sql(self, e: exp.Paren) -> None: 1532 self.stack.extend( 1533 ( 1534 ")", 1535 e.this, 1536 "(", 1537 ) 1538 ) 1539 1540 def sub_sql(self, e: exp.Sub) -> None: 1541 self._binary(e, " - ") 1542 1543 def subquery_sql(self, e: exp.Subquery) -> None: 1544 self._args(e, 2) 1545 alias = e.args.get("alias") 1546 if alias: 1547 self.stack.append(alias) 1548 self.stack.extend((")", e.this, "(")) 1549 1550 def table_sql(self, e: exp.Table) -> None: 1551 self._args(e, 4) 1552 alias = e.args.get("alias") 1553 if alias: 1554 self.stack.append(alias) 1555 for p in reversed(e.parts): 1556 self.stack.extend((p, ".")) 1557 self.stack.pop() 1558 1559 def tablealias_sql(self, e: exp.TableAlias) -> None: 1560 columns = e.columns 1561 1562 if columns: 1563 self.stack.extend((")", columns, "(")) 1564 1565 self.stack.extend((e.this, " AS ")) 1566 1567 def var_sql(self, e: exp.Var) -> None: 1568 self.stack.append(e.this) 1569 1570 def _binary(self, e: exp.Binary, op: str) -> None: 1571 self.stack.extend((e.expression, op, e.this)) 1572 1573 def _unary(self, e: exp.Unary, op: str) -> None: 1574 self.stack.extend((e.this, op)) 1575 1576 def _function(self, e: exp.Func) -> None: 1577 self.stack.extend( 1578 ( 1579 ")", 1580 list(e.args.values()), 1581 "(", 1582 e.sql_name(), 1583 ) 1584 ) 1585 1586 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1587 kvs = [] 1588 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1589 1590 for k in arg_types or arg_types: 1591 v = node.args.get(k) 1592 1593 if v is not None: 1594 kvs.append([f":{k}", v]) 1595 if kvs: 1596 self.stack.append(kvs) 1597 return True 1598 return False
1356 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1357 self.stack = [expression] 1358 self.sqls.clear() 1359 1360 while self.stack: 1361 node = self.stack.pop() 1362 1363 if isinstance(node, exp.Expression): 1364 if comments and node.comments: 1365 self.stack.append(f" /*{','.join(node.comments)}*/") 1366 1367 exp_handler_name = f"{node.key}_sql" 1368 1369 if hasattr(self, exp_handler_name): 1370 getattr(self, exp_handler_name)(node) 1371 elif isinstance(node, exp.Func): 1372 self._function(node) 1373 else: 1374 key = node.key.upper() 1375 self.stack.append(f"{key} " if self._args(node) else key) 1376 elif type(node) is list: 1377 for n in reversed(node): 1378 if n is not None: 1379 self.stack.extend((n, ",")) 1380 if node: 1381 self.stack.pop() 1382 else: 1383 if node is not None: 1384 self.sqls.append(str(node)) 1385 1386 return "".join(self.sqls)
1403 def anonymous_sql(self, e: exp.Anonymous) -> None: 1404 this = e.this 1405 if isinstance(this, str): 1406 name = this.upper() 1407 elif isinstance(this, exp.Identifier): 1408 name = this.this 1409 name = f'"{name}"' if this.quoted else name.upper() 1410 else: 1411 raise ValueError( 1412 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1413 ) 1414 1415 self.stack.extend( 1416 ( 1417 ")", 1418 e.expressions, 1419 "(", 1420 name, 1421 ) 1422 )