sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 dialect: DialectType = None, 43 max_depth: t.Optional[int] = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 58 Returns: 59 sqlglot.Expression: simplified expression 60 """ 61 62 dialect = Dialect.get_or_raise(dialect) 63 64 def _simplify(expression, root=True): 65 if ( 66 max_depth 67 and isinstance(expression, exp.Connector) 68 and not isinstance(expression.parent, exp.Connector) 69 ): 70 depth = connector_depth(expression) 71 if depth > max_depth: 72 logger.info( 73 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 74 ) 75 return expression 76 77 if expression.meta.get(FINAL): 78 return expression 79 80 # group by expressions cannot be simplified, for example 81 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 82 # the projection must exactly match the group by key 83 group = expression.args.get("group") 84 85 if group and hasattr(expression, "selects"): 86 groups = set(group.expressions) 87 group.meta[FINAL] = True 88 89 for e in expression.selects: 90 for node in e.walk(): 91 if node in groups: 92 e.meta[FINAL] = True 93 break 94 95 having = expression.args.get("having") 96 if having: 97 for node in having.walk(): 98 if node in groups: 99 having.meta[FINAL] = True 100 break 101 102 # Pre-order transformations 103 node = expression 104 node = rewrite_between(node) 105 node = uniq_sort(node, root) 106 node = absorb_and_eliminate(node, root) 107 node = simplify_concat(node) 108 node = simplify_conditionals(node) 109 110 if constant_propagation: 111 node = propagate_constants(node, root) 112 113 exp.replace_children(node, lambda e: _simplify(e, False)) 114 115 # Post-order transformations 116 node = simplify_not(node) 117 node = flatten(node) 118 node = simplify_connectors(node, root) 119 node = remove_complements(node, root) 120 node = simplify_coalesce(node) 121 node.parent = expression.parent 122 node = simplify_literals(node, root) 123 node = simplify_equality(node) 124 node = simplify_parens(node) 125 node = simplify_datetrunc(node, dialect) 126 node = sort_comparison(node) 127 node = simplify_startswith(node) 128 129 if root: 130 expression.replace(node) 131 return node 132 133 expression = while_changing(expression, _simplify) 134 remove_where_true(expression) 135 return expression 136 137 138def connector_depth(expression: exp.Expression) -> int: 139 """ 140 Determine the maximum depth of a tree of Connectors. 141 142 For example: 143 >>> from sqlglot import parse_one 144 >>> connector_depth(parse_one("a AND b AND c AND d")) 145 3 146 """ 147 stack = deque([(expression, 0)]) 148 max_depth = 0 149 150 while stack: 151 expression, depth = stack.pop() 152 153 if not isinstance(expression, exp.Connector): 154 continue 155 156 depth += 1 157 max_depth = max(depth, max_depth) 158 159 stack.append((expression.left, depth)) 160 stack.append((expression.right, depth)) 161 162 return max_depth 163 164 165def catch(*exceptions): 166 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 167 168 def decorator(func): 169 def wrapped(expression, *args, **kwargs): 170 try: 171 return func(expression, *args, **kwargs) 172 except exceptions: 173 return expression 174 175 return wrapped 176 177 return decorator 178 179 180def rewrite_between(expression: exp.Expression) -> exp.Expression: 181 """Rewrite x between y and z to x >= y AND x <= z. 182 183 This is done because comparison simplification is only done on lt/lte/gt/gte. 184 """ 185 if isinstance(expression, exp.Between): 186 negate = isinstance(expression.parent, exp.Not) 187 188 expression = exp.and_( 189 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 190 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 191 copy=False, 192 ) 193 194 if negate: 195 expression = exp.paren(expression, copy=False) 196 197 return expression 198 199 200COMPLEMENT_COMPARISONS = { 201 exp.LT: exp.GTE, 202 exp.GT: exp.LTE, 203 exp.LTE: exp.GT, 204 exp.GTE: exp.LT, 205 exp.EQ: exp.NEQ, 206 exp.NEQ: exp.EQ, 207} 208 209COMPLEMENT_SUBQUERY_PREDICATES = { 210 exp.All: exp.Any, 211 exp.Any: exp.All, 212} 213 214 215def simplify_not(expression): 216 """ 217 Demorgan's Law 218 NOT (x OR y) -> NOT x AND NOT y 219 NOT (x AND y) -> NOT x OR NOT y 220 """ 221 if isinstance(expression, exp.Not): 222 this = expression.this 223 if is_null(this): 224 return exp.null() 225 if this.__class__ in COMPLEMENT_COMPARISONS: 226 right = this.expression 227 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 228 if complement_subquery_predicate: 229 right = complement_subquery_predicate(this=right.this) 230 231 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 232 if isinstance(this, exp.Paren): 233 condition = this.unnest() 234 if isinstance(condition, exp.And): 235 return exp.paren( 236 exp.or_( 237 exp.not_(condition.left, copy=False), 238 exp.not_(condition.right, copy=False), 239 copy=False, 240 ) 241 ) 242 if isinstance(condition, exp.Or): 243 return exp.paren( 244 exp.and_( 245 exp.not_(condition.left, copy=False), 246 exp.not_(condition.right, copy=False), 247 copy=False, 248 ) 249 ) 250 if is_null(condition): 251 return exp.null() 252 if always_true(this): 253 return exp.false() 254 if is_false(this): 255 return exp.true() 256 if isinstance(this, exp.Not): 257 # double negation 258 # NOT NOT x -> x 259 return this.this 260 return expression 261 262 263def flatten(expression): 264 """ 265 A AND (B AND C) -> A AND B AND C 266 A OR (B OR C) -> A OR B OR C 267 """ 268 if isinstance(expression, exp.Connector): 269 for node in expression.args.values(): 270 child = node.unnest() 271 if isinstance(child, expression.__class__): 272 node.replace(child) 273 return expression 274 275 276def simplify_connectors(expression, root=True): 277 def _simplify_connectors(expression, left, right): 278 if isinstance(expression, exp.And): 279 if is_false(left) or is_false(right): 280 return exp.false() 281 if is_zero(left) or is_zero(right): 282 return exp.false() 283 if is_null(left) or is_null(right): 284 return exp.null() 285 if always_true(left) and always_true(right): 286 return exp.true() 287 if always_true(left): 288 return right 289 if always_true(right): 290 return left 291 return _simplify_comparison(expression, left, right) 292 elif isinstance(expression, exp.Or): 293 if always_true(left) or always_true(right): 294 return exp.true() 295 if ( 296 (is_null(left) and is_null(right)) 297 or (is_null(left) and always_false(right)) 298 or (always_false(left) and is_null(right)) 299 ): 300 return exp.null() 301 if is_false(left): 302 return right 303 if is_false(right): 304 return left 305 return _simplify_comparison(expression, left, right, or_=True) 306 elif isinstance(expression, exp.Xor): 307 if left == right: 308 return exp.false() 309 310 if isinstance(expression, exp.Connector): 311 return _flat_simplify(expression, _simplify_connectors, root) 312 return expression 313 314 315LT_LTE = (exp.LT, exp.LTE) 316GT_GTE = (exp.GT, exp.GTE) 317 318COMPARISONS = ( 319 *LT_LTE, 320 *GT_GTE, 321 exp.EQ, 322 exp.NEQ, 323 exp.Is, 324) 325 326INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 327 exp.LT: exp.GT, 328 exp.GT: exp.LT, 329 exp.LTE: exp.GTE, 330 exp.GTE: exp.LTE, 331} 332 333NONDETERMINISTIC = (exp.Rand, exp.Randn) 334AND_OR = (exp.And, exp.Or) 335 336 337def _simplify_comparison(expression, left, right, or_=False): 338 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 339 ll, lr = left.args.values() 340 rl, rr = right.args.values() 341 342 largs = {ll, lr} 343 rargs = {rl, rr} 344 345 matching = largs & rargs 346 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 347 348 if matching and columns: 349 try: 350 l = first(largs - columns) 351 r = first(rargs - columns) 352 except StopIteration: 353 return expression 354 355 if l.is_number and r.is_number: 356 l = l.to_py() 357 r = r.to_py() 358 elif l.is_string and r.is_string: 359 l = l.name 360 r = r.name 361 else: 362 l = extract_date(l) 363 if not l: 364 return None 365 r = extract_date(r) 366 if not r: 367 return None 368 # python won't compare date and datetime, but many engines will upcast 369 l, r = cast_as_datetime(l), cast_as_datetime(r) 370 371 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 372 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 373 return left if (av > bv if or_ else av <= bv) else right 374 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 375 return left if (av < bv if or_ else av >= bv) else right 376 377 # we can't ever shortcut to true because the column could be null 378 if not or_: 379 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 380 if av <= bv: 381 return exp.false() 382 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 383 if av >= bv: 384 return exp.false() 385 elif isinstance(a, exp.EQ): 386 if isinstance(b, exp.LT): 387 return exp.false() if av >= bv else a 388 if isinstance(b, exp.LTE): 389 return exp.false() if av > bv else a 390 if isinstance(b, exp.GT): 391 return exp.false() if av <= bv else a 392 if isinstance(b, exp.GTE): 393 return exp.false() if av < bv else a 394 if isinstance(b, exp.NEQ): 395 return exp.false() if av == bv else a 396 return None 397 398 399def remove_complements(expression, root=True): 400 """ 401 Removing complements. 402 403 A AND NOT A -> FALSE 404 A OR NOT A -> TRUE 405 """ 406 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 407 ops = set(expression.flatten()) 408 for op in ops: 409 if isinstance(op, exp.Not) and op.this in ops: 410 return exp.false() if isinstance(expression, exp.And) else exp.true() 411 412 return expression 413 414 415def uniq_sort(expression, root=True): 416 """ 417 Uniq and sort a connector. 418 419 C AND A AND B AND B -> A AND B AND C 420 """ 421 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 422 flattened = tuple(expression.flatten()) 423 424 if isinstance(expression, exp.Xor): 425 result_func = exp.xor 426 # Do not deduplicate XOR as A XOR A != A if A == True 427 deduped = None 428 arr = tuple((gen(e), e) for e in flattened) 429 else: 430 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 431 deduped = {gen(e): e for e in flattened} 432 arr = tuple(deduped.items()) 433 434 # check if the operands are already sorted, if not sort them 435 # A AND C AND B -> A AND B AND C 436 for i, (sql, e) in enumerate(arr[1:]): 437 if sql < arr[i][0]: 438 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 439 break 440 else: 441 # we didn't have to sort but maybe we need to dedup 442 if deduped and len(deduped) < len(flattened): 443 expression = result_func(*deduped.values(), copy=False) 444 445 return expression 446 447 448def absorb_and_eliminate(expression, root=True): 449 """ 450 absorption: 451 A AND (A OR B) -> A 452 A OR (A AND B) -> A 453 A AND (NOT A OR B) -> A AND B 454 A OR (NOT A AND B) -> A OR B 455 elimination: 456 (A AND B) OR (A AND NOT B) -> A 457 (A OR B) AND (A OR NOT B) -> A 458 """ 459 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 460 kind = exp.Or if isinstance(expression, exp.And) else exp.And 461 462 ops = tuple(expression.flatten()) 463 464 # Initialize lookup tables: 465 # Set of all operands, used to find complements for absorption. 466 op_set = set() 467 # Sub-operands, used to find subsets for absorption. 468 subops = defaultdict(list) 469 # Pairs of complements, used for elimination. 470 pairs = defaultdict(list) 471 472 # Populate the lookup tables 473 for op in ops: 474 op_set.add(op) 475 476 if not isinstance(op, kind): 477 # In cases like: A OR (A AND B) 478 # Subop will be: ^ 479 subops[op].append({op}) 480 continue 481 482 # In cases like: (A AND B) OR (A AND B AND C) 483 # Subops will be: ^ ^ 484 subset = set(op.flatten()) 485 for i in subset: 486 subops[i].append(subset) 487 488 a, b = op.unnest_operands() 489 if isinstance(a, exp.Not): 490 pairs[frozenset((a.this, b))].append((op, b)) 491 if isinstance(b, exp.Not): 492 pairs[frozenset((a, b.this))].append((op, a)) 493 494 for op in ops: 495 if not isinstance(op, kind): 496 continue 497 498 a, b = op.unnest_operands() 499 500 # Absorb 501 if isinstance(a, exp.Not) and a.this in op_set: 502 a.replace(exp.true() if kind == exp.And else exp.false()) 503 continue 504 if isinstance(b, exp.Not) and b.this in op_set: 505 b.replace(exp.true() if kind == exp.And else exp.false()) 506 continue 507 superset = set(op.flatten()) 508 if any(any(subset < superset for subset in subops[i]) for i in superset): 509 op.replace(exp.false() if kind == exp.And else exp.true()) 510 continue 511 512 # Eliminate 513 for other, complement in pairs[frozenset((a, b))]: 514 op.replace(complement) 515 other.replace(complement) 516 517 return expression 518 519 520def propagate_constants(expression, root=True): 521 """ 522 Propagate constants for conjunctions in DNF: 523 524 SELECT * FROM t WHERE a = b AND b = 5 becomes 525 SELECT * FROM t WHERE a = 5 AND b = 5 526 527 Reference: https://www.sqlite.org/optoverview.html 528 """ 529 530 if ( 531 isinstance(expression, exp.And) 532 and (root or not expression.same_parent) 533 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 534 ): 535 constant_mapping = {} 536 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 537 if isinstance(expr, exp.EQ): 538 l, r = expr.left, expr.right 539 540 # TODO: create a helper that can be used to detect nested literal expressions such 541 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 542 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 543 constant_mapping[l] = (id(l), r) 544 545 if constant_mapping: 546 for column in find_all_in_scope(expression, exp.Column): 547 parent = column.parent 548 column_id, constant = constant_mapping.get(column) or (None, None) 549 if ( 550 column_id is not None 551 and id(column) != column_id 552 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 553 ): 554 column.replace(constant.copy()) 555 556 return expression 557 558 559INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 560 exp.DateAdd: exp.Sub, 561 exp.DateSub: exp.Add, 562 exp.DatetimeAdd: exp.Sub, 563 exp.DatetimeSub: exp.Add, 564} 565 566INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 567 **INVERSE_DATE_OPS, 568 exp.Add: exp.Sub, 569 exp.Sub: exp.Add, 570} 571 572 573def _is_number(expression: exp.Expression) -> bool: 574 return expression.is_number 575 576 577def _is_interval(expression: exp.Expression) -> bool: 578 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 579 580 581@catch(ModuleNotFoundError, UnsupportedUnit) 582def simplify_equality(expression: exp.Expression) -> exp.Expression: 583 """ 584 Use the subtraction and addition properties of equality to simplify expressions: 585 586 x + 1 = 3 becomes x = 2 587 588 There are two binary operations in the above expression: + and = 589 Here's how we reference all the operands in the code below: 590 591 l r 592 x + 1 = 3 593 a b 594 """ 595 if isinstance(expression, COMPARISONS): 596 l, r = expression.left, expression.right 597 598 if l.__class__ not in INVERSE_OPS: 599 return expression 600 601 if r.is_number: 602 a_predicate = _is_number 603 b_predicate = _is_number 604 elif _is_date_literal(r): 605 a_predicate = _is_date_literal 606 b_predicate = _is_interval 607 else: 608 return expression 609 610 if l.__class__ in INVERSE_DATE_OPS: 611 l = t.cast(exp.IntervalOp, l) 612 a = l.this 613 b = l.interval() 614 else: 615 l = t.cast(exp.Binary, l) 616 a, b = l.left, l.right 617 618 if not a_predicate(a) and b_predicate(b): 619 pass 620 elif not a_predicate(b) and b_predicate(a): 621 a, b = b, a 622 else: 623 return expression 624 625 return expression.__class__( 626 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 627 ) 628 return expression 629 630 631def simplify_literals(expression, root=True): 632 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 633 return _flat_simplify(expression, _simplify_binary, root) 634 635 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 636 return expression.this.this 637 638 if type(expression) in INVERSE_DATE_OPS: 639 return _simplify_binary(expression, expression.this, expression.interval()) or expression 640 641 return expression 642 643 644NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 645 646 647def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 648 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 649 this = _simplify_integer_cast(expr.this) 650 else: 651 this = expr.this 652 653 if isinstance(expr, exp.Cast) and this.is_int: 654 num = this.to_py() 655 656 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 657 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 658 # engine-dependent 659 if ( 660 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 661 ) or ( 662 UTINYINT_MIN <= num <= UTINYINT_MAX 663 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 664 ): 665 return this 666 667 return expr 668 669 670def _simplify_binary(expression, a, b): 671 if isinstance(expression, COMPARISONS): 672 a = _simplify_integer_cast(a) 673 b = _simplify_integer_cast(b) 674 675 if isinstance(expression, exp.Is): 676 if isinstance(b, exp.Not): 677 c = b.this 678 not_ = True 679 else: 680 c = b 681 not_ = False 682 683 if is_null(c): 684 if isinstance(a, exp.Literal): 685 return exp.true() if not_ else exp.false() 686 if is_null(a): 687 return exp.false() if not_ else exp.true() 688 elif isinstance(expression, NULL_OK): 689 return None 690 elif is_null(a) or is_null(b): 691 return exp.null() 692 693 if a.is_number and b.is_number: 694 num_a = a.to_py() 695 num_b = b.to_py() 696 697 if isinstance(expression, exp.Add): 698 return exp.Literal.number(num_a + num_b) 699 if isinstance(expression, exp.Mul): 700 return exp.Literal.number(num_a * num_b) 701 702 # We only simplify Sub, Div if a and b have the same parent because they're not associative 703 if isinstance(expression, exp.Sub): 704 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 705 if isinstance(expression, exp.Div): 706 # engines have differing int div behavior so intdiv is not safe 707 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 708 return None 709 return exp.Literal.number(num_a / num_b) 710 711 boolean = eval_boolean(expression, num_a, num_b) 712 713 if boolean: 714 return boolean 715 elif a.is_string and b.is_string: 716 boolean = eval_boolean(expression, a.this, b.this) 717 718 if boolean: 719 return boolean 720 elif _is_date_literal(a) and isinstance(b, exp.Interval): 721 date, b = extract_date(a), extract_interval(b) 722 if date and b: 723 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 724 return date_literal(date + b, extract_type(a)) 725 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 726 return date_literal(date - b, extract_type(a)) 727 elif isinstance(a, exp.Interval) and _is_date_literal(b): 728 a, date = extract_interval(a), extract_date(b) 729 # you cannot subtract a date from an interval 730 if a and b and isinstance(expression, exp.Add): 731 return date_literal(a + date, extract_type(b)) 732 elif _is_date_literal(a) and _is_date_literal(b): 733 if isinstance(expression, exp.Predicate): 734 a, b = extract_date(a), extract_date(b) 735 boolean = eval_boolean(expression, a, b) 736 if boolean: 737 return boolean 738 739 return None 740 741 742def simplify_parens(expression): 743 if not isinstance(expression, exp.Paren): 744 return expression 745 746 this = expression.this 747 parent = expression.parent 748 parent_is_predicate = isinstance(parent, exp.Predicate) 749 750 if ( 751 not isinstance(this, exp.Select) 752 and not isinstance(parent, exp.SubqueryPredicate) 753 and ( 754 not isinstance(parent, (exp.Condition, exp.Binary)) 755 or isinstance(parent, exp.Paren) 756 or ( 757 not isinstance(this, exp.Binary) 758 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 759 ) 760 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 761 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 762 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 763 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 764 ) 765 ): 766 return this 767 return expression 768 769 770def _is_nonnull_constant(expression: exp.Expression) -> bool: 771 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 772 773 774def _is_constant(expression: exp.Expression) -> bool: 775 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 776 777 778def simplify_coalesce(expression): 779 # COALESCE(x) -> x 780 if ( 781 isinstance(expression, exp.Coalesce) 782 and (not expression.expressions or _is_nonnull_constant(expression.this)) 783 # COALESCE is also used as a Spark partitioning hint 784 and not isinstance(expression.parent, exp.Hint) 785 ): 786 return expression.this 787 788 if not isinstance(expression, COMPARISONS): 789 return expression 790 791 if isinstance(expression.left, exp.Coalesce): 792 coalesce = expression.left 793 other = expression.right 794 elif isinstance(expression.right, exp.Coalesce): 795 coalesce = expression.right 796 other = expression.left 797 else: 798 return expression 799 800 # This transformation is valid for non-constants, 801 # but it really only does anything if they are both constants. 802 if not _is_constant(other): 803 return expression 804 805 # Find the first constant arg 806 for arg_index, arg in enumerate(coalesce.expressions): 807 if _is_constant(arg): 808 break 809 else: 810 return expression 811 812 coalesce.set("expressions", coalesce.expressions[:arg_index]) 813 814 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 815 # since we already remove COALESCE at the top of this function. 816 coalesce = coalesce if coalesce.expressions else coalesce.this 817 818 # This expression is more complex than when we started, but it will get simplified further 819 return exp.paren( 820 exp.or_( 821 exp.and_( 822 coalesce.is_(exp.null()).not_(copy=False), 823 expression.copy(), 824 copy=False, 825 ), 826 exp.and_( 827 coalesce.is_(exp.null()), 828 type(expression)(this=arg.copy(), expression=other.copy()), 829 copy=False, 830 ), 831 copy=False, 832 ) 833 ) 834 835 836CONCATS = (exp.Concat, exp.DPipe) 837 838 839def simplify_concat(expression): 840 """Reduces all groups that contain string literals by concatenating them.""" 841 if not isinstance(expression, CONCATS) or ( 842 # We can't reduce a CONCAT_WS call if we don't statically know the separator 843 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 844 ): 845 return expression 846 847 if isinstance(expression, exp.ConcatWs): 848 sep_expr, *expressions = expression.expressions 849 sep = sep_expr.name 850 concat_type = exp.ConcatWs 851 args = {} 852 else: 853 expressions = expression.expressions 854 sep = "" 855 concat_type = exp.Concat 856 args = { 857 "safe": expression.args.get("safe"), 858 "coalesce": expression.args.get("coalesce"), 859 } 860 861 new_args = [] 862 for is_string_group, group in itertools.groupby( 863 expressions or expression.flatten(), lambda e: e.is_string 864 ): 865 if is_string_group: 866 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 867 else: 868 new_args.extend(group) 869 870 if len(new_args) == 1 and new_args[0].is_string: 871 return new_args[0] 872 873 if concat_type is exp.ConcatWs: 874 new_args = [sep_expr] + new_args 875 elif isinstance(expression, exp.DPipe): 876 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 877 878 return concat_type(expressions=new_args, **args) 879 880 881def simplify_conditionals(expression): 882 """Simplifies expressions like IF, CASE if their condition is statically known.""" 883 if isinstance(expression, exp.Case): 884 this = expression.this 885 for case in expression.args["ifs"]: 886 cond = case.this 887 if this: 888 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 889 cond = cond.replace(this.pop().eq(cond)) 890 891 if always_true(cond): 892 return case.args["true"] 893 894 if always_false(cond): 895 case.pop() 896 if not expression.args["ifs"]: 897 return expression.args.get("default") or exp.null() 898 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 899 if always_true(expression.this): 900 return expression.args["true"] 901 if always_false(expression.this): 902 return expression.args.get("false") or exp.null() 903 904 return expression 905 906 907def simplify_startswith(expression: exp.Expression) -> exp.Expression: 908 """ 909 Reduces a prefix check to either TRUE or FALSE if both the string and the 910 prefix are statically known. 911 912 Example: 913 >>> from sqlglot import parse_one 914 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 915 'TRUE' 916 """ 917 if ( 918 isinstance(expression, exp.StartsWith) 919 and expression.this.is_string 920 and expression.expression.is_string 921 ): 922 return exp.convert(expression.name.startswith(expression.expression.name)) 923 924 return expression 925 926 927DateRange = t.Tuple[datetime.date, datetime.date] 928 929 930def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 931 """ 932 Get the date range for a DATE_TRUNC equality comparison: 933 934 Example: 935 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 936 Returns: 937 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 938 """ 939 floor = date_floor(date, unit, dialect) 940 941 if date != floor: 942 # This will always be False, except for NULL values. 943 return None 944 945 return floor, floor + interval(unit) 946 947 948def _datetrunc_eq_expression( 949 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 950) -> exp.Expression: 951 """Get the logical expression for a date range""" 952 return exp.and_( 953 left >= date_literal(drange[0], target_type), 954 left < date_literal(drange[1], target_type), 955 copy=False, 956 ) 957 958 959def _datetrunc_eq( 960 left: exp.Expression, 961 date: datetime.date, 962 unit: str, 963 dialect: Dialect, 964 target_type: t.Optional[exp.DataType], 965) -> t.Optional[exp.Expression]: 966 drange = _datetrunc_range(date, unit, dialect) 967 if not drange: 968 return None 969 970 return _datetrunc_eq_expression(left, drange, target_type) 971 972 973def _datetrunc_neq( 974 left: exp.Expression, 975 date: datetime.date, 976 unit: str, 977 dialect: Dialect, 978 target_type: t.Optional[exp.DataType], 979) -> t.Optional[exp.Expression]: 980 drange = _datetrunc_range(date, unit, dialect) 981 if not drange: 982 return None 983 984 return exp.and_( 985 left < date_literal(drange[0], target_type), 986 left >= date_literal(drange[1], target_type), 987 copy=False, 988 ) 989 990 991DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 992 exp.LT: lambda l, dt, u, d, t: l 993 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 994 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 995 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 996 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 997 exp.EQ: _datetrunc_eq, 998 exp.NEQ: _datetrunc_neq, 999} 1000DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 1001DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 1002 1003 1004def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 1005 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 1006 1007 1008@catch(ModuleNotFoundError, UnsupportedUnit) 1009def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1010 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1011 comparison = expression.__class__ 1012 1013 if isinstance(expression, DATETRUNCS): 1014 this = expression.this 1015 trunc_type = extract_type(this) 1016 date = extract_date(this) 1017 if date and expression.unit: 1018 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1019 elif comparison not in DATETRUNC_COMPARISONS: 1020 return expression 1021 1022 if isinstance(expression, exp.Binary): 1023 l, r = expression.left, expression.right 1024 1025 if not _is_datetrunc_predicate(l, r): 1026 return expression 1027 1028 l = t.cast(exp.DateTrunc, l) 1029 trunc_arg = l.this 1030 unit = l.unit.name.lower() 1031 date = extract_date(r) 1032 1033 if not date: 1034 return expression 1035 1036 return ( 1037 DATETRUNC_BINARY_COMPARISONS[comparison]( 1038 trunc_arg, date, unit, dialect, extract_type(r) 1039 ) 1040 or expression 1041 ) 1042 1043 if isinstance(expression, exp.In): 1044 l = expression.this 1045 rs = expression.expressions 1046 1047 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1048 l = t.cast(exp.DateTrunc, l) 1049 unit = l.unit.name.lower() 1050 1051 ranges = [] 1052 for r in rs: 1053 date = extract_date(r) 1054 if not date: 1055 return expression 1056 drange = _datetrunc_range(date, unit, dialect) 1057 if drange: 1058 ranges.append(drange) 1059 1060 if not ranges: 1061 return expression 1062 1063 ranges = merge_ranges(ranges) 1064 target_type = extract_type(*rs) 1065 1066 return exp.or_( 1067 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1068 ) 1069 1070 return expression 1071 1072 1073def sort_comparison(expression: exp.Expression) -> exp.Expression: 1074 if expression.__class__ in COMPLEMENT_COMPARISONS: 1075 l, r = expression.this, expression.expression 1076 l_column = isinstance(l, exp.Column) 1077 r_column = isinstance(r, exp.Column) 1078 l_const = _is_constant(l) 1079 r_const = _is_constant(r) 1080 1081 if ( 1082 (l_column and not r_column) 1083 or (r_const and not l_const) 1084 or isinstance(r, exp.SubqueryPredicate) 1085 ): 1086 return expression 1087 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1088 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1089 this=r, expression=l 1090 ) 1091 return expression 1092 1093 1094# CROSS joins result in an empty table if the right table is empty. 1095# So we can only simplify certain types of joins to CROSS. 1096# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1097JOINS = { 1098 ("", ""), 1099 ("", "INNER"), 1100 ("RIGHT", ""), 1101 ("RIGHT", "OUTER"), 1102} 1103 1104 1105def remove_where_true(expression): 1106 for where in expression.find_all(exp.Where): 1107 if always_true(where.this): 1108 where.pop() 1109 for join in expression.find_all(exp.Join): 1110 if ( 1111 always_true(join.args.get("on")) 1112 and not join.args.get("using") 1113 and not join.args.get("method") 1114 and (join.side, join.kind) in JOINS 1115 ): 1116 join.args["on"].pop() 1117 join.set("side", None) 1118 join.set("kind", "CROSS") 1119 1120 1121def always_true(expression): 1122 return (isinstance(expression, exp.Boolean) and expression.this) or ( 1123 isinstance(expression, exp.Literal) and not is_zero(expression) 1124 ) 1125 1126 1127def always_false(expression): 1128 return is_false(expression) or is_null(expression) or is_zero(expression) 1129 1130 1131def is_zero(expression): 1132 return isinstance(expression, exp.Literal) and expression.to_py() == 0 1133 1134 1135def is_complement(a, b): 1136 return isinstance(b, exp.Not) and b.this == a 1137 1138 1139def is_false(a: exp.Expression) -> bool: 1140 return type(a) is exp.Boolean and not a.this 1141 1142 1143def is_null(a: exp.Expression) -> bool: 1144 return type(a) is exp.Null 1145 1146 1147def eval_boolean(expression, a, b): 1148 if isinstance(expression, (exp.EQ, exp.Is)): 1149 return boolean_literal(a == b) 1150 if isinstance(expression, exp.NEQ): 1151 return boolean_literal(a != b) 1152 if isinstance(expression, exp.GT): 1153 return boolean_literal(a > b) 1154 if isinstance(expression, exp.GTE): 1155 return boolean_literal(a >= b) 1156 if isinstance(expression, exp.LT): 1157 return boolean_literal(a < b) 1158 if isinstance(expression, exp.LTE): 1159 return boolean_literal(a <= b) 1160 return None 1161 1162 1163def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1164 if isinstance(value, datetime.datetime): 1165 return value.date() 1166 if isinstance(value, datetime.date): 1167 return value 1168 try: 1169 return datetime.datetime.fromisoformat(value).date() 1170 except ValueError: 1171 return None 1172 1173 1174def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1175 if isinstance(value, datetime.datetime): 1176 return value 1177 if isinstance(value, datetime.date): 1178 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1179 try: 1180 return datetime.datetime.fromisoformat(value) 1181 except ValueError: 1182 return None 1183 1184 1185def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1186 if not value: 1187 return None 1188 if to.is_type(exp.DataType.Type.DATE): 1189 return cast_as_date(value) 1190 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1191 return cast_as_datetime(value) 1192 return None 1193 1194 1195def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1196 if isinstance(cast, exp.Cast): 1197 to = cast.to 1198 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1199 to = exp.DataType.build(exp.DataType.Type.DATE) 1200 else: 1201 return None 1202 1203 if isinstance(cast.this, exp.Literal): 1204 value: t.Any = cast.this.name 1205 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1206 value = extract_date(cast.this) 1207 else: 1208 return None 1209 return cast_value(value, to) 1210 1211 1212def _is_date_literal(expression: exp.Expression) -> bool: 1213 return extract_date(expression) is not None 1214 1215 1216def extract_interval(expression): 1217 try: 1218 n = int(expression.this.to_py()) 1219 unit = expression.text("unit").lower() 1220 return interval(unit, n) 1221 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1222 return None 1223 1224 1225def extract_type(*expressions): 1226 target_type = None 1227 for expression in expressions: 1228 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1229 if target_type: 1230 break 1231 1232 return target_type 1233 1234 1235def date_literal(date, target_type=None): 1236 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1237 target_type = ( 1238 exp.DataType.Type.DATETIME 1239 if isinstance(date, datetime.datetime) 1240 else exp.DataType.Type.DATE 1241 ) 1242 1243 return exp.cast(exp.Literal.string(date), target_type) 1244 1245 1246def interval(unit: str, n: int = 1): 1247 from dateutil.relativedelta import relativedelta 1248 1249 if unit == "year": 1250 return relativedelta(years=1 * n) 1251 if unit == "quarter": 1252 return relativedelta(months=3 * n) 1253 if unit == "month": 1254 return relativedelta(months=1 * n) 1255 if unit == "week": 1256 return relativedelta(weeks=1 * n) 1257 if unit == "day": 1258 return relativedelta(days=1 * n) 1259 if unit == "hour": 1260 return relativedelta(hours=1 * n) 1261 if unit == "minute": 1262 return relativedelta(minutes=1 * n) 1263 if unit == "second": 1264 return relativedelta(seconds=1 * n) 1265 1266 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1267 1268 1269def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1270 if unit == "year": 1271 return d.replace(month=1, day=1) 1272 if unit == "quarter": 1273 if d.month <= 3: 1274 return d.replace(month=1, day=1) 1275 elif d.month <= 6: 1276 return d.replace(month=4, day=1) 1277 elif d.month <= 9: 1278 return d.replace(month=7, day=1) 1279 else: 1280 return d.replace(month=10, day=1) 1281 if unit == "month": 1282 return d.replace(month=d.month, day=1) 1283 if unit == "week": 1284 # Assuming week starts on Monday (0) and ends on Sunday (6) 1285 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1286 if unit == "day": 1287 return d 1288 1289 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1290 1291 1292def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1293 floor = date_floor(d, unit, dialect) 1294 1295 if floor == d: 1296 return d 1297 1298 return floor + interval(unit) 1299 1300 1301def boolean_literal(condition): 1302 return exp.true() if condition else exp.false() 1303 1304 1305def _flat_simplify(expression, simplifier, root=True): 1306 if root or not expression.same_parent: 1307 operands = [] 1308 queue = deque(expression.flatten(unnest=False)) 1309 size = len(queue) 1310 1311 while queue: 1312 a = queue.popleft() 1313 1314 for b in queue: 1315 result = simplifier(expression, a, b) 1316 1317 if result and result is not expression: 1318 queue.remove(b) 1319 queue.appendleft(result) 1320 break 1321 else: 1322 operands.append(a) 1323 1324 if len(operands) < size: 1325 return functools.reduce( 1326 lambda a, b: expression.__class__(this=a, expression=b), operands 1327 ) 1328 return expression 1329 1330 1331def gen(expression: t.Any) -> str: 1332 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1333 1334 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1335 generator is expensive so we have a bare minimum sql generator here. 1336 """ 1337 return Gen().gen(expression) 1338 1339 1340class Gen: 1341 def __init__(self): 1342 self.stack = [] 1343 self.sqls = [] 1344 1345 def gen(self, expression: exp.Expression) -> str: 1346 self.stack = [expression] 1347 self.sqls.clear() 1348 1349 while self.stack: 1350 node = self.stack.pop() 1351 1352 if isinstance(node, exp.Expression): 1353 exp_handler_name = f"{node.key}_sql" 1354 1355 if hasattr(self, exp_handler_name): 1356 getattr(self, exp_handler_name)(node) 1357 elif isinstance(node, exp.Func): 1358 self._function(node) 1359 else: 1360 key = node.key.upper() 1361 self.stack.append(f"{key} " if self._args(node) else key) 1362 elif type(node) is list: 1363 for n in reversed(node): 1364 if n is not None: 1365 self.stack.extend((n, ",")) 1366 if node: 1367 self.stack.pop() 1368 else: 1369 if node is not None: 1370 self.sqls.append(str(node)) 1371 1372 return "".join(self.sqls) 1373 1374 def add_sql(self, e: exp.Add) -> None: 1375 self._binary(e, " + ") 1376 1377 def alias_sql(self, e: exp.Alias) -> None: 1378 self.stack.extend( 1379 ( 1380 e.args.get("alias"), 1381 " AS ", 1382 e.args.get("this"), 1383 ) 1384 ) 1385 1386 def and_sql(self, e: exp.And) -> None: 1387 self._binary(e, " AND ") 1388 1389 def anonymous_sql(self, e: exp.Anonymous) -> None: 1390 this = e.this 1391 if isinstance(this, str): 1392 name = this.upper() 1393 elif isinstance(this, exp.Identifier): 1394 name = this.this 1395 name = f'"{name}"' if this.quoted else name.upper() 1396 else: 1397 raise ValueError( 1398 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1399 ) 1400 1401 self.stack.extend( 1402 ( 1403 ")", 1404 e.expressions, 1405 "(", 1406 name, 1407 ) 1408 ) 1409 1410 def between_sql(self, e: exp.Between) -> None: 1411 self.stack.extend( 1412 ( 1413 e.args.get("high"), 1414 " AND ", 1415 e.args.get("low"), 1416 " BETWEEN ", 1417 e.this, 1418 ) 1419 ) 1420 1421 def boolean_sql(self, e: exp.Boolean) -> None: 1422 self.stack.append("TRUE" if e.this else "FALSE") 1423 1424 def bracket_sql(self, e: exp.Bracket) -> None: 1425 self.stack.extend( 1426 ( 1427 "]", 1428 e.expressions, 1429 "[", 1430 e.this, 1431 ) 1432 ) 1433 1434 def column_sql(self, e: exp.Column) -> None: 1435 for p in reversed(e.parts): 1436 self.stack.extend((p, ".")) 1437 self.stack.pop() 1438 1439 def datatype_sql(self, e: exp.DataType) -> None: 1440 self._args(e, 1) 1441 self.stack.append(f"{e.this.name} ") 1442 1443 def div_sql(self, e: exp.Div) -> None: 1444 self._binary(e, " / ") 1445 1446 def dot_sql(self, e: exp.Dot) -> None: 1447 self._binary(e, ".") 1448 1449 def eq_sql(self, e: exp.EQ) -> None: 1450 self._binary(e, " = ") 1451 1452 def from_sql(self, e: exp.From) -> None: 1453 self.stack.extend((e.this, "FROM ")) 1454 1455 def gt_sql(self, e: exp.GT) -> None: 1456 self._binary(e, " > ") 1457 1458 def gte_sql(self, e: exp.GTE) -> None: 1459 self._binary(e, " >= ") 1460 1461 def identifier_sql(self, e: exp.Identifier) -> None: 1462 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1463 1464 def ilike_sql(self, e: exp.ILike) -> None: 1465 self._binary(e, " ILIKE ") 1466 1467 def in_sql(self, e: exp.In) -> None: 1468 self.stack.append(")") 1469 self._args(e, 1) 1470 self.stack.extend( 1471 ( 1472 "(", 1473 " IN ", 1474 e.this, 1475 ) 1476 ) 1477 1478 def intdiv_sql(self, e: exp.IntDiv) -> None: 1479 self._binary(e, " DIV ") 1480 1481 def is_sql(self, e: exp.Is) -> None: 1482 self._binary(e, " IS ") 1483 1484 def like_sql(self, e: exp.Like) -> None: 1485 self._binary(e, " Like ") 1486 1487 def literal_sql(self, e: exp.Literal) -> None: 1488 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1489 1490 def lt_sql(self, e: exp.LT) -> None: 1491 self._binary(e, " < ") 1492 1493 def lte_sql(self, e: exp.LTE) -> None: 1494 self._binary(e, " <= ") 1495 1496 def mod_sql(self, e: exp.Mod) -> None: 1497 self._binary(e, " % ") 1498 1499 def mul_sql(self, e: exp.Mul) -> None: 1500 self._binary(e, " * ") 1501 1502 def neg_sql(self, e: exp.Neg) -> None: 1503 self._unary(e, "-") 1504 1505 def neq_sql(self, e: exp.NEQ) -> None: 1506 self._binary(e, " <> ") 1507 1508 def not_sql(self, e: exp.Not) -> None: 1509 self._unary(e, "NOT ") 1510 1511 def null_sql(self, e: exp.Null) -> None: 1512 self.stack.append("NULL") 1513 1514 def or_sql(self, e: exp.Or) -> None: 1515 self._binary(e, " OR ") 1516 1517 def paren_sql(self, e: exp.Paren) -> None: 1518 self.stack.extend( 1519 ( 1520 ")", 1521 e.this, 1522 "(", 1523 ) 1524 ) 1525 1526 def sub_sql(self, e: exp.Sub) -> None: 1527 self._binary(e, " - ") 1528 1529 def subquery_sql(self, e: exp.Subquery) -> None: 1530 self._args(e, 2) 1531 alias = e.args.get("alias") 1532 if alias: 1533 self.stack.append(alias) 1534 self.stack.extend((")", e.this, "(")) 1535 1536 def table_sql(self, e: exp.Table) -> None: 1537 self._args(e, 4) 1538 alias = e.args.get("alias") 1539 if alias: 1540 self.stack.append(alias) 1541 for p in reversed(e.parts): 1542 self.stack.extend((p, ".")) 1543 self.stack.pop() 1544 1545 def tablealias_sql(self, e: exp.TableAlias) -> None: 1546 columns = e.columns 1547 1548 if columns: 1549 self.stack.extend((")", columns, "(")) 1550 1551 self.stack.extend((e.this, " AS ")) 1552 1553 def var_sql(self, e: exp.Var) -> None: 1554 self.stack.append(e.this) 1555 1556 def _binary(self, e: exp.Binary, op: str) -> None: 1557 self.stack.extend((e.expression, op, e.this)) 1558 1559 def _unary(self, e: exp.Unary, op: str) -> None: 1560 self.stack.extend((e.this, op)) 1561 1562 def _function(self, e: exp.Func) -> None: 1563 self.stack.extend( 1564 ( 1565 ")", 1566 list(e.args.values()), 1567 "(", 1568 e.sql_name(), 1569 ) 1570 ) 1571 1572 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1573 kvs = [] 1574 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1575 1576 for k in arg_types or arg_types: 1577 v = node.args.get(k) 1578 1579 if v is not None: 1580 kvs.append([f":{k}", v]) 1581 if kvs: 1582 self.stack.append(kvs) 1583 return True 1584 return False
Common base class for all non-exit exceptions.
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 dialect: DialectType = None, 44 max_depth: t.Optional[int] = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 59 Returns: 60 sqlglot.Expression: simplified expression 61 """ 62 63 dialect = Dialect.get_or_raise(dialect) 64 65 def _simplify(expression, root=True): 66 if ( 67 max_depth 68 and isinstance(expression, exp.Connector) 69 and not isinstance(expression.parent, exp.Connector) 70 ): 71 depth = connector_depth(expression) 72 if depth > max_depth: 73 logger.info( 74 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 75 ) 76 return expression 77 78 if expression.meta.get(FINAL): 79 return expression 80 81 # group by expressions cannot be simplified, for example 82 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 83 # the projection must exactly match the group by key 84 group = expression.args.get("group") 85 86 if group and hasattr(expression, "selects"): 87 groups = set(group.expressions) 88 group.meta[FINAL] = True 89 90 for e in expression.selects: 91 for node in e.walk(): 92 if node in groups: 93 e.meta[FINAL] = True 94 break 95 96 having = expression.args.get("having") 97 if having: 98 for node in having.walk(): 99 if node in groups: 100 having.meta[FINAL] = True 101 break 102 103 # Pre-order transformations 104 node = expression 105 node = rewrite_between(node) 106 node = uniq_sort(node, root) 107 node = absorb_and_eliminate(node, root) 108 node = simplify_concat(node) 109 node = simplify_conditionals(node) 110 111 if constant_propagation: 112 node = propagate_constants(node, root) 113 114 exp.replace_children(node, lambda e: _simplify(e, False)) 115 116 # Post-order transformations 117 node = simplify_not(node) 118 node = flatten(node) 119 node = simplify_connectors(node, root) 120 node = remove_complements(node, root) 121 node = simplify_coalesce(node) 122 node.parent = expression.parent 123 node = simplify_literals(node, root) 124 node = simplify_equality(node) 125 node = simplify_parens(node) 126 node = simplify_datetrunc(node, dialect) 127 node = sort_comparison(node) 128 node = simplify_startswith(node) 129 130 if root: 131 expression.replace(node) 132 return node 133 134 expression = while_changing(expression, _simplify) 135 remove_where_true(expression) 136 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- max_depth: Chains of Connectors (AND, OR, etc) exceeding
max_depth
will be skipped
Returns:
sqlglot.Expression: simplified expression
139def connector_depth(expression: exp.Expression) -> int: 140 """ 141 Determine the maximum depth of a tree of Connectors. 142 143 For example: 144 >>> from sqlglot import parse_one 145 >>> connector_depth(parse_one("a AND b AND c AND d")) 146 3 147 """ 148 stack = deque([(expression, 0)]) 149 max_depth = 0 150 151 while stack: 152 expression, depth = stack.pop() 153 154 if not isinstance(expression, exp.Connector): 155 continue 156 157 depth += 1 158 max_depth = max(depth, max_depth) 159 160 stack.append((expression.left, depth)) 161 stack.append((expression.right, depth)) 162 163 return max_depth
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one >>> connector_depth(parse_one("a AND b AND c AND d")) 3
166def catch(*exceptions): 167 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 168 169 def decorator(func): 170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression 175 176 return wrapped 177 178 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
181def rewrite_between(expression: exp.Expression) -> exp.Expression: 182 """Rewrite x between y and z to x >= y AND x <= z. 183 184 This is done because comparison simplification is only done on lt/lte/gt/gte. 185 """ 186 if isinstance(expression, exp.Between): 187 negate = isinstance(expression.parent, exp.Not) 188 189 expression = exp.and_( 190 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 191 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 192 copy=False, 193 ) 194 195 if negate: 196 expression = exp.paren(expression, copy=False) 197 198 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
216def simplify_not(expression): 217 """ 218 Demorgan's Law 219 NOT (x OR y) -> NOT x AND NOT y 220 NOT (x AND y) -> NOT x OR NOT y 221 """ 222 if isinstance(expression, exp.Not): 223 this = expression.this 224 if is_null(this): 225 return exp.null() 226 if this.__class__ in COMPLEMENT_COMPARISONS: 227 right = this.expression 228 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 229 if complement_subquery_predicate: 230 right = complement_subquery_predicate(this=right.this) 231 232 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 233 if isinstance(this, exp.Paren): 234 condition = this.unnest() 235 if isinstance(condition, exp.And): 236 return exp.paren( 237 exp.or_( 238 exp.not_(condition.left, copy=False), 239 exp.not_(condition.right, copy=False), 240 copy=False, 241 ) 242 ) 243 if isinstance(condition, exp.Or): 244 return exp.paren( 245 exp.and_( 246 exp.not_(condition.left, copy=False), 247 exp.not_(condition.right, copy=False), 248 copy=False, 249 ) 250 ) 251 if is_null(condition): 252 return exp.null() 253 if always_true(this): 254 return exp.false() 255 if is_false(this): 256 return exp.true() 257 if isinstance(this, exp.Not): 258 # double negation 259 # NOT NOT x -> x 260 return this.this 261 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
264def flatten(expression): 265 """ 266 A AND (B AND C) -> A AND B AND C 267 A OR (B OR C) -> A OR B OR C 268 """ 269 if isinstance(expression, exp.Connector): 270 for node in expression.args.values(): 271 child = node.unnest() 272 if isinstance(child, expression.__class__): 273 node.replace(child) 274 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
277def simplify_connectors(expression, root=True): 278 def _simplify_connectors(expression, left, right): 279 if isinstance(expression, exp.And): 280 if is_false(left) or is_false(right): 281 return exp.false() 282 if is_zero(left) or is_zero(right): 283 return exp.false() 284 if is_null(left) or is_null(right): 285 return exp.null() 286 if always_true(left) and always_true(right): 287 return exp.true() 288 if always_true(left): 289 return right 290 if always_true(right): 291 return left 292 return _simplify_comparison(expression, left, right) 293 elif isinstance(expression, exp.Or): 294 if always_true(left) or always_true(right): 295 return exp.true() 296 if ( 297 (is_null(left) and is_null(right)) 298 or (is_null(left) and always_false(right)) 299 or (always_false(left) and is_null(right)) 300 ): 301 return exp.null() 302 if is_false(left): 303 return right 304 if is_false(right): 305 return left 306 return _simplify_comparison(expression, left, right, or_=True) 307 elif isinstance(expression, exp.Xor): 308 if left == right: 309 return exp.false() 310 311 if isinstance(expression, exp.Connector): 312 return _flat_simplify(expression, _simplify_connectors, root) 313 return expression
400def remove_complements(expression, root=True): 401 """ 402 Removing complements. 403 404 A AND NOT A -> FALSE 405 A OR NOT A -> TRUE 406 """ 407 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 408 ops = set(expression.flatten()) 409 for op in ops: 410 if isinstance(op, exp.Not) and op.this in ops: 411 return exp.false() if isinstance(expression, exp.And) else exp.true() 412 413 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
416def uniq_sort(expression, root=True): 417 """ 418 Uniq and sort a connector. 419 420 C AND A AND B AND B -> A AND B AND C 421 """ 422 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 423 flattened = tuple(expression.flatten()) 424 425 if isinstance(expression, exp.Xor): 426 result_func = exp.xor 427 # Do not deduplicate XOR as A XOR A != A if A == True 428 deduped = None 429 arr = tuple((gen(e), e) for e in flattened) 430 else: 431 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 432 deduped = {gen(e): e for e in flattened} 433 arr = tuple(deduped.items()) 434 435 # check if the operands are already sorted, if not sort them 436 # A AND C AND B -> A AND B AND C 437 for i, (sql, e) in enumerate(arr[1:]): 438 if sql < arr[i][0]: 439 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 440 break 441 else: 442 # we didn't have to sort but maybe we need to dedup 443 if deduped and len(deduped) < len(flattened): 444 expression = result_func(*deduped.values(), copy=False) 445 446 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
449def absorb_and_eliminate(expression, root=True): 450 """ 451 absorption: 452 A AND (A OR B) -> A 453 A OR (A AND B) -> A 454 A AND (NOT A OR B) -> A AND B 455 A OR (NOT A AND B) -> A OR B 456 elimination: 457 (A AND B) OR (A AND NOT B) -> A 458 (A OR B) AND (A OR NOT B) -> A 459 """ 460 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 461 kind = exp.Or if isinstance(expression, exp.And) else exp.And 462 463 ops = tuple(expression.flatten()) 464 465 # Initialize lookup tables: 466 # Set of all operands, used to find complements for absorption. 467 op_set = set() 468 # Sub-operands, used to find subsets for absorption. 469 subops = defaultdict(list) 470 # Pairs of complements, used for elimination. 471 pairs = defaultdict(list) 472 473 # Populate the lookup tables 474 for op in ops: 475 op_set.add(op) 476 477 if not isinstance(op, kind): 478 # In cases like: A OR (A AND B) 479 # Subop will be: ^ 480 subops[op].append({op}) 481 continue 482 483 # In cases like: (A AND B) OR (A AND B AND C) 484 # Subops will be: ^ ^ 485 subset = set(op.flatten()) 486 for i in subset: 487 subops[i].append(subset) 488 489 a, b = op.unnest_operands() 490 if isinstance(a, exp.Not): 491 pairs[frozenset((a.this, b))].append((op, b)) 492 if isinstance(b, exp.Not): 493 pairs[frozenset((a, b.this))].append((op, a)) 494 495 for op in ops: 496 if not isinstance(op, kind): 497 continue 498 499 a, b = op.unnest_operands() 500 501 # Absorb 502 if isinstance(a, exp.Not) and a.this in op_set: 503 a.replace(exp.true() if kind == exp.And else exp.false()) 504 continue 505 if isinstance(b, exp.Not) and b.this in op_set: 506 b.replace(exp.true() if kind == exp.And else exp.false()) 507 continue 508 superset = set(op.flatten()) 509 if any(any(subset < superset for subset in subops[i]) for i in superset): 510 op.replace(exp.false() if kind == exp.And else exp.true()) 511 continue 512 513 # Eliminate 514 for other, complement in pairs[frozenset((a, b))]: 515 op.replace(complement) 516 other.replace(complement) 517 518 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
521def propagate_constants(expression, root=True): 522 """ 523 Propagate constants for conjunctions in DNF: 524 525 SELECT * FROM t WHERE a = b AND b = 5 becomes 526 SELECT * FROM t WHERE a = 5 AND b = 5 527 528 Reference: https://www.sqlite.org/optoverview.html 529 """ 530 531 if ( 532 isinstance(expression, exp.And) 533 and (root or not expression.same_parent) 534 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 535 ): 536 constant_mapping = {} 537 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 538 if isinstance(expr, exp.EQ): 539 l, r = expr.left, expr.right 540 541 # TODO: create a helper that can be used to detect nested literal expressions such 542 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 543 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 544 constant_mapping[l] = (id(l), r) 545 546 if constant_mapping: 547 for column in find_all_in_scope(expression, exp.Column): 548 parent = column.parent 549 column_id, constant = constant_mapping.get(column) or (None, None) 550 if ( 551 column_id is not None 552 and id(column) != column_id 553 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 554 ): 555 column.replace(constant.copy()) 556 557 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
632def simplify_literals(expression, root=True): 633 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 634 return _flat_simplify(expression, _simplify_binary, root) 635 636 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 637 return expression.this.this 638 639 if type(expression) in INVERSE_DATE_OPS: 640 return _simplify_binary(expression, expression.this, expression.interval()) or expression 641 642 return expression
743def simplify_parens(expression): 744 if not isinstance(expression, exp.Paren): 745 return expression 746 747 this = expression.this 748 parent = expression.parent 749 parent_is_predicate = isinstance(parent, exp.Predicate) 750 751 if ( 752 not isinstance(this, exp.Select) 753 and not isinstance(parent, exp.SubqueryPredicate) 754 and ( 755 not isinstance(parent, (exp.Condition, exp.Binary)) 756 or isinstance(parent, exp.Paren) 757 or ( 758 not isinstance(this, exp.Binary) 759 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 760 ) 761 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 762 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 763 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 764 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 765 ) 766 ): 767 return this 768 return expression
779def simplify_coalesce(expression): 780 # COALESCE(x) -> x 781 if ( 782 isinstance(expression, exp.Coalesce) 783 and (not expression.expressions or _is_nonnull_constant(expression.this)) 784 # COALESCE is also used as a Spark partitioning hint 785 and not isinstance(expression.parent, exp.Hint) 786 ): 787 return expression.this 788 789 if not isinstance(expression, COMPARISONS): 790 return expression 791 792 if isinstance(expression.left, exp.Coalesce): 793 coalesce = expression.left 794 other = expression.right 795 elif isinstance(expression.right, exp.Coalesce): 796 coalesce = expression.right 797 other = expression.left 798 else: 799 return expression 800 801 # This transformation is valid for non-constants, 802 # but it really only does anything if they are both constants. 803 if not _is_constant(other): 804 return expression 805 806 # Find the first constant arg 807 for arg_index, arg in enumerate(coalesce.expressions): 808 if _is_constant(arg): 809 break 810 else: 811 return expression 812 813 coalesce.set("expressions", coalesce.expressions[:arg_index]) 814 815 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 816 # since we already remove COALESCE at the top of this function. 817 coalesce = coalesce if coalesce.expressions else coalesce.this 818 819 # This expression is more complex than when we started, but it will get simplified further 820 return exp.paren( 821 exp.or_( 822 exp.and_( 823 coalesce.is_(exp.null()).not_(copy=False), 824 expression.copy(), 825 copy=False, 826 ), 827 exp.and_( 828 coalesce.is_(exp.null()), 829 type(expression)(this=arg.copy(), expression=other.copy()), 830 copy=False, 831 ), 832 copy=False, 833 ) 834 )
840def simplify_concat(expression): 841 """Reduces all groups that contain string literals by concatenating them.""" 842 if not isinstance(expression, CONCATS) or ( 843 # We can't reduce a CONCAT_WS call if we don't statically know the separator 844 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 845 ): 846 return expression 847 848 if isinstance(expression, exp.ConcatWs): 849 sep_expr, *expressions = expression.expressions 850 sep = sep_expr.name 851 concat_type = exp.ConcatWs 852 args = {} 853 else: 854 expressions = expression.expressions 855 sep = "" 856 concat_type = exp.Concat 857 args = { 858 "safe": expression.args.get("safe"), 859 "coalesce": expression.args.get("coalesce"), 860 } 861 862 new_args = [] 863 for is_string_group, group in itertools.groupby( 864 expressions or expression.flatten(), lambda e: e.is_string 865 ): 866 if is_string_group: 867 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 868 else: 869 new_args.extend(group) 870 871 if len(new_args) == 1 and new_args[0].is_string: 872 return new_args[0] 873 874 if concat_type is exp.ConcatWs: 875 new_args = [sep_expr] + new_args 876 elif isinstance(expression, exp.DPipe): 877 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 878 879 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
882def simplify_conditionals(expression): 883 """Simplifies expressions like IF, CASE if their condition is statically known.""" 884 if isinstance(expression, exp.Case): 885 this = expression.this 886 for case in expression.args["ifs"]: 887 cond = case.this 888 if this: 889 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 890 cond = cond.replace(this.pop().eq(cond)) 891 892 if always_true(cond): 893 return case.args["true"] 894 895 if always_false(cond): 896 case.pop() 897 if not expression.args["ifs"]: 898 return expression.args.get("default") or exp.null() 899 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 900 if always_true(expression.this): 901 return expression.args["true"] 902 if always_false(expression.this): 903 return expression.args.get("false") or exp.null() 904 905 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
908def simplify_startswith(expression: exp.Expression) -> exp.Expression: 909 """ 910 Reduces a prefix check to either TRUE or FALSE if both the string and the 911 prefix are statically known. 912 913 Example: 914 >>> from sqlglot import parse_one 915 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 916 'TRUE' 917 """ 918 if ( 919 isinstance(expression, exp.StartsWith) 920 and expression.this.is_string 921 and expression.expression.is_string 922 ): 923 return exp.convert(expression.name.startswith(expression.expression.name)) 924 925 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1074def sort_comparison(expression: exp.Expression) -> exp.Expression: 1075 if expression.__class__ in COMPLEMENT_COMPARISONS: 1076 l, r = expression.this, expression.expression 1077 l_column = isinstance(l, exp.Column) 1078 r_column = isinstance(r, exp.Column) 1079 l_const = _is_constant(l) 1080 r_const = _is_constant(r) 1081 1082 if ( 1083 (l_column and not r_column) 1084 or (r_const and not l_const) 1085 or isinstance(r, exp.SubqueryPredicate) 1086 ): 1087 return expression 1088 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1089 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1090 this=r, expression=l 1091 ) 1092 return expression
1106def remove_where_true(expression): 1107 for where in expression.find_all(exp.Where): 1108 if always_true(where.this): 1109 where.pop() 1110 for join in expression.find_all(exp.Join): 1111 if ( 1112 always_true(join.args.get("on")) 1113 and not join.args.get("using") 1114 and not join.args.get("method") 1115 and (join.side, join.kind) in JOINS 1116 ): 1117 join.args["on"].pop() 1118 join.set("side", None) 1119 join.set("kind", "CROSS")
1148def eval_boolean(expression, a, b): 1149 if isinstance(expression, (exp.EQ, exp.Is)): 1150 return boolean_literal(a == b) 1151 if isinstance(expression, exp.NEQ): 1152 return boolean_literal(a != b) 1153 if isinstance(expression, exp.GT): 1154 return boolean_literal(a > b) 1155 if isinstance(expression, exp.GTE): 1156 return boolean_literal(a >= b) 1157 if isinstance(expression, exp.LT): 1158 return boolean_literal(a < b) 1159 if isinstance(expression, exp.LTE): 1160 return boolean_literal(a <= b) 1161 return None
1164def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1165 if isinstance(value, datetime.datetime): 1166 return value.date() 1167 if isinstance(value, datetime.date): 1168 return value 1169 try: 1170 return datetime.datetime.fromisoformat(value).date() 1171 except ValueError: 1172 return None
1175def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1176 if isinstance(value, datetime.datetime): 1177 return value 1178 if isinstance(value, datetime.date): 1179 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1180 try: 1181 return datetime.datetime.fromisoformat(value) 1182 except ValueError: 1183 return None
1186def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1187 if not value: 1188 return None 1189 if to.is_type(exp.DataType.Type.DATE): 1190 return cast_as_date(value) 1191 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1192 return cast_as_datetime(value) 1193 return None
1196def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1197 if isinstance(cast, exp.Cast): 1198 to = cast.to 1199 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1200 to = exp.DataType.build(exp.DataType.Type.DATE) 1201 else: 1202 return None 1203 1204 if isinstance(cast.this, exp.Literal): 1205 value: t.Any = cast.this.name 1206 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1207 value = extract_date(cast.this) 1208 else: 1209 return None 1210 return cast_value(value, to)
1236def date_literal(date, target_type=None): 1237 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1238 target_type = ( 1239 exp.DataType.Type.DATETIME 1240 if isinstance(date, datetime.datetime) 1241 else exp.DataType.Type.DATE 1242 ) 1243 1244 return exp.cast(exp.Literal.string(date), target_type)
1247def interval(unit: str, n: int = 1): 1248 from dateutil.relativedelta import relativedelta 1249 1250 if unit == "year": 1251 return relativedelta(years=1 * n) 1252 if unit == "quarter": 1253 return relativedelta(months=3 * n) 1254 if unit == "month": 1255 return relativedelta(months=1 * n) 1256 if unit == "week": 1257 return relativedelta(weeks=1 * n) 1258 if unit == "day": 1259 return relativedelta(days=1 * n) 1260 if unit == "hour": 1261 return relativedelta(hours=1 * n) 1262 if unit == "minute": 1263 return relativedelta(minutes=1 * n) 1264 if unit == "second": 1265 return relativedelta(seconds=1 * n) 1266 1267 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1270def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1271 if unit == "year": 1272 return d.replace(month=1, day=1) 1273 if unit == "quarter": 1274 if d.month <= 3: 1275 return d.replace(month=1, day=1) 1276 elif d.month <= 6: 1277 return d.replace(month=4, day=1) 1278 elif d.month <= 9: 1279 return d.replace(month=7, day=1) 1280 else: 1281 return d.replace(month=10, day=1) 1282 if unit == "month": 1283 return d.replace(month=d.month, day=1) 1284 if unit == "week": 1285 # Assuming week starts on Monday (0) and ends on Sunday (6) 1286 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1287 if unit == "day": 1288 return d 1289 1290 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1332def gen(expression: t.Any) -> str: 1333 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1334 1335 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1336 generator is expensive so we have a bare minimum sql generator here. 1337 """ 1338 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1341class Gen: 1342 def __init__(self): 1343 self.stack = [] 1344 self.sqls = [] 1345 1346 def gen(self, expression: exp.Expression) -> str: 1347 self.stack = [expression] 1348 self.sqls.clear() 1349 1350 while self.stack: 1351 node = self.stack.pop() 1352 1353 if isinstance(node, exp.Expression): 1354 exp_handler_name = f"{node.key}_sql" 1355 1356 if hasattr(self, exp_handler_name): 1357 getattr(self, exp_handler_name)(node) 1358 elif isinstance(node, exp.Func): 1359 self._function(node) 1360 else: 1361 key = node.key.upper() 1362 self.stack.append(f"{key} " if self._args(node) else key) 1363 elif type(node) is list: 1364 for n in reversed(node): 1365 if n is not None: 1366 self.stack.extend((n, ",")) 1367 if node: 1368 self.stack.pop() 1369 else: 1370 if node is not None: 1371 self.sqls.append(str(node)) 1372 1373 return "".join(self.sqls) 1374 1375 def add_sql(self, e: exp.Add) -> None: 1376 self._binary(e, " + ") 1377 1378 def alias_sql(self, e: exp.Alias) -> None: 1379 self.stack.extend( 1380 ( 1381 e.args.get("alias"), 1382 " AS ", 1383 e.args.get("this"), 1384 ) 1385 ) 1386 1387 def and_sql(self, e: exp.And) -> None: 1388 self._binary(e, " AND ") 1389 1390 def anonymous_sql(self, e: exp.Anonymous) -> None: 1391 this = e.this 1392 if isinstance(this, str): 1393 name = this.upper() 1394 elif isinstance(this, exp.Identifier): 1395 name = this.this 1396 name = f'"{name}"' if this.quoted else name.upper() 1397 else: 1398 raise ValueError( 1399 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1400 ) 1401 1402 self.stack.extend( 1403 ( 1404 ")", 1405 e.expressions, 1406 "(", 1407 name, 1408 ) 1409 ) 1410 1411 def between_sql(self, e: exp.Between) -> None: 1412 self.stack.extend( 1413 ( 1414 e.args.get("high"), 1415 " AND ", 1416 e.args.get("low"), 1417 " BETWEEN ", 1418 e.this, 1419 ) 1420 ) 1421 1422 def boolean_sql(self, e: exp.Boolean) -> None: 1423 self.stack.append("TRUE" if e.this else "FALSE") 1424 1425 def bracket_sql(self, e: exp.Bracket) -> None: 1426 self.stack.extend( 1427 ( 1428 "]", 1429 e.expressions, 1430 "[", 1431 e.this, 1432 ) 1433 ) 1434 1435 def column_sql(self, e: exp.Column) -> None: 1436 for p in reversed(e.parts): 1437 self.stack.extend((p, ".")) 1438 self.stack.pop() 1439 1440 def datatype_sql(self, e: exp.DataType) -> None: 1441 self._args(e, 1) 1442 self.stack.append(f"{e.this.name} ") 1443 1444 def div_sql(self, e: exp.Div) -> None: 1445 self._binary(e, " / ") 1446 1447 def dot_sql(self, e: exp.Dot) -> None: 1448 self._binary(e, ".") 1449 1450 def eq_sql(self, e: exp.EQ) -> None: 1451 self._binary(e, " = ") 1452 1453 def from_sql(self, e: exp.From) -> None: 1454 self.stack.extend((e.this, "FROM ")) 1455 1456 def gt_sql(self, e: exp.GT) -> None: 1457 self._binary(e, " > ") 1458 1459 def gte_sql(self, e: exp.GTE) -> None: 1460 self._binary(e, " >= ") 1461 1462 def identifier_sql(self, e: exp.Identifier) -> None: 1463 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1464 1465 def ilike_sql(self, e: exp.ILike) -> None: 1466 self._binary(e, " ILIKE ") 1467 1468 def in_sql(self, e: exp.In) -> None: 1469 self.stack.append(")") 1470 self._args(e, 1) 1471 self.stack.extend( 1472 ( 1473 "(", 1474 " IN ", 1475 e.this, 1476 ) 1477 ) 1478 1479 def intdiv_sql(self, e: exp.IntDiv) -> None: 1480 self._binary(e, " DIV ") 1481 1482 def is_sql(self, e: exp.Is) -> None: 1483 self._binary(e, " IS ") 1484 1485 def like_sql(self, e: exp.Like) -> None: 1486 self._binary(e, " Like ") 1487 1488 def literal_sql(self, e: exp.Literal) -> None: 1489 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1490 1491 def lt_sql(self, e: exp.LT) -> None: 1492 self._binary(e, " < ") 1493 1494 def lte_sql(self, e: exp.LTE) -> None: 1495 self._binary(e, " <= ") 1496 1497 def mod_sql(self, e: exp.Mod) -> None: 1498 self._binary(e, " % ") 1499 1500 def mul_sql(self, e: exp.Mul) -> None: 1501 self._binary(e, " * ") 1502 1503 def neg_sql(self, e: exp.Neg) -> None: 1504 self._unary(e, "-") 1505 1506 def neq_sql(self, e: exp.NEQ) -> None: 1507 self._binary(e, " <> ") 1508 1509 def not_sql(self, e: exp.Not) -> None: 1510 self._unary(e, "NOT ") 1511 1512 def null_sql(self, e: exp.Null) -> None: 1513 self.stack.append("NULL") 1514 1515 def or_sql(self, e: exp.Or) -> None: 1516 self._binary(e, " OR ") 1517 1518 def paren_sql(self, e: exp.Paren) -> None: 1519 self.stack.extend( 1520 ( 1521 ")", 1522 e.this, 1523 "(", 1524 ) 1525 ) 1526 1527 def sub_sql(self, e: exp.Sub) -> None: 1528 self._binary(e, " - ") 1529 1530 def subquery_sql(self, e: exp.Subquery) -> None: 1531 self._args(e, 2) 1532 alias = e.args.get("alias") 1533 if alias: 1534 self.stack.append(alias) 1535 self.stack.extend((")", e.this, "(")) 1536 1537 def table_sql(self, e: exp.Table) -> None: 1538 self._args(e, 4) 1539 alias = e.args.get("alias") 1540 if alias: 1541 self.stack.append(alias) 1542 for p in reversed(e.parts): 1543 self.stack.extend((p, ".")) 1544 self.stack.pop() 1545 1546 def tablealias_sql(self, e: exp.TableAlias) -> None: 1547 columns = e.columns 1548 1549 if columns: 1550 self.stack.extend((")", columns, "(")) 1551 1552 self.stack.extend((e.this, " AS ")) 1553 1554 def var_sql(self, e: exp.Var) -> None: 1555 self.stack.append(e.this) 1556 1557 def _binary(self, e: exp.Binary, op: str) -> None: 1558 self.stack.extend((e.expression, op, e.this)) 1559 1560 def _unary(self, e: exp.Unary, op: str) -> None: 1561 self.stack.extend((e.this, op)) 1562 1563 def _function(self, e: exp.Func) -> None: 1564 self.stack.extend( 1565 ( 1566 ")", 1567 list(e.args.values()), 1568 "(", 1569 e.sql_name(), 1570 ) 1571 ) 1572 1573 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1574 kvs = [] 1575 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1576 1577 for k in arg_types or arg_types: 1578 v = node.args.get(k) 1579 1580 if v is not None: 1581 kvs.append([f":{k}", v]) 1582 if kvs: 1583 self.stack.append(kvs) 1584 return True 1585 return False
1346 def gen(self, expression: exp.Expression) -> str: 1347 self.stack = [expression] 1348 self.sqls.clear() 1349 1350 while self.stack: 1351 node = self.stack.pop() 1352 1353 if isinstance(node, exp.Expression): 1354 exp_handler_name = f"{node.key}_sql" 1355 1356 if hasattr(self, exp_handler_name): 1357 getattr(self, exp_handler_name)(node) 1358 elif isinstance(node, exp.Func): 1359 self._function(node) 1360 else: 1361 key = node.key.upper() 1362 self.stack.append(f"{key} " if self._args(node) else key) 1363 elif type(node) is list: 1364 for n in reversed(node): 1365 if n is not None: 1366 self.stack.extend((n, ",")) 1367 if node: 1368 self.stack.pop() 1369 else: 1370 if node is not None: 1371 self.sqls.append(str(node)) 1372 1373 return "".join(self.sqls)
1390 def anonymous_sql(self, e: exp.Anonymous) -> None: 1391 this = e.this 1392 if isinstance(this, str): 1393 name = this.upper() 1394 elif isinstance(this, exp.Identifier): 1395 name = this.this 1396 name = f'"{name}"' if this.quoted else name.upper() 1397 else: 1398 raise ValueError( 1399 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1400 ) 1401 1402 self.stack.extend( 1403 ( 1404 ")", 1405 e.expressions, 1406 "(", 1407 name, 1408 ) 1409 )