sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 dialect: DialectType = None, 43): 44 """ 45 Rewrite sqlglot AST to simplify expressions. 46 47 Example: 48 >>> import sqlglot 49 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 50 >>> simplify(expression).sql() 51 'TRUE' 52 53 Args: 54 expression: expression to simplify 55 constant_propagation: whether the constant propagation rule should be used 56 Returns: 57 sqlglot.Expression: simplified expression 58 """ 59 60 dialect = Dialect.get_or_raise(dialect) 61 62 def _simplify(expression): 63 pre_transformation_stack = [expression] 64 post_transformation_stack = [] 65 66 while pre_transformation_stack: 67 node = pre_transformation_stack.pop() 68 69 if node.meta.get(FINAL): 70 continue 71 72 # group by expressions cannot be simplified, for example 73 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 74 # the projection must exactly match the group by key 75 group = node.args.get("group") 76 77 if group and hasattr(node, "selects"): 78 groups = set(group.expressions) 79 group.meta[FINAL] = True 80 81 for s in node.selects: 82 for n in s.walk(): 83 if n in groups: 84 s.meta[FINAL] = True 85 break 86 87 having = node.args.get("having") 88 if having: 89 for n in having.walk(): 90 if n in groups: 91 having.meta[FINAL] = True 92 break 93 94 parent = node.parent 95 root = node is expression 96 97 new_node = rewrite_between(node) 98 new_node = uniq_sort(new_node, root) 99 new_node = absorb_and_eliminate(new_node, root) 100 new_node = simplify_concat(new_node) 101 new_node = simplify_conditionals(new_node) 102 103 if constant_propagation: 104 new_node = propagate_constants(new_node, root) 105 106 if new_node is not node: 107 node.replace(new_node) 108 109 pre_transformation_stack.extend( 110 n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 111 ) 112 post_transformation_stack.append((new_node, parent)) 113 114 while post_transformation_stack: 115 node, parent = post_transformation_stack.pop() 116 root = node is expression 117 118 # Resets parent, arg_key, index pointers– this is needed because some of the 119 # previous transformations mutate the AST, leading to an inconsistent state 120 for k, v in tuple(node.args.items()): 121 node.set(k, v) 122 123 # Post-order transformations 124 new_node = simplify_not(node) 125 new_node = flatten(new_node) 126 new_node = simplify_connectors(new_node, root) 127 new_node = remove_complements(new_node, root) 128 new_node = simplify_coalesce(new_node, dialect) 129 130 new_node.parent = parent 131 132 new_node = simplify_literals(new_node, root) 133 new_node = simplify_equality(new_node) 134 new_node = simplify_parens(new_node) 135 new_node = simplify_datetrunc(new_node, dialect) 136 new_node = sort_comparison(new_node) 137 new_node = simplify_startswith(new_node) 138 139 if new_node is not node: 140 node.replace(new_node) 141 142 return new_node 143 144 expression = while_changing(expression, _simplify) 145 remove_where_true(expression) 146 return expression 147 148 149def catch(*exceptions): 150 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 151 152 def decorator(func): 153 def wrapped(expression, *args, **kwargs): 154 try: 155 return func(expression, *args, **kwargs) 156 except exceptions: 157 return expression 158 159 return wrapped 160 161 return decorator 162 163 164def rewrite_between(expression: exp.Expression) -> exp.Expression: 165 """Rewrite x between y and z to x >= y AND x <= z. 166 167 This is done because comparison simplification is only done on lt/lte/gt/gte. 168 """ 169 if isinstance(expression, exp.Between): 170 negate = isinstance(expression.parent, exp.Not) 171 172 expression = exp.and_( 173 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 174 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 175 copy=False, 176 ) 177 178 if negate: 179 expression = exp.paren(expression, copy=False) 180 181 return expression 182 183 184COMPLEMENT_COMPARISONS = { 185 exp.LT: exp.GTE, 186 exp.GT: exp.LTE, 187 exp.LTE: exp.GT, 188 exp.GTE: exp.LT, 189 exp.EQ: exp.NEQ, 190 exp.NEQ: exp.EQ, 191} 192 193COMPLEMENT_SUBQUERY_PREDICATES = { 194 exp.All: exp.Any, 195 exp.Any: exp.All, 196} 197 198 199def simplify_not(expression): 200 """ 201 Demorgan's Law 202 NOT (x OR y) -> NOT x AND NOT y 203 NOT (x AND y) -> NOT x OR NOT y 204 """ 205 if isinstance(expression, exp.Not): 206 this = expression.this 207 if is_null(this): 208 return exp.null() 209 if this.__class__ in COMPLEMENT_COMPARISONS: 210 right = this.expression 211 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 212 if complement_subquery_predicate: 213 right = complement_subquery_predicate(this=right.this) 214 215 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 216 if isinstance(this, exp.Paren): 217 condition = this.unnest() 218 if isinstance(condition, exp.And): 219 return exp.paren( 220 exp.or_( 221 exp.not_(condition.left, copy=False), 222 exp.not_(condition.right, copy=False), 223 copy=False, 224 ) 225 ) 226 if isinstance(condition, exp.Or): 227 return exp.paren( 228 exp.and_( 229 exp.not_(condition.left, copy=False), 230 exp.not_(condition.right, copy=False), 231 copy=False, 232 ) 233 ) 234 if is_null(condition): 235 return exp.null() 236 if always_true(this): 237 return exp.false() 238 if is_false(this): 239 return exp.true() 240 if isinstance(this, exp.Not): 241 # double negation 242 # NOT NOT x -> x 243 return this.this 244 return expression 245 246 247def flatten(expression): 248 """ 249 A AND (B AND C) -> A AND B AND C 250 A OR (B OR C) -> A OR B OR C 251 """ 252 if isinstance(expression, exp.Connector): 253 for node in expression.args.values(): 254 child = node.unnest() 255 if isinstance(child, expression.__class__): 256 node.replace(child) 257 return expression 258 259 260def simplify_connectors(expression, root=True): 261 def _simplify_connectors(expression, left, right): 262 if isinstance(expression, exp.And): 263 if is_false(left) or is_false(right): 264 return exp.false() 265 if is_zero(left) or is_zero(right): 266 return exp.false() 267 if is_null(left) or is_null(right): 268 return exp.null() 269 if always_true(left) and always_true(right): 270 return exp.true() 271 if always_true(left): 272 return right 273 if always_true(right): 274 return left 275 return _simplify_comparison(expression, left, right) 276 elif isinstance(expression, exp.Or): 277 if always_true(left) or always_true(right): 278 return exp.true() 279 if ( 280 (is_null(left) and is_null(right)) 281 or (is_null(left) and always_false(right)) 282 or (always_false(left) and is_null(right)) 283 ): 284 return exp.null() 285 if is_false(left): 286 return right 287 if is_false(right): 288 return left 289 return _simplify_comparison(expression, left, right, or_=True) 290 elif isinstance(expression, exp.Xor): 291 if left == right: 292 return exp.false() 293 294 if isinstance(expression, exp.Connector): 295 return _flat_simplify(expression, _simplify_connectors, root) 296 return expression 297 298 299LT_LTE = (exp.LT, exp.LTE) 300GT_GTE = (exp.GT, exp.GTE) 301 302COMPARISONS = ( 303 *LT_LTE, 304 *GT_GTE, 305 exp.EQ, 306 exp.NEQ, 307 exp.Is, 308) 309 310INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 311 exp.LT: exp.GT, 312 exp.GT: exp.LT, 313 exp.LTE: exp.GTE, 314 exp.GTE: exp.LTE, 315} 316 317NONDETERMINISTIC = (exp.Rand, exp.Randn) 318AND_OR = (exp.And, exp.Or) 319 320 321def _simplify_comparison(expression, left, right, or_=False): 322 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 323 ll, lr = left.args.values() 324 rl, rr = right.args.values() 325 326 largs = {ll, lr} 327 rargs = {rl, rr} 328 329 matching = largs & rargs 330 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 331 332 if matching and columns: 333 try: 334 l = first(largs - columns) 335 r = first(rargs - columns) 336 except StopIteration: 337 return expression 338 339 if l.is_number and r.is_number: 340 l = l.to_py() 341 r = r.to_py() 342 elif l.is_string and r.is_string: 343 l = l.name 344 r = r.name 345 else: 346 l = extract_date(l) 347 if not l: 348 return None 349 r = extract_date(r) 350 if not r: 351 return None 352 # python won't compare date and datetime, but many engines will upcast 353 l, r = cast_as_datetime(l), cast_as_datetime(r) 354 355 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 356 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 357 return left if (av > bv if or_ else av <= bv) else right 358 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 359 return left if (av < bv if or_ else av >= bv) else right 360 361 # we can't ever shortcut to true because the column could be null 362 if not or_: 363 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 364 if av <= bv: 365 return exp.false() 366 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 367 if av >= bv: 368 return exp.false() 369 elif isinstance(a, exp.EQ): 370 if isinstance(b, exp.LT): 371 return exp.false() if av >= bv else a 372 if isinstance(b, exp.LTE): 373 return exp.false() if av > bv else a 374 if isinstance(b, exp.GT): 375 return exp.false() if av <= bv else a 376 if isinstance(b, exp.GTE): 377 return exp.false() if av < bv else a 378 if isinstance(b, exp.NEQ): 379 return exp.false() if av == bv else a 380 return None 381 382 383def remove_complements(expression, root=True): 384 """ 385 Removing complements. 386 387 A AND NOT A -> FALSE 388 A OR NOT A -> TRUE 389 """ 390 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 391 ops = set(expression.flatten()) 392 for op in ops: 393 if isinstance(op, exp.Not) and op.this in ops: 394 return exp.false() if isinstance(expression, exp.And) else exp.true() 395 396 return expression 397 398 399def uniq_sort(expression, root=True): 400 """ 401 Uniq and sort a connector. 402 403 C AND A AND B AND B -> A AND B AND C 404 """ 405 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 406 flattened = tuple(expression.flatten()) 407 408 if isinstance(expression, exp.Xor): 409 result_func = exp.xor 410 # Do not deduplicate XOR as A XOR A != A if A == True 411 deduped = None 412 arr = tuple((gen(e), e) for e in flattened) 413 else: 414 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 415 deduped = {gen(e): e for e in flattened} 416 arr = tuple(deduped.items()) 417 418 # check if the operands are already sorted, if not sort them 419 # A AND C AND B -> A AND B AND C 420 for i, (sql, e) in enumerate(arr[1:]): 421 if sql < arr[i][0]: 422 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 423 break 424 else: 425 # we didn't have to sort but maybe we need to dedup 426 if deduped and len(deduped) < len(flattened): 427 expression = result_func(*deduped.values(), copy=False) 428 429 return expression 430 431 432def absorb_and_eliminate(expression, root=True): 433 """ 434 absorption: 435 A AND (A OR B) -> A 436 A OR (A AND B) -> A 437 A AND (NOT A OR B) -> A AND B 438 A OR (NOT A AND B) -> A OR B 439 elimination: 440 (A AND B) OR (A AND NOT B) -> A 441 (A OR B) AND (A OR NOT B) -> A 442 """ 443 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 444 kind = exp.Or if isinstance(expression, exp.And) else exp.And 445 446 ops = tuple(expression.flatten()) 447 448 # Initialize lookup tables: 449 # Set of all operands, used to find complements for absorption. 450 op_set = set() 451 # Sub-operands, used to find subsets for absorption. 452 subops = defaultdict(list) 453 # Pairs of complements, used for elimination. 454 pairs = defaultdict(list) 455 456 # Populate the lookup tables 457 for op in ops: 458 op_set.add(op) 459 460 if not isinstance(op, kind): 461 # In cases like: A OR (A AND B) 462 # Subop will be: ^ 463 subops[op].append({op}) 464 continue 465 466 # In cases like: (A AND B) OR (A AND B AND C) 467 # Subops will be: ^ ^ 468 subset = set(op.flatten()) 469 for i in subset: 470 subops[i].append(subset) 471 472 a, b = op.unnest_operands() 473 if isinstance(a, exp.Not): 474 pairs[frozenset((a.this, b))].append((op, b)) 475 if isinstance(b, exp.Not): 476 pairs[frozenset((a, b.this))].append((op, a)) 477 478 for op in ops: 479 if not isinstance(op, kind): 480 continue 481 482 a, b = op.unnest_operands() 483 484 # Absorb 485 if isinstance(a, exp.Not) and a.this in op_set: 486 a.replace(exp.true() if kind == exp.And else exp.false()) 487 continue 488 if isinstance(b, exp.Not) and b.this in op_set: 489 b.replace(exp.true() if kind == exp.And else exp.false()) 490 continue 491 superset = set(op.flatten()) 492 if any(any(subset < superset for subset in subops[i]) for i in superset): 493 op.replace(exp.false() if kind == exp.And else exp.true()) 494 continue 495 496 # Eliminate 497 for other, complement in pairs[frozenset((a, b))]: 498 op.replace(complement) 499 other.replace(complement) 500 501 return expression 502 503 504def propagate_constants(expression, root=True): 505 """ 506 Propagate constants for conjunctions in DNF: 507 508 SELECT * FROM t WHERE a = b AND b = 5 becomes 509 SELECT * FROM t WHERE a = 5 AND b = 5 510 511 Reference: https://www.sqlite.org/optoverview.html 512 """ 513 514 if ( 515 isinstance(expression, exp.And) 516 and (root or not expression.same_parent) 517 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 518 ): 519 constant_mapping = {} 520 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 521 if isinstance(expr, exp.EQ): 522 l, r = expr.left, expr.right 523 524 # TODO: create a helper that can be used to detect nested literal expressions such 525 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 526 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 527 constant_mapping[l] = (id(l), r) 528 529 if constant_mapping: 530 for column in find_all_in_scope(expression, exp.Column): 531 parent = column.parent 532 column_id, constant = constant_mapping.get(column) or (None, None) 533 if ( 534 column_id is not None 535 and id(column) != column_id 536 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 537 ): 538 column.replace(constant.copy()) 539 540 return expression 541 542 543INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 544 exp.DateAdd: exp.Sub, 545 exp.DateSub: exp.Add, 546 exp.DatetimeAdd: exp.Sub, 547 exp.DatetimeSub: exp.Add, 548} 549 550INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 551 **INVERSE_DATE_OPS, 552 exp.Add: exp.Sub, 553 exp.Sub: exp.Add, 554} 555 556 557def _is_number(expression: exp.Expression) -> bool: 558 return expression.is_number 559 560 561def _is_interval(expression: exp.Expression) -> bool: 562 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 563 564 565@catch(ModuleNotFoundError, UnsupportedUnit) 566def simplify_equality(expression: exp.Expression) -> exp.Expression: 567 """ 568 Use the subtraction and addition properties of equality to simplify expressions: 569 570 x + 1 = 3 becomes x = 2 571 572 There are two binary operations in the above expression: + and = 573 Here's how we reference all the operands in the code below: 574 575 l r 576 x + 1 = 3 577 a b 578 """ 579 if isinstance(expression, COMPARISONS): 580 l, r = expression.left, expression.right 581 582 if l.__class__ not in INVERSE_OPS: 583 return expression 584 585 if r.is_number: 586 a_predicate = _is_number 587 b_predicate = _is_number 588 elif _is_date_literal(r): 589 a_predicate = _is_date_literal 590 b_predicate = _is_interval 591 else: 592 return expression 593 594 if l.__class__ in INVERSE_DATE_OPS: 595 l = t.cast(exp.IntervalOp, l) 596 a = l.this 597 b = l.interval() 598 else: 599 l = t.cast(exp.Binary, l) 600 a, b = l.left, l.right 601 602 if not a_predicate(a) and b_predicate(b): 603 pass 604 elif not a_predicate(b) and b_predicate(a): 605 a, b = b, a 606 else: 607 return expression 608 609 return expression.__class__( 610 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 611 ) 612 return expression 613 614 615def simplify_literals(expression, root=True): 616 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 617 return _flat_simplify(expression, _simplify_binary, root) 618 619 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 620 return expression.this.this 621 622 if type(expression) in INVERSE_DATE_OPS: 623 return _simplify_binary(expression, expression.this, expression.interval()) or expression 624 625 return expression 626 627 628NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 629 630 631def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 632 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 633 this = _simplify_integer_cast(expr.this) 634 else: 635 this = expr.this 636 637 if isinstance(expr, exp.Cast) and this.is_int: 638 num = this.to_py() 639 640 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 641 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 642 # engine-dependent 643 if ( 644 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 645 ) or ( 646 UTINYINT_MIN <= num <= UTINYINT_MAX 647 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 648 ): 649 return this 650 651 return expr 652 653 654def _simplify_binary(expression, a, b): 655 if isinstance(expression, COMPARISONS): 656 a = _simplify_integer_cast(a) 657 b = _simplify_integer_cast(b) 658 659 if isinstance(expression, exp.Is): 660 if isinstance(b, exp.Not): 661 c = b.this 662 not_ = True 663 else: 664 c = b 665 not_ = False 666 667 if is_null(c): 668 if isinstance(a, exp.Literal): 669 return exp.true() if not_ else exp.false() 670 if is_null(a): 671 return exp.false() if not_ else exp.true() 672 elif isinstance(expression, NULL_OK): 673 return None 674 elif is_null(a) or is_null(b): 675 return exp.null() 676 677 if a.is_number and b.is_number: 678 num_a = a.to_py() 679 num_b = b.to_py() 680 681 if isinstance(expression, exp.Add): 682 return exp.Literal.number(num_a + num_b) 683 if isinstance(expression, exp.Mul): 684 return exp.Literal.number(num_a * num_b) 685 686 # We only simplify Sub, Div if a and b have the same parent because they're not associative 687 if isinstance(expression, exp.Sub): 688 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 689 if isinstance(expression, exp.Div): 690 # engines have differing int div behavior so intdiv is not safe 691 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 692 return None 693 return exp.Literal.number(num_a / num_b) 694 695 boolean = eval_boolean(expression, num_a, num_b) 696 697 if boolean: 698 return boolean 699 elif a.is_string and b.is_string: 700 boolean = eval_boolean(expression, a.this, b.this) 701 702 if boolean: 703 return boolean 704 elif _is_date_literal(a) and isinstance(b, exp.Interval): 705 date, b = extract_date(a), extract_interval(b) 706 if date and b: 707 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 708 return date_literal(date + b, extract_type(a)) 709 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 710 return date_literal(date - b, extract_type(a)) 711 elif isinstance(a, exp.Interval) and _is_date_literal(b): 712 a, date = extract_interval(a), extract_date(b) 713 # you cannot subtract a date from an interval 714 if a and b and isinstance(expression, exp.Add): 715 return date_literal(a + date, extract_type(b)) 716 elif _is_date_literal(a) and _is_date_literal(b): 717 if isinstance(expression, exp.Predicate): 718 a, b = extract_date(a), extract_date(b) 719 boolean = eval_boolean(expression, a, b) 720 if boolean: 721 return boolean 722 723 return None 724 725 726def simplify_parens(expression): 727 if not isinstance(expression, exp.Paren): 728 return expression 729 730 this = expression.this 731 parent = expression.parent 732 parent_is_predicate = isinstance(parent, exp.Predicate) 733 734 if ( 735 not isinstance(this, exp.Select) 736 and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) 737 and ( 738 not isinstance(parent, (exp.Condition, exp.Binary)) 739 or isinstance(parent, exp.Paren) 740 or ( 741 not isinstance(this, exp.Binary) 742 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 743 ) 744 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 745 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 746 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 747 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 748 ) 749 ): 750 return this 751 return expression 752 753 754def _is_nonnull_constant(expression: exp.Expression) -> bool: 755 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 756 757 758def _is_constant(expression: exp.Expression) -> bool: 759 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 760 761 762def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 763 # COALESCE(x) -> x 764 if ( 765 isinstance(expression, exp.Coalesce) 766 and (not expression.expressions or _is_nonnull_constant(expression.this)) 767 # COALESCE is also used as a Spark partitioning hint 768 and not isinstance(expression.parent, exp.Hint) 769 ): 770 return expression.this 771 772 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 773 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 774 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 775 if dialect == "redshift": 776 return expression 777 778 if not isinstance(expression, COMPARISONS): 779 return expression 780 781 if isinstance(expression.left, exp.Coalesce): 782 coalesce = expression.left 783 other = expression.right 784 elif isinstance(expression.right, exp.Coalesce): 785 coalesce = expression.right 786 other = expression.left 787 else: 788 return expression 789 790 # This transformation is valid for non-constants, 791 # but it really only does anything if they are both constants. 792 if not _is_constant(other): 793 return expression 794 795 # Find the first constant arg 796 for arg_index, arg in enumerate(coalesce.expressions): 797 if _is_constant(arg): 798 break 799 else: 800 return expression 801 802 coalesce.set("expressions", coalesce.expressions[:arg_index]) 803 804 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 805 # since we already remove COALESCE at the top of this function. 806 coalesce = coalesce if coalesce.expressions else coalesce.this 807 808 # This expression is more complex than when we started, but it will get simplified further 809 return exp.paren( 810 exp.or_( 811 exp.and_( 812 coalesce.is_(exp.null()).not_(copy=False), 813 expression.copy(), 814 copy=False, 815 ), 816 exp.and_( 817 coalesce.is_(exp.null()), 818 type(expression)(this=arg.copy(), expression=other.copy()), 819 copy=False, 820 ), 821 copy=False, 822 ) 823 ) 824 825 826CONCATS = (exp.Concat, exp.DPipe) 827 828 829def simplify_concat(expression): 830 """Reduces all groups that contain string literals by concatenating them.""" 831 if not isinstance(expression, CONCATS) or ( 832 # We can't reduce a CONCAT_WS call if we don't statically know the separator 833 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 834 ): 835 return expression 836 837 if isinstance(expression, exp.ConcatWs): 838 sep_expr, *expressions = expression.expressions 839 sep = sep_expr.name 840 concat_type = exp.ConcatWs 841 args = {} 842 else: 843 expressions = expression.expressions 844 sep = "" 845 concat_type = exp.Concat 846 args = { 847 "safe": expression.args.get("safe"), 848 "coalesce": expression.args.get("coalesce"), 849 } 850 851 new_args = [] 852 for is_string_group, group in itertools.groupby( 853 expressions or expression.flatten(), lambda e: e.is_string 854 ): 855 if is_string_group: 856 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 857 else: 858 new_args.extend(group) 859 860 if len(new_args) == 1 and new_args[0].is_string: 861 return new_args[0] 862 863 if concat_type is exp.ConcatWs: 864 new_args = [sep_expr] + new_args 865 elif isinstance(expression, exp.DPipe): 866 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 867 868 return concat_type(expressions=new_args, **args) 869 870 871def simplify_conditionals(expression): 872 """Simplifies expressions like IF, CASE if their condition is statically known.""" 873 if isinstance(expression, exp.Case): 874 this = expression.this 875 for case in expression.args["ifs"]: 876 cond = case.this 877 if this: 878 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 879 cond = cond.replace(this.pop().eq(cond)) 880 881 if always_true(cond): 882 return case.args["true"] 883 884 if always_false(cond): 885 case.pop() 886 if not expression.args["ifs"]: 887 return expression.args.get("default") or exp.null() 888 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 889 if always_true(expression.this): 890 return expression.args["true"] 891 if always_false(expression.this): 892 return expression.args.get("false") or exp.null() 893 894 return expression 895 896 897def simplify_startswith(expression: exp.Expression) -> exp.Expression: 898 """ 899 Reduces a prefix check to either TRUE or FALSE if both the string and the 900 prefix are statically known. 901 902 Example: 903 >>> from sqlglot import parse_one 904 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 905 'TRUE' 906 """ 907 if ( 908 isinstance(expression, exp.StartsWith) 909 and expression.this.is_string 910 and expression.expression.is_string 911 ): 912 return exp.convert(expression.name.startswith(expression.expression.name)) 913 914 return expression 915 916 917DateRange = t.Tuple[datetime.date, datetime.date] 918 919 920def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 921 """ 922 Get the date range for a DATE_TRUNC equality comparison: 923 924 Example: 925 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 926 Returns: 927 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 928 """ 929 floor = date_floor(date, unit, dialect) 930 931 if date != floor: 932 # This will always be False, except for NULL values. 933 return None 934 935 return floor, floor + interval(unit) 936 937 938def _datetrunc_eq_expression( 939 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 940) -> exp.Expression: 941 """Get the logical expression for a date range""" 942 return exp.and_( 943 left >= date_literal(drange[0], target_type), 944 left < date_literal(drange[1], target_type), 945 copy=False, 946 ) 947 948 949def _datetrunc_eq( 950 left: exp.Expression, 951 date: datetime.date, 952 unit: str, 953 dialect: Dialect, 954 target_type: t.Optional[exp.DataType], 955) -> t.Optional[exp.Expression]: 956 drange = _datetrunc_range(date, unit, dialect) 957 if not drange: 958 return None 959 960 return _datetrunc_eq_expression(left, drange, target_type) 961 962 963def _datetrunc_neq( 964 left: exp.Expression, 965 date: datetime.date, 966 unit: str, 967 dialect: Dialect, 968 target_type: t.Optional[exp.DataType], 969) -> t.Optional[exp.Expression]: 970 drange = _datetrunc_range(date, unit, dialect) 971 if not drange: 972 return None 973 974 return exp.and_( 975 left < date_literal(drange[0], target_type), 976 left >= date_literal(drange[1], target_type), 977 copy=False, 978 ) 979 980 981DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 982 exp.LT: lambda l, dt, u, d, t: l 983 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 984 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 985 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 986 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 987 exp.EQ: _datetrunc_eq, 988 exp.NEQ: _datetrunc_neq, 989} 990DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 991DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 992 993 994def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 995 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 996 997 998@catch(ModuleNotFoundError, UnsupportedUnit) 999def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1000 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1001 comparison = expression.__class__ 1002 1003 if isinstance(expression, DATETRUNCS): 1004 this = expression.this 1005 trunc_type = extract_type(this) 1006 date = extract_date(this) 1007 if date and expression.unit: 1008 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1009 elif comparison not in DATETRUNC_COMPARISONS: 1010 return expression 1011 1012 if isinstance(expression, exp.Binary): 1013 l, r = expression.left, expression.right 1014 1015 if not _is_datetrunc_predicate(l, r): 1016 return expression 1017 1018 l = t.cast(exp.DateTrunc, l) 1019 trunc_arg = l.this 1020 unit = l.unit.name.lower() 1021 date = extract_date(r) 1022 1023 if not date: 1024 return expression 1025 1026 return ( 1027 DATETRUNC_BINARY_COMPARISONS[comparison]( 1028 trunc_arg, date, unit, dialect, extract_type(r) 1029 ) 1030 or expression 1031 ) 1032 1033 if isinstance(expression, exp.In): 1034 l = expression.this 1035 rs = expression.expressions 1036 1037 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1038 l = t.cast(exp.DateTrunc, l) 1039 unit = l.unit.name.lower() 1040 1041 ranges = [] 1042 for r in rs: 1043 date = extract_date(r) 1044 if not date: 1045 return expression 1046 drange = _datetrunc_range(date, unit, dialect) 1047 if drange: 1048 ranges.append(drange) 1049 1050 if not ranges: 1051 return expression 1052 1053 ranges = merge_ranges(ranges) 1054 target_type = extract_type(*rs) 1055 1056 return exp.or_( 1057 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1058 ) 1059 1060 return expression 1061 1062 1063def sort_comparison(expression: exp.Expression) -> exp.Expression: 1064 if expression.__class__ in COMPLEMENT_COMPARISONS: 1065 l, r = expression.this, expression.expression 1066 l_column = isinstance(l, exp.Column) 1067 r_column = isinstance(r, exp.Column) 1068 l_const = _is_constant(l) 1069 r_const = _is_constant(r) 1070 1071 if ( 1072 (l_column and not r_column) 1073 or (r_const and not l_const) 1074 or isinstance(r, exp.SubqueryPredicate) 1075 ): 1076 return expression 1077 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1078 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1079 this=r, expression=l 1080 ) 1081 return expression 1082 1083 1084# CROSS joins result in an empty table if the right table is empty. 1085# So we can only simplify certain types of joins to CROSS. 1086# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1087JOINS = { 1088 ("", ""), 1089 ("", "INNER"), 1090 ("RIGHT", ""), 1091 ("RIGHT", "OUTER"), 1092} 1093 1094 1095def remove_where_true(expression): 1096 for where in expression.find_all(exp.Where): 1097 if always_true(where.this): 1098 where.pop() 1099 for join in expression.find_all(exp.Join): 1100 if ( 1101 always_true(join.args.get("on")) 1102 and not join.args.get("using") 1103 and not join.args.get("method") 1104 and (join.side, join.kind) in JOINS 1105 ): 1106 join.args["on"].pop() 1107 join.set("side", None) 1108 join.set("kind", "CROSS") 1109 1110 1111def always_true(expression): 1112 return (isinstance(expression, exp.Boolean) and expression.this) or ( 1113 isinstance(expression, exp.Literal) and not is_zero(expression) 1114 ) 1115 1116 1117def always_false(expression): 1118 return is_false(expression) or is_null(expression) or is_zero(expression) 1119 1120 1121def is_zero(expression): 1122 return isinstance(expression, exp.Literal) and expression.to_py() == 0 1123 1124 1125def is_complement(a, b): 1126 return isinstance(b, exp.Not) and b.this == a 1127 1128 1129def is_false(a: exp.Expression) -> bool: 1130 return type(a) is exp.Boolean and not a.this 1131 1132 1133def is_null(a: exp.Expression) -> bool: 1134 return type(a) is exp.Null 1135 1136 1137def eval_boolean(expression, a, b): 1138 if isinstance(expression, (exp.EQ, exp.Is)): 1139 return boolean_literal(a == b) 1140 if isinstance(expression, exp.NEQ): 1141 return boolean_literal(a != b) 1142 if isinstance(expression, exp.GT): 1143 return boolean_literal(a > b) 1144 if isinstance(expression, exp.GTE): 1145 return boolean_literal(a >= b) 1146 if isinstance(expression, exp.LT): 1147 return boolean_literal(a < b) 1148 if isinstance(expression, exp.LTE): 1149 return boolean_literal(a <= b) 1150 return None 1151 1152 1153def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1154 if isinstance(value, datetime.datetime): 1155 return value.date() 1156 if isinstance(value, datetime.date): 1157 return value 1158 try: 1159 return datetime.datetime.fromisoformat(value).date() 1160 except ValueError: 1161 return None 1162 1163 1164def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1165 if isinstance(value, datetime.datetime): 1166 return value 1167 if isinstance(value, datetime.date): 1168 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1169 try: 1170 return datetime.datetime.fromisoformat(value) 1171 except ValueError: 1172 return None 1173 1174 1175def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1176 if not value: 1177 return None 1178 if to.is_type(exp.DataType.Type.DATE): 1179 return cast_as_date(value) 1180 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1181 return cast_as_datetime(value) 1182 return None 1183 1184 1185def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1186 if isinstance(cast, exp.Cast): 1187 to = cast.to 1188 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1189 to = exp.DataType.build(exp.DataType.Type.DATE) 1190 else: 1191 return None 1192 1193 if isinstance(cast.this, exp.Literal): 1194 value: t.Any = cast.this.name 1195 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1196 value = extract_date(cast.this) 1197 else: 1198 return None 1199 return cast_value(value, to) 1200 1201 1202def _is_date_literal(expression: exp.Expression) -> bool: 1203 return extract_date(expression) is not None 1204 1205 1206def extract_interval(expression): 1207 try: 1208 n = int(expression.this.to_py()) 1209 unit = expression.text("unit").lower() 1210 return interval(unit, n) 1211 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1212 return None 1213 1214 1215def extract_type(*expressions): 1216 target_type = None 1217 for expression in expressions: 1218 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1219 if target_type: 1220 break 1221 1222 return target_type 1223 1224 1225def date_literal(date, target_type=None): 1226 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1227 target_type = ( 1228 exp.DataType.Type.DATETIME 1229 if isinstance(date, datetime.datetime) 1230 else exp.DataType.Type.DATE 1231 ) 1232 1233 return exp.cast(exp.Literal.string(date), target_type) 1234 1235 1236def interval(unit: str, n: int = 1): 1237 from dateutil.relativedelta import relativedelta 1238 1239 if unit == "year": 1240 return relativedelta(years=1 * n) 1241 if unit == "quarter": 1242 return relativedelta(months=3 * n) 1243 if unit == "month": 1244 return relativedelta(months=1 * n) 1245 if unit == "week": 1246 return relativedelta(weeks=1 * n) 1247 if unit == "day": 1248 return relativedelta(days=1 * n) 1249 if unit == "hour": 1250 return relativedelta(hours=1 * n) 1251 if unit == "minute": 1252 return relativedelta(minutes=1 * n) 1253 if unit == "second": 1254 return relativedelta(seconds=1 * n) 1255 1256 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1257 1258 1259def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1260 if unit == "year": 1261 return d.replace(month=1, day=1) 1262 if unit == "quarter": 1263 if d.month <= 3: 1264 return d.replace(month=1, day=1) 1265 elif d.month <= 6: 1266 return d.replace(month=4, day=1) 1267 elif d.month <= 9: 1268 return d.replace(month=7, day=1) 1269 else: 1270 return d.replace(month=10, day=1) 1271 if unit == "month": 1272 return d.replace(month=d.month, day=1) 1273 if unit == "week": 1274 # Assuming week starts on Monday (0) and ends on Sunday (6) 1275 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1276 if unit == "day": 1277 return d 1278 1279 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1280 1281 1282def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1283 floor = date_floor(d, unit, dialect) 1284 1285 if floor == d: 1286 return d 1287 1288 return floor + interval(unit) 1289 1290 1291def boolean_literal(condition): 1292 return exp.true() if condition else exp.false() 1293 1294 1295def _flat_simplify(expression, simplifier, root=True): 1296 if root or not expression.same_parent: 1297 operands = [] 1298 queue = deque(expression.flatten(unnest=False)) 1299 size = len(queue) 1300 1301 while queue: 1302 a = queue.popleft() 1303 1304 for b in queue: 1305 result = simplifier(expression, a, b) 1306 1307 if result and result is not expression: 1308 queue.remove(b) 1309 queue.appendleft(result) 1310 break 1311 else: 1312 operands.append(a) 1313 1314 if len(operands) < size: 1315 return functools.reduce( 1316 lambda a, b: expression.__class__(this=a, expression=b), operands 1317 ) 1318 return expression 1319 1320 1321def gen(expression: t.Any, comments: bool = False) -> str: 1322 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1323 1324 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1325 generator is expensive so we have a bare minimum sql generator here. 1326 1327 Args: 1328 expression: the expression to convert into a SQL string. 1329 comments: whether to include the expression's comments. 1330 """ 1331 return Gen().gen(expression, comments=comments) 1332 1333 1334class Gen: 1335 def __init__(self): 1336 self.stack = [] 1337 self.sqls = [] 1338 1339 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1340 self.stack = [expression] 1341 self.sqls.clear() 1342 1343 while self.stack: 1344 node = self.stack.pop() 1345 1346 if isinstance(node, exp.Expression): 1347 if comments and node.comments: 1348 self.stack.append(f" /*{','.join(node.comments)}*/") 1349 1350 exp_handler_name = f"{node.key}_sql" 1351 1352 if hasattr(self, exp_handler_name): 1353 getattr(self, exp_handler_name)(node) 1354 elif isinstance(node, exp.Func): 1355 self._function(node) 1356 else: 1357 key = node.key.upper() 1358 self.stack.append(f"{key} " if self._args(node) else key) 1359 elif type(node) is list: 1360 for n in reversed(node): 1361 if n is not None: 1362 self.stack.extend((n, ",")) 1363 if node: 1364 self.stack.pop() 1365 else: 1366 if node is not None: 1367 self.sqls.append(str(node)) 1368 1369 return "".join(self.sqls) 1370 1371 def add_sql(self, e: exp.Add) -> None: 1372 self._binary(e, " + ") 1373 1374 def alias_sql(self, e: exp.Alias) -> None: 1375 self.stack.extend( 1376 ( 1377 e.args.get("alias"), 1378 " AS ", 1379 e.args.get("this"), 1380 ) 1381 ) 1382 1383 def and_sql(self, e: exp.And) -> None: 1384 self._binary(e, " AND ") 1385 1386 def anonymous_sql(self, e: exp.Anonymous) -> None: 1387 this = e.this 1388 if isinstance(this, str): 1389 name = this.upper() 1390 elif isinstance(this, exp.Identifier): 1391 name = this.this 1392 name = f'"{name}"' if this.quoted else name.upper() 1393 else: 1394 raise ValueError( 1395 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1396 ) 1397 1398 self.stack.extend( 1399 ( 1400 ")", 1401 e.expressions, 1402 "(", 1403 name, 1404 ) 1405 ) 1406 1407 def between_sql(self, e: exp.Between) -> None: 1408 self.stack.extend( 1409 ( 1410 e.args.get("high"), 1411 " AND ", 1412 e.args.get("low"), 1413 " BETWEEN ", 1414 e.this, 1415 ) 1416 ) 1417 1418 def boolean_sql(self, e: exp.Boolean) -> None: 1419 self.stack.append("TRUE" if e.this else "FALSE") 1420 1421 def bracket_sql(self, e: exp.Bracket) -> None: 1422 self.stack.extend( 1423 ( 1424 "]", 1425 e.expressions, 1426 "[", 1427 e.this, 1428 ) 1429 ) 1430 1431 def column_sql(self, e: exp.Column) -> None: 1432 for p in reversed(e.parts): 1433 self.stack.extend((p, ".")) 1434 self.stack.pop() 1435 1436 def datatype_sql(self, e: exp.DataType) -> None: 1437 self._args(e, 1) 1438 self.stack.append(f"{e.this.name} ") 1439 1440 def div_sql(self, e: exp.Div) -> None: 1441 self._binary(e, " / ") 1442 1443 def dot_sql(self, e: exp.Dot) -> None: 1444 self._binary(e, ".") 1445 1446 def eq_sql(self, e: exp.EQ) -> None: 1447 self._binary(e, " = ") 1448 1449 def from_sql(self, e: exp.From) -> None: 1450 self.stack.extend((e.this, "FROM ")) 1451 1452 def gt_sql(self, e: exp.GT) -> None: 1453 self._binary(e, " > ") 1454 1455 def gte_sql(self, e: exp.GTE) -> None: 1456 self._binary(e, " >= ") 1457 1458 def identifier_sql(self, e: exp.Identifier) -> None: 1459 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1460 1461 def ilike_sql(self, e: exp.ILike) -> None: 1462 self._binary(e, " ILIKE ") 1463 1464 def in_sql(self, e: exp.In) -> None: 1465 self.stack.append(")") 1466 self._args(e, 1) 1467 self.stack.extend( 1468 ( 1469 "(", 1470 " IN ", 1471 e.this, 1472 ) 1473 ) 1474 1475 def intdiv_sql(self, e: exp.IntDiv) -> None: 1476 self._binary(e, " DIV ") 1477 1478 def is_sql(self, e: exp.Is) -> None: 1479 self._binary(e, " IS ") 1480 1481 def like_sql(self, e: exp.Like) -> None: 1482 self._binary(e, " Like ") 1483 1484 def literal_sql(self, e: exp.Literal) -> None: 1485 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1486 1487 def lt_sql(self, e: exp.LT) -> None: 1488 self._binary(e, " < ") 1489 1490 def lte_sql(self, e: exp.LTE) -> None: 1491 self._binary(e, " <= ") 1492 1493 def mod_sql(self, e: exp.Mod) -> None: 1494 self._binary(e, " % ") 1495 1496 def mul_sql(self, e: exp.Mul) -> None: 1497 self._binary(e, " * ") 1498 1499 def neg_sql(self, e: exp.Neg) -> None: 1500 self._unary(e, "-") 1501 1502 def neq_sql(self, e: exp.NEQ) -> None: 1503 self._binary(e, " <> ") 1504 1505 def not_sql(self, e: exp.Not) -> None: 1506 self._unary(e, "NOT ") 1507 1508 def null_sql(self, e: exp.Null) -> None: 1509 self.stack.append("NULL") 1510 1511 def or_sql(self, e: exp.Or) -> None: 1512 self._binary(e, " OR ") 1513 1514 def paren_sql(self, e: exp.Paren) -> None: 1515 self.stack.extend( 1516 ( 1517 ")", 1518 e.this, 1519 "(", 1520 ) 1521 ) 1522 1523 def sub_sql(self, e: exp.Sub) -> None: 1524 self._binary(e, " - ") 1525 1526 def subquery_sql(self, e: exp.Subquery) -> None: 1527 self._args(e, 2) 1528 alias = e.args.get("alias") 1529 if alias: 1530 self.stack.append(alias) 1531 self.stack.extend((")", e.this, "(")) 1532 1533 def table_sql(self, e: exp.Table) -> None: 1534 self._args(e, 4) 1535 alias = e.args.get("alias") 1536 if alias: 1537 self.stack.append(alias) 1538 for p in reversed(e.parts): 1539 self.stack.extend((p, ".")) 1540 self.stack.pop() 1541 1542 def tablealias_sql(self, e: exp.TableAlias) -> None: 1543 columns = e.columns 1544 1545 if columns: 1546 self.stack.extend((")", columns, "(")) 1547 1548 self.stack.extend((e.this, " AS ")) 1549 1550 def var_sql(self, e: exp.Var) -> None: 1551 self.stack.append(e.this) 1552 1553 def _binary(self, e: exp.Binary, op: str) -> None: 1554 self.stack.extend((e.expression, op, e.this)) 1555 1556 def _unary(self, e: exp.Unary, op: str) -> None: 1557 self.stack.extend((e.this, op)) 1558 1559 def _function(self, e: exp.Func) -> None: 1560 self.stack.extend( 1561 ( 1562 ")", 1563 list(e.args.values()), 1564 "(", 1565 e.sql_name(), 1566 ) 1567 ) 1568 1569 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1570 kvs = [] 1571 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1572 1573 for k in arg_types or arg_types: 1574 v = node.args.get(k) 1575 1576 if v is not None: 1577 kvs.append([f":{k}", v]) 1578 if kvs: 1579 self.stack.append(kvs) 1580 return True 1581 return False
Common base class for all non-exit exceptions.
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 dialect: DialectType = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 Returns: 58 sqlglot.Expression: simplified expression 59 """ 60 61 dialect = Dialect.get_or_raise(dialect) 62 63 def _simplify(expression): 64 pre_transformation_stack = [expression] 65 post_transformation_stack = [] 66 67 while pre_transformation_stack: 68 node = pre_transformation_stack.pop() 69 70 if node.meta.get(FINAL): 71 continue 72 73 # group by expressions cannot be simplified, for example 74 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 75 # the projection must exactly match the group by key 76 group = node.args.get("group") 77 78 if group and hasattr(node, "selects"): 79 groups = set(group.expressions) 80 group.meta[FINAL] = True 81 82 for s in node.selects: 83 for n in s.walk(): 84 if n in groups: 85 s.meta[FINAL] = True 86 break 87 88 having = node.args.get("having") 89 if having: 90 for n in having.walk(): 91 if n in groups: 92 having.meta[FINAL] = True 93 break 94 95 parent = node.parent 96 root = node is expression 97 98 new_node = rewrite_between(node) 99 new_node = uniq_sort(new_node, root) 100 new_node = absorb_and_eliminate(new_node, root) 101 new_node = simplify_concat(new_node) 102 new_node = simplify_conditionals(new_node) 103 104 if constant_propagation: 105 new_node = propagate_constants(new_node, root) 106 107 if new_node is not node: 108 node.replace(new_node) 109 110 pre_transformation_stack.extend( 111 n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) 112 ) 113 post_transformation_stack.append((new_node, parent)) 114 115 while post_transformation_stack: 116 node, parent = post_transformation_stack.pop() 117 root = node is expression 118 119 # Resets parent, arg_key, index pointers– this is needed because some of the 120 # previous transformations mutate the AST, leading to an inconsistent state 121 for k, v in tuple(node.args.items()): 122 node.set(k, v) 123 124 # Post-order transformations 125 new_node = simplify_not(node) 126 new_node = flatten(new_node) 127 new_node = simplify_connectors(new_node, root) 128 new_node = remove_complements(new_node, root) 129 new_node = simplify_coalesce(new_node, dialect) 130 131 new_node.parent = parent 132 133 new_node = simplify_literals(new_node, root) 134 new_node = simplify_equality(new_node) 135 new_node = simplify_parens(new_node) 136 new_node = simplify_datetrunc(new_node, dialect) 137 new_node = sort_comparison(new_node) 138 new_node = simplify_startswith(new_node) 139 140 if new_node is not node: 141 node.replace(new_node) 142 143 return new_node 144 145 expression = while_changing(expression, _simplify) 146 remove_where_true(expression) 147 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
150def catch(*exceptions): 151 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 152 153 def decorator(func): 154 def wrapped(expression, *args, **kwargs): 155 try: 156 return func(expression, *args, **kwargs) 157 except exceptions: 158 return expression 159 160 return wrapped 161 162 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
165def rewrite_between(expression: exp.Expression) -> exp.Expression: 166 """Rewrite x between y and z to x >= y AND x <= z. 167 168 This is done because comparison simplification is only done on lt/lte/gt/gte. 169 """ 170 if isinstance(expression, exp.Between): 171 negate = isinstance(expression.parent, exp.Not) 172 173 expression = exp.and_( 174 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 175 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 176 copy=False, 177 ) 178 179 if negate: 180 expression = exp.paren(expression, copy=False) 181 182 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
200def simplify_not(expression): 201 """ 202 Demorgan's Law 203 NOT (x OR y) -> NOT x AND NOT y 204 NOT (x AND y) -> NOT x OR NOT y 205 """ 206 if isinstance(expression, exp.Not): 207 this = expression.this 208 if is_null(this): 209 return exp.null() 210 if this.__class__ in COMPLEMENT_COMPARISONS: 211 right = this.expression 212 complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) 213 if complement_subquery_predicate: 214 right = complement_subquery_predicate(this=right.this) 215 216 return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) 217 if isinstance(this, exp.Paren): 218 condition = this.unnest() 219 if isinstance(condition, exp.And): 220 return exp.paren( 221 exp.or_( 222 exp.not_(condition.left, copy=False), 223 exp.not_(condition.right, copy=False), 224 copy=False, 225 ) 226 ) 227 if isinstance(condition, exp.Or): 228 return exp.paren( 229 exp.and_( 230 exp.not_(condition.left, copy=False), 231 exp.not_(condition.right, copy=False), 232 copy=False, 233 ) 234 ) 235 if is_null(condition): 236 return exp.null() 237 if always_true(this): 238 return exp.false() 239 if is_false(this): 240 return exp.true() 241 if isinstance(this, exp.Not): 242 # double negation 243 # NOT NOT x -> x 244 return this.this 245 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
248def flatten(expression): 249 """ 250 A AND (B AND C) -> A AND B AND C 251 A OR (B OR C) -> A OR B OR C 252 """ 253 if isinstance(expression, exp.Connector): 254 for node in expression.args.values(): 255 child = node.unnest() 256 if isinstance(child, expression.__class__): 257 node.replace(child) 258 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
261def simplify_connectors(expression, root=True): 262 def _simplify_connectors(expression, left, right): 263 if isinstance(expression, exp.And): 264 if is_false(left) or is_false(right): 265 return exp.false() 266 if is_zero(left) or is_zero(right): 267 return exp.false() 268 if is_null(left) or is_null(right): 269 return exp.null() 270 if always_true(left) and always_true(right): 271 return exp.true() 272 if always_true(left): 273 return right 274 if always_true(right): 275 return left 276 return _simplify_comparison(expression, left, right) 277 elif isinstance(expression, exp.Or): 278 if always_true(left) or always_true(right): 279 return exp.true() 280 if ( 281 (is_null(left) and is_null(right)) 282 or (is_null(left) and always_false(right)) 283 or (always_false(left) and is_null(right)) 284 ): 285 return exp.null() 286 if is_false(left): 287 return right 288 if is_false(right): 289 return left 290 return _simplify_comparison(expression, left, right, or_=True) 291 elif isinstance(expression, exp.Xor): 292 if left == right: 293 return exp.false() 294 295 if isinstance(expression, exp.Connector): 296 return _flat_simplify(expression, _simplify_connectors, root) 297 return expression
384def remove_complements(expression, root=True): 385 """ 386 Removing complements. 387 388 A AND NOT A -> FALSE 389 A OR NOT A -> TRUE 390 """ 391 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 392 ops = set(expression.flatten()) 393 for op in ops: 394 if isinstance(op, exp.Not) and op.this in ops: 395 return exp.false() if isinstance(expression, exp.And) else exp.true() 396 397 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
400def uniq_sort(expression, root=True): 401 """ 402 Uniq and sort a connector. 403 404 C AND A AND B AND B -> A AND B AND C 405 """ 406 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 407 flattened = tuple(expression.flatten()) 408 409 if isinstance(expression, exp.Xor): 410 result_func = exp.xor 411 # Do not deduplicate XOR as A XOR A != A if A == True 412 deduped = None 413 arr = tuple((gen(e), e) for e in flattened) 414 else: 415 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 416 deduped = {gen(e): e for e in flattened} 417 arr = tuple(deduped.items()) 418 419 # check if the operands are already sorted, if not sort them 420 # A AND C AND B -> A AND B AND C 421 for i, (sql, e) in enumerate(arr[1:]): 422 if sql < arr[i][0]: 423 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 424 break 425 else: 426 # we didn't have to sort but maybe we need to dedup 427 if deduped and len(deduped) < len(flattened): 428 expression = result_func(*deduped.values(), copy=False) 429 430 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
433def absorb_and_eliminate(expression, root=True): 434 """ 435 absorption: 436 A AND (A OR B) -> A 437 A OR (A AND B) -> A 438 A AND (NOT A OR B) -> A AND B 439 A OR (NOT A AND B) -> A OR B 440 elimination: 441 (A AND B) OR (A AND NOT B) -> A 442 (A OR B) AND (A OR NOT B) -> A 443 """ 444 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 445 kind = exp.Or if isinstance(expression, exp.And) else exp.And 446 447 ops = tuple(expression.flatten()) 448 449 # Initialize lookup tables: 450 # Set of all operands, used to find complements for absorption. 451 op_set = set() 452 # Sub-operands, used to find subsets for absorption. 453 subops = defaultdict(list) 454 # Pairs of complements, used for elimination. 455 pairs = defaultdict(list) 456 457 # Populate the lookup tables 458 for op in ops: 459 op_set.add(op) 460 461 if not isinstance(op, kind): 462 # In cases like: A OR (A AND B) 463 # Subop will be: ^ 464 subops[op].append({op}) 465 continue 466 467 # In cases like: (A AND B) OR (A AND B AND C) 468 # Subops will be: ^ ^ 469 subset = set(op.flatten()) 470 for i in subset: 471 subops[i].append(subset) 472 473 a, b = op.unnest_operands() 474 if isinstance(a, exp.Not): 475 pairs[frozenset((a.this, b))].append((op, b)) 476 if isinstance(b, exp.Not): 477 pairs[frozenset((a, b.this))].append((op, a)) 478 479 for op in ops: 480 if not isinstance(op, kind): 481 continue 482 483 a, b = op.unnest_operands() 484 485 # Absorb 486 if isinstance(a, exp.Not) and a.this in op_set: 487 a.replace(exp.true() if kind == exp.And else exp.false()) 488 continue 489 if isinstance(b, exp.Not) and b.this in op_set: 490 b.replace(exp.true() if kind == exp.And else exp.false()) 491 continue 492 superset = set(op.flatten()) 493 if any(any(subset < superset for subset in subops[i]) for i in superset): 494 op.replace(exp.false() if kind == exp.And else exp.true()) 495 continue 496 497 # Eliminate 498 for other, complement in pairs[frozenset((a, b))]: 499 op.replace(complement) 500 other.replace(complement) 501 502 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
505def propagate_constants(expression, root=True): 506 """ 507 Propagate constants for conjunctions in DNF: 508 509 SELECT * FROM t WHERE a = b AND b = 5 becomes 510 SELECT * FROM t WHERE a = 5 AND b = 5 511 512 Reference: https://www.sqlite.org/optoverview.html 513 """ 514 515 if ( 516 isinstance(expression, exp.And) 517 and (root or not expression.same_parent) 518 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 519 ): 520 constant_mapping = {} 521 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 522 if isinstance(expr, exp.EQ): 523 l, r = expr.left, expr.right 524 525 # TODO: create a helper that can be used to detect nested literal expressions such 526 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 527 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 528 constant_mapping[l] = (id(l), r) 529 530 if constant_mapping: 531 for column in find_all_in_scope(expression, exp.Column): 532 parent = column.parent 533 column_id, constant = constant_mapping.get(column) or (None, None) 534 if ( 535 column_id is not None 536 and id(column) != column_id 537 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 538 ): 539 column.replace(constant.copy()) 540 541 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
154 def wrapped(expression, *args, **kwargs): 155 try: 156 return func(expression, *args, **kwargs) 157 except exceptions: 158 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
616def simplify_literals(expression, root=True): 617 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 618 return _flat_simplify(expression, _simplify_binary, root) 619 620 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 621 return expression.this.this 622 623 if type(expression) in INVERSE_DATE_OPS: 624 return _simplify_binary(expression, expression.this, expression.interval()) or expression 625 626 return expression
727def simplify_parens(expression): 728 if not isinstance(expression, exp.Paren): 729 return expression 730 731 this = expression.this 732 parent = expression.parent 733 parent_is_predicate = isinstance(parent, exp.Predicate) 734 735 if ( 736 not isinstance(this, exp.Select) 737 and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) 738 and ( 739 not isinstance(parent, (exp.Condition, exp.Binary)) 740 or isinstance(parent, exp.Paren) 741 or ( 742 not isinstance(this, exp.Binary) 743 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 744 ) 745 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 746 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 747 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 748 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 749 ) 750 ): 751 return this 752 return expression
763def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: 764 # COALESCE(x) -> x 765 if ( 766 isinstance(expression, exp.Coalesce) 767 and (not expression.expressions or _is_nonnull_constant(expression.this)) 768 # COALESCE is also used as a Spark partitioning hint 769 and not isinstance(expression.parent, exp.Hint) 770 ): 771 return expression.this 772 773 # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, 774 # because they are not always equivalent. For example, if `x` is `NULL` and it comes 775 # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` 776 if dialect == "redshift": 777 return expression 778 779 if not isinstance(expression, COMPARISONS): 780 return expression 781 782 if isinstance(expression.left, exp.Coalesce): 783 coalesce = expression.left 784 other = expression.right 785 elif isinstance(expression.right, exp.Coalesce): 786 coalesce = expression.right 787 other = expression.left 788 else: 789 return expression 790 791 # This transformation is valid for non-constants, 792 # but it really only does anything if they are both constants. 793 if not _is_constant(other): 794 return expression 795 796 # Find the first constant arg 797 for arg_index, arg in enumerate(coalesce.expressions): 798 if _is_constant(arg): 799 break 800 else: 801 return expression 802 803 coalesce.set("expressions", coalesce.expressions[:arg_index]) 804 805 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 806 # since we already remove COALESCE at the top of this function. 807 coalesce = coalesce if coalesce.expressions else coalesce.this 808 809 # This expression is more complex than when we started, but it will get simplified further 810 return exp.paren( 811 exp.or_( 812 exp.and_( 813 coalesce.is_(exp.null()).not_(copy=False), 814 expression.copy(), 815 copy=False, 816 ), 817 exp.and_( 818 coalesce.is_(exp.null()), 819 type(expression)(this=arg.copy(), expression=other.copy()), 820 copy=False, 821 ), 822 copy=False, 823 ) 824 )
830def simplify_concat(expression): 831 """Reduces all groups that contain string literals by concatenating them.""" 832 if not isinstance(expression, CONCATS) or ( 833 # We can't reduce a CONCAT_WS call if we don't statically know the separator 834 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 835 ): 836 return expression 837 838 if isinstance(expression, exp.ConcatWs): 839 sep_expr, *expressions = expression.expressions 840 sep = sep_expr.name 841 concat_type = exp.ConcatWs 842 args = {} 843 else: 844 expressions = expression.expressions 845 sep = "" 846 concat_type = exp.Concat 847 args = { 848 "safe": expression.args.get("safe"), 849 "coalesce": expression.args.get("coalesce"), 850 } 851 852 new_args = [] 853 for is_string_group, group in itertools.groupby( 854 expressions or expression.flatten(), lambda e: e.is_string 855 ): 856 if is_string_group: 857 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 858 else: 859 new_args.extend(group) 860 861 if len(new_args) == 1 and new_args[0].is_string: 862 return new_args[0] 863 864 if concat_type is exp.ConcatWs: 865 new_args = [sep_expr] + new_args 866 elif isinstance(expression, exp.DPipe): 867 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 868 869 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
872def simplify_conditionals(expression): 873 """Simplifies expressions like IF, CASE if their condition is statically known.""" 874 if isinstance(expression, exp.Case): 875 this = expression.this 876 for case in expression.args["ifs"]: 877 cond = case.this 878 if this: 879 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 880 cond = cond.replace(this.pop().eq(cond)) 881 882 if always_true(cond): 883 return case.args["true"] 884 885 if always_false(cond): 886 case.pop() 887 if not expression.args["ifs"]: 888 return expression.args.get("default") or exp.null() 889 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 890 if always_true(expression.this): 891 return expression.args["true"] 892 if always_false(expression.this): 893 return expression.args.get("false") or exp.null() 894 895 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
898def simplify_startswith(expression: exp.Expression) -> exp.Expression: 899 """ 900 Reduces a prefix check to either TRUE or FALSE if both the string and the 901 prefix are statically known. 902 903 Example: 904 >>> from sqlglot import parse_one 905 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 906 'TRUE' 907 """ 908 if ( 909 isinstance(expression, exp.StartsWith) 910 and expression.this.is_string 911 and expression.expression.is_string 912 ): 913 return exp.convert(expression.name.startswith(expression.expression.name)) 914 915 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
154 def wrapped(expression, *args, **kwargs): 155 try: 156 return func(expression, *args, **kwargs) 157 except exceptions: 158 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1064def sort_comparison(expression: exp.Expression) -> exp.Expression: 1065 if expression.__class__ in COMPLEMENT_COMPARISONS: 1066 l, r = expression.this, expression.expression 1067 l_column = isinstance(l, exp.Column) 1068 r_column = isinstance(r, exp.Column) 1069 l_const = _is_constant(l) 1070 r_const = _is_constant(r) 1071 1072 if ( 1073 (l_column and not r_column) 1074 or (r_const and not l_const) 1075 or isinstance(r, exp.SubqueryPredicate) 1076 ): 1077 return expression 1078 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1079 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1080 this=r, expression=l 1081 ) 1082 return expression
1096def remove_where_true(expression): 1097 for where in expression.find_all(exp.Where): 1098 if always_true(where.this): 1099 where.pop() 1100 for join in expression.find_all(exp.Join): 1101 if ( 1102 always_true(join.args.get("on")) 1103 and not join.args.get("using") 1104 and not join.args.get("method") 1105 and (join.side, join.kind) in JOINS 1106 ): 1107 join.args["on"].pop() 1108 join.set("side", None) 1109 join.set("kind", "CROSS")
1138def eval_boolean(expression, a, b): 1139 if isinstance(expression, (exp.EQ, exp.Is)): 1140 return boolean_literal(a == b) 1141 if isinstance(expression, exp.NEQ): 1142 return boolean_literal(a != b) 1143 if isinstance(expression, exp.GT): 1144 return boolean_literal(a > b) 1145 if isinstance(expression, exp.GTE): 1146 return boolean_literal(a >= b) 1147 if isinstance(expression, exp.LT): 1148 return boolean_literal(a < b) 1149 if isinstance(expression, exp.LTE): 1150 return boolean_literal(a <= b) 1151 return None
1154def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1155 if isinstance(value, datetime.datetime): 1156 return value.date() 1157 if isinstance(value, datetime.date): 1158 return value 1159 try: 1160 return datetime.datetime.fromisoformat(value).date() 1161 except ValueError: 1162 return None
1165def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1166 if isinstance(value, datetime.datetime): 1167 return value 1168 if isinstance(value, datetime.date): 1169 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1170 try: 1171 return datetime.datetime.fromisoformat(value) 1172 except ValueError: 1173 return None
1176def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1177 if not value: 1178 return None 1179 if to.is_type(exp.DataType.Type.DATE): 1180 return cast_as_date(value) 1181 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1182 return cast_as_datetime(value) 1183 return None
1186def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1187 if isinstance(cast, exp.Cast): 1188 to = cast.to 1189 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1190 to = exp.DataType.build(exp.DataType.Type.DATE) 1191 else: 1192 return None 1193 1194 if isinstance(cast.this, exp.Literal): 1195 value: t.Any = cast.this.name 1196 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1197 value = extract_date(cast.this) 1198 else: 1199 return None 1200 return cast_value(value, to)
1226def date_literal(date, target_type=None): 1227 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1228 target_type = ( 1229 exp.DataType.Type.DATETIME 1230 if isinstance(date, datetime.datetime) 1231 else exp.DataType.Type.DATE 1232 ) 1233 1234 return exp.cast(exp.Literal.string(date), target_type)
1237def interval(unit: str, n: int = 1): 1238 from dateutil.relativedelta import relativedelta 1239 1240 if unit == "year": 1241 return relativedelta(years=1 * n) 1242 if unit == "quarter": 1243 return relativedelta(months=3 * n) 1244 if unit == "month": 1245 return relativedelta(months=1 * n) 1246 if unit == "week": 1247 return relativedelta(weeks=1 * n) 1248 if unit == "day": 1249 return relativedelta(days=1 * n) 1250 if unit == "hour": 1251 return relativedelta(hours=1 * n) 1252 if unit == "minute": 1253 return relativedelta(minutes=1 * n) 1254 if unit == "second": 1255 return relativedelta(seconds=1 * n) 1256 1257 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1260def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1261 if unit == "year": 1262 return d.replace(month=1, day=1) 1263 if unit == "quarter": 1264 if d.month <= 3: 1265 return d.replace(month=1, day=1) 1266 elif d.month <= 6: 1267 return d.replace(month=4, day=1) 1268 elif d.month <= 9: 1269 return d.replace(month=7, day=1) 1270 else: 1271 return d.replace(month=10, day=1) 1272 if unit == "month": 1273 return d.replace(month=d.month, day=1) 1274 if unit == "week": 1275 # Assuming week starts on Monday (0) and ends on Sunday (6) 1276 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1277 if unit == "day": 1278 return d 1279 1280 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1322def gen(expression: t.Any, comments: bool = False) -> str: 1323 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1324 1325 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1326 generator is expensive so we have a bare minimum sql generator here. 1327 1328 Args: 1329 expression: the expression to convert into a SQL string. 1330 comments: whether to include the expression's comments. 1331 """ 1332 return Gen().gen(expression, comments=comments)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
Arguments:
- expression: the expression to convert into a SQL string.
- comments: whether to include the expression's comments.
1335class Gen: 1336 def __init__(self): 1337 self.stack = [] 1338 self.sqls = [] 1339 1340 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1341 self.stack = [expression] 1342 self.sqls.clear() 1343 1344 while self.stack: 1345 node = self.stack.pop() 1346 1347 if isinstance(node, exp.Expression): 1348 if comments and node.comments: 1349 self.stack.append(f" /*{','.join(node.comments)}*/") 1350 1351 exp_handler_name = f"{node.key}_sql" 1352 1353 if hasattr(self, exp_handler_name): 1354 getattr(self, exp_handler_name)(node) 1355 elif isinstance(node, exp.Func): 1356 self._function(node) 1357 else: 1358 key = node.key.upper() 1359 self.stack.append(f"{key} " if self._args(node) else key) 1360 elif type(node) is list: 1361 for n in reversed(node): 1362 if n is not None: 1363 self.stack.extend((n, ",")) 1364 if node: 1365 self.stack.pop() 1366 else: 1367 if node is not None: 1368 self.sqls.append(str(node)) 1369 1370 return "".join(self.sqls) 1371 1372 def add_sql(self, e: exp.Add) -> None: 1373 self._binary(e, " + ") 1374 1375 def alias_sql(self, e: exp.Alias) -> None: 1376 self.stack.extend( 1377 ( 1378 e.args.get("alias"), 1379 " AS ", 1380 e.args.get("this"), 1381 ) 1382 ) 1383 1384 def and_sql(self, e: exp.And) -> None: 1385 self._binary(e, " AND ") 1386 1387 def anonymous_sql(self, e: exp.Anonymous) -> None: 1388 this = e.this 1389 if isinstance(this, str): 1390 name = this.upper() 1391 elif isinstance(this, exp.Identifier): 1392 name = this.this 1393 name = f'"{name}"' if this.quoted else name.upper() 1394 else: 1395 raise ValueError( 1396 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1397 ) 1398 1399 self.stack.extend( 1400 ( 1401 ")", 1402 e.expressions, 1403 "(", 1404 name, 1405 ) 1406 ) 1407 1408 def between_sql(self, e: exp.Between) -> None: 1409 self.stack.extend( 1410 ( 1411 e.args.get("high"), 1412 " AND ", 1413 e.args.get("low"), 1414 " BETWEEN ", 1415 e.this, 1416 ) 1417 ) 1418 1419 def boolean_sql(self, e: exp.Boolean) -> None: 1420 self.stack.append("TRUE" if e.this else "FALSE") 1421 1422 def bracket_sql(self, e: exp.Bracket) -> None: 1423 self.stack.extend( 1424 ( 1425 "]", 1426 e.expressions, 1427 "[", 1428 e.this, 1429 ) 1430 ) 1431 1432 def column_sql(self, e: exp.Column) -> None: 1433 for p in reversed(e.parts): 1434 self.stack.extend((p, ".")) 1435 self.stack.pop() 1436 1437 def datatype_sql(self, e: exp.DataType) -> None: 1438 self._args(e, 1) 1439 self.stack.append(f"{e.this.name} ") 1440 1441 def div_sql(self, e: exp.Div) -> None: 1442 self._binary(e, " / ") 1443 1444 def dot_sql(self, e: exp.Dot) -> None: 1445 self._binary(e, ".") 1446 1447 def eq_sql(self, e: exp.EQ) -> None: 1448 self._binary(e, " = ") 1449 1450 def from_sql(self, e: exp.From) -> None: 1451 self.stack.extend((e.this, "FROM ")) 1452 1453 def gt_sql(self, e: exp.GT) -> None: 1454 self._binary(e, " > ") 1455 1456 def gte_sql(self, e: exp.GTE) -> None: 1457 self._binary(e, " >= ") 1458 1459 def identifier_sql(self, e: exp.Identifier) -> None: 1460 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1461 1462 def ilike_sql(self, e: exp.ILike) -> None: 1463 self._binary(e, " ILIKE ") 1464 1465 def in_sql(self, e: exp.In) -> None: 1466 self.stack.append(")") 1467 self._args(e, 1) 1468 self.stack.extend( 1469 ( 1470 "(", 1471 " IN ", 1472 e.this, 1473 ) 1474 ) 1475 1476 def intdiv_sql(self, e: exp.IntDiv) -> None: 1477 self._binary(e, " DIV ") 1478 1479 def is_sql(self, e: exp.Is) -> None: 1480 self._binary(e, " IS ") 1481 1482 def like_sql(self, e: exp.Like) -> None: 1483 self._binary(e, " Like ") 1484 1485 def literal_sql(self, e: exp.Literal) -> None: 1486 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1487 1488 def lt_sql(self, e: exp.LT) -> None: 1489 self._binary(e, " < ") 1490 1491 def lte_sql(self, e: exp.LTE) -> None: 1492 self._binary(e, " <= ") 1493 1494 def mod_sql(self, e: exp.Mod) -> None: 1495 self._binary(e, " % ") 1496 1497 def mul_sql(self, e: exp.Mul) -> None: 1498 self._binary(e, " * ") 1499 1500 def neg_sql(self, e: exp.Neg) -> None: 1501 self._unary(e, "-") 1502 1503 def neq_sql(self, e: exp.NEQ) -> None: 1504 self._binary(e, " <> ") 1505 1506 def not_sql(self, e: exp.Not) -> None: 1507 self._unary(e, "NOT ") 1508 1509 def null_sql(self, e: exp.Null) -> None: 1510 self.stack.append("NULL") 1511 1512 def or_sql(self, e: exp.Or) -> None: 1513 self._binary(e, " OR ") 1514 1515 def paren_sql(self, e: exp.Paren) -> None: 1516 self.stack.extend( 1517 ( 1518 ")", 1519 e.this, 1520 "(", 1521 ) 1522 ) 1523 1524 def sub_sql(self, e: exp.Sub) -> None: 1525 self._binary(e, " - ") 1526 1527 def subquery_sql(self, e: exp.Subquery) -> None: 1528 self._args(e, 2) 1529 alias = e.args.get("alias") 1530 if alias: 1531 self.stack.append(alias) 1532 self.stack.extend((")", e.this, "(")) 1533 1534 def table_sql(self, e: exp.Table) -> None: 1535 self._args(e, 4) 1536 alias = e.args.get("alias") 1537 if alias: 1538 self.stack.append(alias) 1539 for p in reversed(e.parts): 1540 self.stack.extend((p, ".")) 1541 self.stack.pop() 1542 1543 def tablealias_sql(self, e: exp.TableAlias) -> None: 1544 columns = e.columns 1545 1546 if columns: 1547 self.stack.extend((")", columns, "(")) 1548 1549 self.stack.extend((e.this, " AS ")) 1550 1551 def var_sql(self, e: exp.Var) -> None: 1552 self.stack.append(e.this) 1553 1554 def _binary(self, e: exp.Binary, op: str) -> None: 1555 self.stack.extend((e.expression, op, e.this)) 1556 1557 def _unary(self, e: exp.Unary, op: str) -> None: 1558 self.stack.extend((e.this, op)) 1559 1560 def _function(self, e: exp.Func) -> None: 1561 self.stack.extend( 1562 ( 1563 ")", 1564 list(e.args.values()), 1565 "(", 1566 e.sql_name(), 1567 ) 1568 ) 1569 1570 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1571 kvs = [] 1572 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1573 1574 for k in arg_types or arg_types: 1575 v = node.args.get(k) 1576 1577 if v is not None: 1578 kvs.append([f":{k}", v]) 1579 if kvs: 1580 self.stack.append(kvs) 1581 return True 1582 return False
1340 def gen(self, expression: exp.Expression, comments: bool = False) -> str: 1341 self.stack = [expression] 1342 self.sqls.clear() 1343 1344 while self.stack: 1345 node = self.stack.pop() 1346 1347 if isinstance(node, exp.Expression): 1348 if comments and node.comments: 1349 self.stack.append(f" /*{','.join(node.comments)}*/") 1350 1351 exp_handler_name = f"{node.key}_sql" 1352 1353 if hasattr(self, exp_handler_name): 1354 getattr(self, exp_handler_name)(node) 1355 elif isinstance(node, exp.Func): 1356 self._function(node) 1357 else: 1358 key = node.key.upper() 1359 self.stack.append(f"{key} " if self._args(node) else key) 1360 elif type(node) is list: 1361 for n in reversed(node): 1362 if n is not None: 1363 self.stack.extend((n, ",")) 1364 if node: 1365 self.stack.pop() 1366 else: 1367 if node is not None: 1368 self.sqls.append(str(node)) 1369 1370 return "".join(self.sqls)
1387 def anonymous_sql(self, e: exp.Anonymous) -> None: 1388 this = e.this 1389 if isinstance(this, str): 1390 name = this.upper() 1391 elif isinstance(this, exp.Identifier): 1392 name = this.this 1393 name = f'"{name}"' if this.quoted else name.upper() 1394 else: 1395 raise ValueError( 1396 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1397 ) 1398 1399 self.stack.extend( 1400 ( 1401 ")", 1402 e.expressions, 1403 "(", 1404 name, 1405 ) 1406 )