sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import logging 5import functools 6import itertools 7import typing as t 8from collections import deque, defaultdict 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23logger = logging.getLogger("sqlglot") 24 25# Final means that an expression should not be simplified 26FINAL = "final" 27 28# Value ranges for byte-sized signed/unsigned integers 29TINYINT_MIN = -128 30TINYINT_MAX = 127 31UTINYINT_MIN = 0 32UTINYINT_MAX = 255 33 34 35class UnsupportedUnit(Exception): 36 pass 37 38 39def simplify( 40 expression: exp.Expression, 41 constant_propagation: bool = False, 42 dialect: DialectType = None, 43 max_depth: t.Optional[int] = None, 44): 45 """ 46 Rewrite sqlglot AST to simplify expressions. 47 48 Example: 49 >>> import sqlglot 50 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 51 >>> simplify(expression).sql() 52 'TRUE' 53 54 Args: 55 expression: expression to simplify 56 constant_propagation: whether the constant propagation rule should be used 57 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 58 Returns: 59 sqlglot.Expression: simplified expression 60 """ 61 62 dialect = Dialect.get_or_raise(dialect) 63 64 def _simplify(expression, root=True): 65 if ( 66 max_depth 67 and isinstance(expression, exp.Connector) 68 and not isinstance(expression.parent, exp.Connector) 69 ): 70 depth = connector_depth(expression) 71 if depth > max_depth: 72 logger.info( 73 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 74 ) 75 return expression 76 77 if expression.meta.get(FINAL): 78 return expression 79 80 # group by expressions cannot be simplified, for example 81 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 82 # the projection must exactly match the group by key 83 group = expression.args.get("group") 84 85 if group and hasattr(expression, "selects"): 86 groups = set(group.expressions) 87 group.meta[FINAL] = True 88 89 for e in expression.selects: 90 for node in e.walk(): 91 if node in groups: 92 e.meta[FINAL] = True 93 break 94 95 having = expression.args.get("having") 96 if having: 97 for node in having.walk(): 98 if node in groups: 99 having.meta[FINAL] = True 100 break 101 102 # Pre-order transformations 103 node = expression 104 node = rewrite_between(node) 105 node = uniq_sort(node, root) 106 node = absorb_and_eliminate(node, root) 107 node = simplify_concat(node) 108 node = simplify_conditionals(node) 109 110 if constant_propagation: 111 node = propagate_constants(node, root) 112 113 exp.replace_children(node, lambda e: _simplify(e, False)) 114 115 # Post-order transformations 116 node = simplify_not(node) 117 node = flatten(node) 118 node = simplify_connectors(node, root) 119 node = remove_complements(node, root) 120 node = simplify_coalesce(node) 121 node.parent = expression.parent 122 node = simplify_literals(node, root) 123 node = simplify_equality(node) 124 node = simplify_parens(node) 125 node = simplify_datetrunc(node, dialect) 126 node = sort_comparison(node) 127 node = simplify_startswith(node) 128 129 if root: 130 expression.replace(node) 131 return node 132 133 expression = while_changing(expression, _simplify) 134 remove_where_true(expression) 135 return expression 136 137 138def connector_depth(expression: exp.Expression) -> int: 139 """ 140 Determine the maximum depth of a tree of Connectors. 141 142 For example: 143 >>> from sqlglot import parse_one 144 >>> connector_depth(parse_one("a AND b AND c AND d")) 145 3 146 """ 147 stack = deque([(expression, 0)]) 148 max_depth = 0 149 150 while stack: 151 expression, depth = stack.pop() 152 153 if not isinstance(expression, exp.Connector): 154 continue 155 156 depth += 1 157 max_depth = max(depth, max_depth) 158 159 stack.append((expression.left, depth)) 160 stack.append((expression.right, depth)) 161 162 return max_depth 163 164 165def catch(*exceptions): 166 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 167 168 def decorator(func): 169 def wrapped(expression, *args, **kwargs): 170 try: 171 return func(expression, *args, **kwargs) 172 except exceptions: 173 return expression 174 175 return wrapped 176 177 return decorator 178 179 180def rewrite_between(expression: exp.Expression) -> exp.Expression: 181 """Rewrite x between y and z to x >= y AND x <= z. 182 183 This is done because comparison simplification is only done on lt/lte/gt/gte. 184 """ 185 if isinstance(expression, exp.Between): 186 negate = isinstance(expression.parent, exp.Not) 187 188 expression = exp.and_( 189 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 190 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 191 copy=False, 192 ) 193 194 if negate: 195 expression = exp.paren(expression, copy=False) 196 197 return expression 198 199 200COMPLEMENT_COMPARISONS = { 201 exp.LT: exp.GTE, 202 exp.GT: exp.LTE, 203 exp.LTE: exp.GT, 204 exp.GTE: exp.LT, 205 exp.EQ: exp.NEQ, 206 exp.NEQ: exp.EQ, 207} 208 209 210def simplify_not(expression): 211 """ 212 Demorgan's Law 213 NOT (x OR y) -> NOT x AND NOT y 214 NOT (x AND y) -> NOT x OR NOT y 215 """ 216 if isinstance(expression, exp.Not): 217 this = expression.this 218 if is_null(this): 219 return exp.null() 220 if this.__class__ in COMPLEMENT_COMPARISONS: 221 return COMPLEMENT_COMPARISONS[this.__class__]( 222 this=this.this, expression=this.expression 223 ) 224 if isinstance(this, exp.Paren): 225 condition = this.unnest() 226 if isinstance(condition, exp.And): 227 return exp.paren( 228 exp.or_( 229 exp.not_(condition.left, copy=False), 230 exp.not_(condition.right, copy=False), 231 copy=False, 232 ) 233 ) 234 if isinstance(condition, exp.Or): 235 return exp.paren( 236 exp.and_( 237 exp.not_(condition.left, copy=False), 238 exp.not_(condition.right, copy=False), 239 copy=False, 240 ) 241 ) 242 if is_null(condition): 243 return exp.null() 244 if always_true(this): 245 return exp.false() 246 if is_false(this): 247 return exp.true() 248 if isinstance(this, exp.Not): 249 # double negation 250 # NOT NOT x -> x 251 return this.this 252 return expression 253 254 255def flatten(expression): 256 """ 257 A AND (B AND C) -> A AND B AND C 258 A OR (B OR C) -> A OR B OR C 259 """ 260 if isinstance(expression, exp.Connector): 261 for node in expression.args.values(): 262 child = node.unnest() 263 if isinstance(child, expression.__class__): 264 node.replace(child) 265 return expression 266 267 268def simplify_connectors(expression, root=True): 269 def _simplify_connectors(expression, left, right): 270 if left == right: 271 if isinstance(expression, exp.Xor): 272 return exp.false() 273 return left 274 if isinstance(expression, exp.And): 275 if is_false(left) or is_false(right): 276 return exp.false() 277 if is_null(left) or is_null(right): 278 return exp.null() 279 if always_true(left) and always_true(right): 280 return exp.true() 281 if always_true(left): 282 return right 283 if always_true(right): 284 return left 285 return _simplify_comparison(expression, left, right) 286 elif isinstance(expression, exp.Or): 287 if always_true(left) or always_true(right): 288 return exp.true() 289 if is_false(left) and is_false(right): 290 return exp.false() 291 if ( 292 (is_null(left) and is_null(right)) 293 or (is_null(left) and is_false(right)) 294 or (is_false(left) and is_null(right)) 295 ): 296 return exp.null() 297 if is_false(left): 298 return right 299 if is_false(right): 300 return left 301 return _simplify_comparison(expression, left, right, or_=True) 302 303 if isinstance(expression, exp.Connector): 304 return _flat_simplify(expression, _simplify_connectors, root) 305 return expression 306 307 308LT_LTE = (exp.LT, exp.LTE) 309GT_GTE = (exp.GT, exp.GTE) 310 311COMPARISONS = ( 312 *LT_LTE, 313 *GT_GTE, 314 exp.EQ, 315 exp.NEQ, 316 exp.Is, 317) 318 319INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 320 exp.LT: exp.GT, 321 exp.GT: exp.LT, 322 exp.LTE: exp.GTE, 323 exp.GTE: exp.LTE, 324} 325 326NONDETERMINISTIC = (exp.Rand, exp.Randn) 327AND_OR = (exp.And, exp.Or) 328 329 330def _simplify_comparison(expression, left, right, or_=False): 331 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 332 ll, lr = left.args.values() 333 rl, rr = right.args.values() 334 335 largs = {ll, lr} 336 rargs = {rl, rr} 337 338 matching = largs & rargs 339 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 340 341 if matching and columns: 342 try: 343 l = first(largs - columns) 344 r = first(rargs - columns) 345 except StopIteration: 346 return expression 347 348 if l.is_number and r.is_number: 349 l = l.to_py() 350 r = r.to_py() 351 elif l.is_string and r.is_string: 352 l = l.name 353 r = r.name 354 else: 355 l = extract_date(l) 356 if not l: 357 return None 358 r = extract_date(r) 359 if not r: 360 return None 361 # python won't compare date and datetime, but many engines will upcast 362 l, r = cast_as_datetime(l), cast_as_datetime(r) 363 364 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 365 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 366 return left if (av > bv if or_ else av <= bv) else right 367 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 368 return left if (av < bv if or_ else av >= bv) else right 369 370 # we can't ever shortcut to true because the column could be null 371 if not or_: 372 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 373 if av <= bv: 374 return exp.false() 375 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 376 if av >= bv: 377 return exp.false() 378 elif isinstance(a, exp.EQ): 379 if isinstance(b, exp.LT): 380 return exp.false() if av >= bv else a 381 if isinstance(b, exp.LTE): 382 return exp.false() if av > bv else a 383 if isinstance(b, exp.GT): 384 return exp.false() if av <= bv else a 385 if isinstance(b, exp.GTE): 386 return exp.false() if av < bv else a 387 if isinstance(b, exp.NEQ): 388 return exp.false() if av == bv else a 389 return None 390 391 392def remove_complements(expression, root=True): 393 """ 394 Removing complements. 395 396 A AND NOT A -> FALSE 397 A OR NOT A -> TRUE 398 """ 399 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 400 ops = set(expression.flatten()) 401 for op in ops: 402 if isinstance(op, exp.Not) and op.this in ops: 403 return exp.false() if isinstance(expression, exp.And) else exp.true() 404 405 return expression 406 407 408def uniq_sort(expression, root=True): 409 """ 410 Uniq and sort a connector. 411 412 C AND A AND B AND B -> A AND B AND C 413 """ 414 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 415 flattened = tuple(expression.flatten()) 416 417 if isinstance(expression, exp.Xor): 418 result_func = exp.xor 419 # Do not deduplicate XOR as A XOR A != A if A == True 420 deduped = None 421 arr = tuple((gen(e), e) for e in flattened) 422 else: 423 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 424 deduped = {gen(e): e for e in flattened} 425 arr = tuple(deduped.items()) 426 427 # check if the operands are already sorted, if not sort them 428 # A AND C AND B -> A AND B AND C 429 for i, (sql, e) in enumerate(arr[1:]): 430 if sql < arr[i][0]: 431 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 432 break 433 else: 434 # we didn't have to sort but maybe we need to dedup 435 if deduped and len(deduped) < len(flattened): 436 expression = result_func(*deduped.values(), copy=False) 437 438 return expression 439 440 441def absorb_and_eliminate(expression, root=True): 442 """ 443 absorption: 444 A AND (A OR B) -> A 445 A OR (A AND B) -> A 446 A AND (NOT A OR B) -> A AND B 447 A OR (NOT A AND B) -> A OR B 448 elimination: 449 (A AND B) OR (A AND NOT B) -> A 450 (A OR B) AND (A OR NOT B) -> A 451 """ 452 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 453 kind = exp.Or if isinstance(expression, exp.And) else exp.And 454 455 ops = tuple(expression.flatten()) 456 457 # Initialize lookup tables: 458 # Set of all operands, used to find complements for absorption. 459 op_set = set() 460 # Sub-operands, used to find subsets for absorption. 461 subops = defaultdict(list) 462 # Pairs of complements, used for elimination. 463 pairs = defaultdict(list) 464 465 # Populate the lookup tables 466 for op in ops: 467 op_set.add(op) 468 469 if not isinstance(op, kind): 470 # In cases like: A OR (A AND B) 471 # Subop will be: ^ 472 subops[op].append({op}) 473 continue 474 475 # In cases like: (A AND B) OR (A AND B AND C) 476 # Subops will be: ^ ^ 477 subset = set(op.flatten()) 478 for i in subset: 479 subops[i].append(subset) 480 481 a, b = op.unnest_operands() 482 if isinstance(a, exp.Not): 483 pairs[frozenset((a.this, b))].append((op, b)) 484 if isinstance(b, exp.Not): 485 pairs[frozenset((a, b.this))].append((op, a)) 486 487 for op in ops: 488 if not isinstance(op, kind): 489 continue 490 491 a, b = op.unnest_operands() 492 493 # Absorb 494 if isinstance(a, exp.Not) and a.this in op_set: 495 a.replace(exp.true() if kind == exp.And else exp.false()) 496 continue 497 if isinstance(b, exp.Not) and b.this in op_set: 498 b.replace(exp.true() if kind == exp.And else exp.false()) 499 continue 500 superset = set(op.flatten()) 501 if any(any(subset < superset for subset in subops[i]) for i in superset): 502 op.replace(exp.false() if kind == exp.And else exp.true()) 503 continue 504 505 # Eliminate 506 for other, complement in pairs[frozenset((a, b))]: 507 op.replace(complement) 508 other.replace(complement) 509 510 return expression 511 512 513def propagate_constants(expression, root=True): 514 """ 515 Propagate constants for conjunctions in DNF: 516 517 SELECT * FROM t WHERE a = b AND b = 5 becomes 518 SELECT * FROM t WHERE a = 5 AND b = 5 519 520 Reference: https://www.sqlite.org/optoverview.html 521 """ 522 523 if ( 524 isinstance(expression, exp.And) 525 and (root or not expression.same_parent) 526 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 527 ): 528 constant_mapping = {} 529 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 530 if isinstance(expr, exp.EQ): 531 l, r = expr.left, expr.right 532 533 # TODO: create a helper that can be used to detect nested literal expressions such 534 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 535 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 536 constant_mapping[l] = (id(l), r) 537 538 if constant_mapping: 539 for column in find_all_in_scope(expression, exp.Column): 540 parent = column.parent 541 column_id, constant = constant_mapping.get(column) or (None, None) 542 if ( 543 column_id is not None 544 and id(column) != column_id 545 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 546 ): 547 column.replace(constant.copy()) 548 549 return expression 550 551 552INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 553 exp.DateAdd: exp.Sub, 554 exp.DateSub: exp.Add, 555 exp.DatetimeAdd: exp.Sub, 556 exp.DatetimeSub: exp.Add, 557} 558 559INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 560 **INVERSE_DATE_OPS, 561 exp.Add: exp.Sub, 562 exp.Sub: exp.Add, 563} 564 565 566def _is_number(expression: exp.Expression) -> bool: 567 return expression.is_number 568 569 570def _is_interval(expression: exp.Expression) -> bool: 571 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 572 573 574@catch(ModuleNotFoundError, UnsupportedUnit) 575def simplify_equality(expression: exp.Expression) -> exp.Expression: 576 """ 577 Use the subtraction and addition properties of equality to simplify expressions: 578 579 x + 1 = 3 becomes x = 2 580 581 There are two binary operations in the above expression: + and = 582 Here's how we reference all the operands in the code below: 583 584 l r 585 x + 1 = 3 586 a b 587 """ 588 if isinstance(expression, COMPARISONS): 589 l, r = expression.left, expression.right 590 591 if l.__class__ not in INVERSE_OPS: 592 return expression 593 594 if r.is_number: 595 a_predicate = _is_number 596 b_predicate = _is_number 597 elif _is_date_literal(r): 598 a_predicate = _is_date_literal 599 b_predicate = _is_interval 600 else: 601 return expression 602 603 if l.__class__ in INVERSE_DATE_OPS: 604 l = t.cast(exp.IntervalOp, l) 605 a = l.this 606 b = l.interval() 607 else: 608 l = t.cast(exp.Binary, l) 609 a, b = l.left, l.right 610 611 if not a_predicate(a) and b_predicate(b): 612 pass 613 elif not a_predicate(b) and b_predicate(a): 614 a, b = b, a 615 else: 616 return expression 617 618 return expression.__class__( 619 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 620 ) 621 return expression 622 623 624def simplify_literals(expression, root=True): 625 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 626 return _flat_simplify(expression, _simplify_binary, root) 627 628 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 629 return expression.this.this 630 631 if type(expression) in INVERSE_DATE_OPS: 632 return _simplify_binary(expression, expression.this, expression.interval()) or expression 633 634 return expression 635 636 637NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 638 639 640def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 641 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 642 this = _simplify_integer_cast(expr.this) 643 else: 644 this = expr.this 645 646 if isinstance(expr, exp.Cast) and this.is_int: 647 num = this.to_py() 648 649 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 650 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 651 # engine-dependent 652 if ( 653 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 654 ) or ( 655 UTINYINT_MIN <= num <= UTINYINT_MAX 656 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 657 ): 658 return this 659 660 return expr 661 662 663def _simplify_binary(expression, a, b): 664 if isinstance(expression, COMPARISONS): 665 a = _simplify_integer_cast(a) 666 b = _simplify_integer_cast(b) 667 668 if isinstance(expression, exp.Is): 669 if isinstance(b, exp.Not): 670 c = b.this 671 not_ = True 672 else: 673 c = b 674 not_ = False 675 676 if is_null(c): 677 if isinstance(a, exp.Literal): 678 return exp.true() if not_ else exp.false() 679 if is_null(a): 680 return exp.false() if not_ else exp.true() 681 elif isinstance(expression, NULL_OK): 682 return None 683 elif is_null(a) or is_null(b): 684 return exp.null() 685 686 if a.is_number and b.is_number: 687 num_a = a.to_py() 688 num_b = b.to_py() 689 690 if isinstance(expression, exp.Add): 691 return exp.Literal.number(num_a + num_b) 692 if isinstance(expression, exp.Mul): 693 return exp.Literal.number(num_a * num_b) 694 695 # We only simplify Sub, Div if a and b have the same parent because they're not associative 696 if isinstance(expression, exp.Sub): 697 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 698 if isinstance(expression, exp.Div): 699 # engines have differing int div behavior so intdiv is not safe 700 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 701 return None 702 return exp.Literal.number(num_a / num_b) 703 704 boolean = eval_boolean(expression, num_a, num_b) 705 706 if boolean: 707 return boolean 708 elif a.is_string and b.is_string: 709 boolean = eval_boolean(expression, a.this, b.this) 710 711 if boolean: 712 return boolean 713 elif _is_date_literal(a) and isinstance(b, exp.Interval): 714 date, b = extract_date(a), extract_interval(b) 715 if date and b: 716 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 717 return date_literal(date + b, extract_type(a)) 718 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 719 return date_literal(date - b, extract_type(a)) 720 elif isinstance(a, exp.Interval) and _is_date_literal(b): 721 a, date = extract_interval(a), extract_date(b) 722 # you cannot subtract a date from an interval 723 if a and b and isinstance(expression, exp.Add): 724 return date_literal(a + date, extract_type(b)) 725 elif _is_date_literal(a) and _is_date_literal(b): 726 if isinstance(expression, exp.Predicate): 727 a, b = extract_date(a), extract_date(b) 728 boolean = eval_boolean(expression, a, b) 729 if boolean: 730 return boolean 731 732 return None 733 734 735def simplify_parens(expression): 736 if not isinstance(expression, exp.Paren): 737 return expression 738 739 this = expression.this 740 parent = expression.parent 741 parent_is_predicate = isinstance(parent, exp.Predicate) 742 743 if ( 744 not isinstance(this, exp.Select) 745 and not isinstance(parent, exp.SubqueryPredicate) 746 and ( 747 not isinstance(parent, (exp.Condition, exp.Binary)) 748 or isinstance(parent, exp.Paren) 749 or ( 750 not isinstance(this, exp.Binary) 751 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 752 ) 753 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 754 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 755 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 756 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 757 ) 758 ): 759 return this 760 return expression 761 762 763def _is_nonnull_constant(expression: exp.Expression) -> bool: 764 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 765 766 767def _is_constant(expression: exp.Expression) -> bool: 768 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 769 770 771def simplify_coalesce(expression): 772 # COALESCE(x) -> x 773 if ( 774 isinstance(expression, exp.Coalesce) 775 and (not expression.expressions or _is_nonnull_constant(expression.this)) 776 # COALESCE is also used as a Spark partitioning hint 777 and not isinstance(expression.parent, exp.Hint) 778 ): 779 return expression.this 780 781 if not isinstance(expression, COMPARISONS): 782 return expression 783 784 if isinstance(expression.left, exp.Coalesce): 785 coalesce = expression.left 786 other = expression.right 787 elif isinstance(expression.right, exp.Coalesce): 788 coalesce = expression.right 789 other = expression.left 790 else: 791 return expression 792 793 # This transformation is valid for non-constants, 794 # but it really only does anything if they are both constants. 795 if not _is_constant(other): 796 return expression 797 798 # Find the first constant arg 799 for arg_index, arg in enumerate(coalesce.expressions): 800 if _is_constant(arg): 801 break 802 else: 803 return expression 804 805 coalesce.set("expressions", coalesce.expressions[:arg_index]) 806 807 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 808 # since we already remove COALESCE at the top of this function. 809 coalesce = coalesce if coalesce.expressions else coalesce.this 810 811 # This expression is more complex than when we started, but it will get simplified further 812 return exp.paren( 813 exp.or_( 814 exp.and_( 815 coalesce.is_(exp.null()).not_(copy=False), 816 expression.copy(), 817 copy=False, 818 ), 819 exp.and_( 820 coalesce.is_(exp.null()), 821 type(expression)(this=arg.copy(), expression=other.copy()), 822 copy=False, 823 ), 824 copy=False, 825 ) 826 ) 827 828 829CONCATS = (exp.Concat, exp.DPipe) 830 831 832def simplify_concat(expression): 833 """Reduces all groups that contain string literals by concatenating them.""" 834 if not isinstance(expression, CONCATS) or ( 835 # We can't reduce a CONCAT_WS call if we don't statically know the separator 836 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 837 ): 838 return expression 839 840 if isinstance(expression, exp.ConcatWs): 841 sep_expr, *expressions = expression.expressions 842 sep = sep_expr.name 843 concat_type = exp.ConcatWs 844 args = {} 845 else: 846 expressions = expression.expressions 847 sep = "" 848 concat_type = exp.Concat 849 args = { 850 "safe": expression.args.get("safe"), 851 "coalesce": expression.args.get("coalesce"), 852 } 853 854 new_args = [] 855 for is_string_group, group in itertools.groupby( 856 expressions or expression.flatten(), lambda e: e.is_string 857 ): 858 if is_string_group: 859 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 860 else: 861 new_args.extend(group) 862 863 if len(new_args) == 1 and new_args[0].is_string: 864 return new_args[0] 865 866 if concat_type is exp.ConcatWs: 867 new_args = [sep_expr] + new_args 868 elif isinstance(expression, exp.DPipe): 869 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 870 871 return concat_type(expressions=new_args, **args) 872 873 874def simplify_conditionals(expression): 875 """Simplifies expressions like IF, CASE if their condition is statically known.""" 876 if isinstance(expression, exp.Case): 877 this = expression.this 878 for case in expression.args["ifs"]: 879 cond = case.this 880 if this: 881 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 882 cond = cond.replace(this.pop().eq(cond)) 883 884 if always_true(cond): 885 return case.args["true"] 886 887 if always_false(cond): 888 case.pop() 889 if not expression.args["ifs"]: 890 return expression.args.get("default") or exp.null() 891 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 892 if always_true(expression.this): 893 return expression.args["true"] 894 if always_false(expression.this): 895 return expression.args.get("false") or exp.null() 896 897 return expression 898 899 900def simplify_startswith(expression: exp.Expression) -> exp.Expression: 901 """ 902 Reduces a prefix check to either TRUE or FALSE if both the string and the 903 prefix are statically known. 904 905 Example: 906 >>> from sqlglot import parse_one 907 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 908 'TRUE' 909 """ 910 if ( 911 isinstance(expression, exp.StartsWith) 912 and expression.this.is_string 913 and expression.expression.is_string 914 ): 915 return exp.convert(expression.name.startswith(expression.expression.name)) 916 917 return expression 918 919 920DateRange = t.Tuple[datetime.date, datetime.date] 921 922 923def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 924 """ 925 Get the date range for a DATE_TRUNC equality comparison: 926 927 Example: 928 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 929 Returns: 930 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 931 """ 932 floor = date_floor(date, unit, dialect) 933 934 if date != floor: 935 # This will always be False, except for NULL values. 936 return None 937 938 return floor, floor + interval(unit) 939 940 941def _datetrunc_eq_expression( 942 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 943) -> exp.Expression: 944 """Get the logical expression for a date range""" 945 return exp.and_( 946 left >= date_literal(drange[0], target_type), 947 left < date_literal(drange[1], target_type), 948 copy=False, 949 ) 950 951 952def _datetrunc_eq( 953 left: exp.Expression, 954 date: datetime.date, 955 unit: str, 956 dialect: Dialect, 957 target_type: t.Optional[exp.DataType], 958) -> t.Optional[exp.Expression]: 959 drange = _datetrunc_range(date, unit, dialect) 960 if not drange: 961 return None 962 963 return _datetrunc_eq_expression(left, drange, target_type) 964 965 966def _datetrunc_neq( 967 left: exp.Expression, 968 date: datetime.date, 969 unit: str, 970 dialect: Dialect, 971 target_type: t.Optional[exp.DataType], 972) -> t.Optional[exp.Expression]: 973 drange = _datetrunc_range(date, unit, dialect) 974 if not drange: 975 return None 976 977 return exp.and_( 978 left < date_literal(drange[0], target_type), 979 left >= date_literal(drange[1], target_type), 980 copy=False, 981 ) 982 983 984DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 985 exp.LT: lambda l, dt, u, d, t: l 986 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 987 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 988 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 989 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 990 exp.EQ: _datetrunc_eq, 991 exp.NEQ: _datetrunc_neq, 992} 993DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 994DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 995 996 997def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 998 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 999 1000 1001@catch(ModuleNotFoundError, UnsupportedUnit) 1002def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 1003 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 1004 comparison = expression.__class__ 1005 1006 if isinstance(expression, DATETRUNCS): 1007 this = expression.this 1008 trunc_type = extract_type(this) 1009 date = extract_date(this) 1010 if date and expression.unit: 1011 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 1012 elif comparison not in DATETRUNC_COMPARISONS: 1013 return expression 1014 1015 if isinstance(expression, exp.Binary): 1016 l, r = expression.left, expression.right 1017 1018 if not _is_datetrunc_predicate(l, r): 1019 return expression 1020 1021 l = t.cast(exp.DateTrunc, l) 1022 trunc_arg = l.this 1023 unit = l.unit.name.lower() 1024 date = extract_date(r) 1025 1026 if not date: 1027 return expression 1028 1029 return ( 1030 DATETRUNC_BINARY_COMPARISONS[comparison]( 1031 trunc_arg, date, unit, dialect, extract_type(r) 1032 ) 1033 or expression 1034 ) 1035 1036 if isinstance(expression, exp.In): 1037 l = expression.this 1038 rs = expression.expressions 1039 1040 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 1041 l = t.cast(exp.DateTrunc, l) 1042 unit = l.unit.name.lower() 1043 1044 ranges = [] 1045 for r in rs: 1046 date = extract_date(r) 1047 if not date: 1048 return expression 1049 drange = _datetrunc_range(date, unit, dialect) 1050 if drange: 1051 ranges.append(drange) 1052 1053 if not ranges: 1054 return expression 1055 1056 ranges = merge_ranges(ranges) 1057 target_type = extract_type(*rs) 1058 1059 return exp.or_( 1060 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 1061 ) 1062 1063 return expression 1064 1065 1066def sort_comparison(expression: exp.Expression) -> exp.Expression: 1067 if expression.__class__ in COMPLEMENT_COMPARISONS: 1068 l, r = expression.this, expression.expression 1069 l_column = isinstance(l, exp.Column) 1070 r_column = isinstance(r, exp.Column) 1071 l_const = _is_constant(l) 1072 r_const = _is_constant(r) 1073 1074 if (l_column and not r_column) or (r_const and not l_const): 1075 return expression 1076 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1077 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1078 this=r, expression=l 1079 ) 1080 return expression 1081 1082 1083# CROSS joins result in an empty table if the right table is empty. 1084# So we can only simplify certain types of joins to CROSS. 1085# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1086JOINS = { 1087 ("", ""), 1088 ("", "INNER"), 1089 ("RIGHT", ""), 1090 ("RIGHT", "OUTER"), 1091} 1092 1093 1094def remove_where_true(expression): 1095 for where in expression.find_all(exp.Where): 1096 if always_true(where.this): 1097 where.pop() 1098 for join in expression.find_all(exp.Join): 1099 if ( 1100 always_true(join.args.get("on")) 1101 and not join.args.get("using") 1102 and not join.args.get("method") 1103 and (join.side, join.kind) in JOINS 1104 ): 1105 join.args["on"].pop() 1106 join.set("side", None) 1107 join.set("kind", "CROSS") 1108 1109 1110def always_true(expression): 1111 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 1112 expression, exp.Literal 1113 ) 1114 1115 1116def always_false(expression): 1117 return is_false(expression) or is_null(expression) 1118 1119 1120def is_complement(a, b): 1121 return isinstance(b, exp.Not) and b.this == a 1122 1123 1124def is_false(a: exp.Expression) -> bool: 1125 return type(a) is exp.Boolean and not a.this 1126 1127 1128def is_null(a: exp.Expression) -> bool: 1129 return type(a) is exp.Null 1130 1131 1132def eval_boolean(expression, a, b): 1133 if isinstance(expression, (exp.EQ, exp.Is)): 1134 return boolean_literal(a == b) 1135 if isinstance(expression, exp.NEQ): 1136 return boolean_literal(a != b) 1137 if isinstance(expression, exp.GT): 1138 return boolean_literal(a > b) 1139 if isinstance(expression, exp.GTE): 1140 return boolean_literal(a >= b) 1141 if isinstance(expression, exp.LT): 1142 return boolean_literal(a < b) 1143 if isinstance(expression, exp.LTE): 1144 return boolean_literal(a <= b) 1145 return None 1146 1147 1148def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1149 if isinstance(value, datetime.datetime): 1150 return value.date() 1151 if isinstance(value, datetime.date): 1152 return value 1153 try: 1154 return datetime.datetime.fromisoformat(value).date() 1155 except ValueError: 1156 return None 1157 1158 1159def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1160 if isinstance(value, datetime.datetime): 1161 return value 1162 if isinstance(value, datetime.date): 1163 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1164 try: 1165 return datetime.datetime.fromisoformat(value) 1166 except ValueError: 1167 return None 1168 1169 1170def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1171 if not value: 1172 return None 1173 if to.is_type(exp.DataType.Type.DATE): 1174 return cast_as_date(value) 1175 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1176 return cast_as_datetime(value) 1177 return None 1178 1179 1180def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1181 if isinstance(cast, exp.Cast): 1182 to = cast.to 1183 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1184 to = exp.DataType.build(exp.DataType.Type.DATE) 1185 else: 1186 return None 1187 1188 if isinstance(cast.this, exp.Literal): 1189 value: t.Any = cast.this.name 1190 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1191 value = extract_date(cast.this) 1192 else: 1193 return None 1194 return cast_value(value, to) 1195 1196 1197def _is_date_literal(expression: exp.Expression) -> bool: 1198 return extract_date(expression) is not None 1199 1200 1201def extract_interval(expression): 1202 try: 1203 n = int(expression.this.to_py()) 1204 unit = expression.text("unit").lower() 1205 return interval(unit, n) 1206 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1207 return None 1208 1209 1210def extract_type(*expressions): 1211 target_type = None 1212 for expression in expressions: 1213 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1214 if target_type: 1215 break 1216 1217 return target_type 1218 1219 1220def date_literal(date, target_type=None): 1221 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1222 target_type = ( 1223 exp.DataType.Type.DATETIME 1224 if isinstance(date, datetime.datetime) 1225 else exp.DataType.Type.DATE 1226 ) 1227 1228 return exp.cast(exp.Literal.string(date), target_type) 1229 1230 1231def interval(unit: str, n: int = 1): 1232 from dateutil.relativedelta import relativedelta 1233 1234 if unit == "year": 1235 return relativedelta(years=1 * n) 1236 if unit == "quarter": 1237 return relativedelta(months=3 * n) 1238 if unit == "month": 1239 return relativedelta(months=1 * n) 1240 if unit == "week": 1241 return relativedelta(weeks=1 * n) 1242 if unit == "day": 1243 return relativedelta(days=1 * n) 1244 if unit == "hour": 1245 return relativedelta(hours=1 * n) 1246 if unit == "minute": 1247 return relativedelta(minutes=1 * n) 1248 if unit == "second": 1249 return relativedelta(seconds=1 * n) 1250 1251 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1252 1253 1254def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1255 if unit == "year": 1256 return d.replace(month=1, day=1) 1257 if unit == "quarter": 1258 if d.month <= 3: 1259 return d.replace(month=1, day=1) 1260 elif d.month <= 6: 1261 return d.replace(month=4, day=1) 1262 elif d.month <= 9: 1263 return d.replace(month=7, day=1) 1264 else: 1265 return d.replace(month=10, day=1) 1266 if unit == "month": 1267 return d.replace(month=d.month, day=1) 1268 if unit == "week": 1269 # Assuming week starts on Monday (0) and ends on Sunday (6) 1270 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1271 if unit == "day": 1272 return d 1273 1274 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1275 1276 1277def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1278 floor = date_floor(d, unit, dialect) 1279 1280 if floor == d: 1281 return d 1282 1283 return floor + interval(unit) 1284 1285 1286def boolean_literal(condition): 1287 return exp.true() if condition else exp.false() 1288 1289 1290def _flat_simplify(expression, simplifier, root=True): 1291 if root or not expression.same_parent: 1292 operands = [] 1293 queue = deque(expression.flatten(unnest=False)) 1294 size = len(queue) 1295 1296 while queue: 1297 a = queue.popleft() 1298 1299 for b in queue: 1300 result = simplifier(expression, a, b) 1301 1302 if result and result is not expression: 1303 queue.remove(b) 1304 queue.appendleft(result) 1305 break 1306 else: 1307 operands.append(a) 1308 1309 if len(operands) < size: 1310 return functools.reduce( 1311 lambda a, b: expression.__class__(this=a, expression=b), operands 1312 ) 1313 return expression 1314 1315 1316def gen(expression: t.Any) -> str: 1317 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1318 1319 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1320 generator is expensive so we have a bare minimum sql generator here. 1321 """ 1322 return Gen().gen(expression) 1323 1324 1325class Gen: 1326 def __init__(self): 1327 self.stack = [] 1328 self.sqls = [] 1329 1330 def gen(self, expression: exp.Expression) -> str: 1331 self.stack = [expression] 1332 self.sqls.clear() 1333 1334 while self.stack: 1335 node = self.stack.pop() 1336 1337 if isinstance(node, exp.Expression): 1338 exp_handler_name = f"{node.key}_sql" 1339 1340 if hasattr(self, exp_handler_name): 1341 getattr(self, exp_handler_name)(node) 1342 elif isinstance(node, exp.Func): 1343 self._function(node) 1344 else: 1345 key = node.key.upper() 1346 self.stack.append(f"{key} " if self._args(node) else key) 1347 elif type(node) is list: 1348 for n in reversed(node): 1349 if n is not None: 1350 self.stack.extend((n, ",")) 1351 if node: 1352 self.stack.pop() 1353 else: 1354 if node is not None: 1355 self.sqls.append(str(node)) 1356 1357 return "".join(self.sqls) 1358 1359 def add_sql(self, e: exp.Add) -> None: 1360 self._binary(e, " + ") 1361 1362 def alias_sql(self, e: exp.Alias) -> None: 1363 self.stack.extend( 1364 ( 1365 e.args.get("alias"), 1366 " AS ", 1367 e.args.get("this"), 1368 ) 1369 ) 1370 1371 def and_sql(self, e: exp.And) -> None: 1372 self._binary(e, " AND ") 1373 1374 def anonymous_sql(self, e: exp.Anonymous) -> None: 1375 this = e.this 1376 if isinstance(this, str): 1377 name = this.upper() 1378 elif isinstance(this, exp.Identifier): 1379 name = this.this 1380 name = f'"{name}"' if this.quoted else name.upper() 1381 else: 1382 raise ValueError( 1383 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1384 ) 1385 1386 self.stack.extend( 1387 ( 1388 ")", 1389 e.expressions, 1390 "(", 1391 name, 1392 ) 1393 ) 1394 1395 def between_sql(self, e: exp.Between) -> None: 1396 self.stack.extend( 1397 ( 1398 e.args.get("high"), 1399 " AND ", 1400 e.args.get("low"), 1401 " BETWEEN ", 1402 e.this, 1403 ) 1404 ) 1405 1406 def boolean_sql(self, e: exp.Boolean) -> None: 1407 self.stack.append("TRUE" if e.this else "FALSE") 1408 1409 def bracket_sql(self, e: exp.Bracket) -> None: 1410 self.stack.extend( 1411 ( 1412 "]", 1413 e.expressions, 1414 "[", 1415 e.this, 1416 ) 1417 ) 1418 1419 def column_sql(self, e: exp.Column) -> None: 1420 for p in reversed(e.parts): 1421 self.stack.extend((p, ".")) 1422 self.stack.pop() 1423 1424 def datatype_sql(self, e: exp.DataType) -> None: 1425 self._args(e, 1) 1426 self.stack.append(f"{e.this.name} ") 1427 1428 def div_sql(self, e: exp.Div) -> None: 1429 self._binary(e, " / ") 1430 1431 def dot_sql(self, e: exp.Dot) -> None: 1432 self._binary(e, ".") 1433 1434 def eq_sql(self, e: exp.EQ) -> None: 1435 self._binary(e, " = ") 1436 1437 def from_sql(self, e: exp.From) -> None: 1438 self.stack.extend((e.this, "FROM ")) 1439 1440 def gt_sql(self, e: exp.GT) -> None: 1441 self._binary(e, " > ") 1442 1443 def gte_sql(self, e: exp.GTE) -> None: 1444 self._binary(e, " >= ") 1445 1446 def identifier_sql(self, e: exp.Identifier) -> None: 1447 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1448 1449 def ilike_sql(self, e: exp.ILike) -> None: 1450 self._binary(e, " ILIKE ") 1451 1452 def in_sql(self, e: exp.In) -> None: 1453 self.stack.append(")") 1454 self._args(e, 1) 1455 self.stack.extend( 1456 ( 1457 "(", 1458 " IN ", 1459 e.this, 1460 ) 1461 ) 1462 1463 def intdiv_sql(self, e: exp.IntDiv) -> None: 1464 self._binary(e, " DIV ") 1465 1466 def is_sql(self, e: exp.Is) -> None: 1467 self._binary(e, " IS ") 1468 1469 def like_sql(self, e: exp.Like) -> None: 1470 self._binary(e, " Like ") 1471 1472 def literal_sql(self, e: exp.Literal) -> None: 1473 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1474 1475 def lt_sql(self, e: exp.LT) -> None: 1476 self._binary(e, " < ") 1477 1478 def lte_sql(self, e: exp.LTE) -> None: 1479 self._binary(e, " <= ") 1480 1481 def mod_sql(self, e: exp.Mod) -> None: 1482 self._binary(e, " % ") 1483 1484 def mul_sql(self, e: exp.Mul) -> None: 1485 self._binary(e, " * ") 1486 1487 def neg_sql(self, e: exp.Neg) -> None: 1488 self._unary(e, "-") 1489 1490 def neq_sql(self, e: exp.NEQ) -> None: 1491 self._binary(e, " <> ") 1492 1493 def not_sql(self, e: exp.Not) -> None: 1494 self._unary(e, "NOT ") 1495 1496 def null_sql(self, e: exp.Null) -> None: 1497 self.stack.append("NULL") 1498 1499 def or_sql(self, e: exp.Or) -> None: 1500 self._binary(e, " OR ") 1501 1502 def paren_sql(self, e: exp.Paren) -> None: 1503 self.stack.extend( 1504 ( 1505 ")", 1506 e.this, 1507 "(", 1508 ) 1509 ) 1510 1511 def sub_sql(self, e: exp.Sub) -> None: 1512 self._binary(e, " - ") 1513 1514 def subquery_sql(self, e: exp.Subquery) -> None: 1515 self._args(e, 2) 1516 alias = e.args.get("alias") 1517 if alias: 1518 self.stack.append(alias) 1519 self.stack.extend((")", e.this, "(")) 1520 1521 def table_sql(self, e: exp.Table) -> None: 1522 self._args(e, 4) 1523 alias = e.args.get("alias") 1524 if alias: 1525 self.stack.append(alias) 1526 for p in reversed(e.parts): 1527 self.stack.extend((p, ".")) 1528 self.stack.pop() 1529 1530 def tablealias_sql(self, e: exp.TableAlias) -> None: 1531 columns = e.columns 1532 1533 if columns: 1534 self.stack.extend((")", columns, "(")) 1535 1536 self.stack.extend((e.this, " AS ")) 1537 1538 def var_sql(self, e: exp.Var) -> None: 1539 self.stack.append(e.this) 1540 1541 def _binary(self, e: exp.Binary, op: str) -> None: 1542 self.stack.extend((e.expression, op, e.this)) 1543 1544 def _unary(self, e: exp.Unary, op: str) -> None: 1545 self.stack.extend((e.this, op)) 1546 1547 def _function(self, e: exp.Func) -> None: 1548 self.stack.extend( 1549 ( 1550 ")", 1551 list(e.args.values()), 1552 "(", 1553 e.sql_name(), 1554 ) 1555 ) 1556 1557 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1558 kvs = [] 1559 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1560 1561 for k in arg_types or arg_types: 1562 v = node.args.get(k) 1563 1564 if v is not None: 1565 kvs.append([f":{k}", v]) 1566 if kvs: 1567 self.stack.append(kvs) 1568 return True 1569 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
40def simplify( 41 expression: exp.Expression, 42 constant_propagation: bool = False, 43 dialect: DialectType = None, 44 max_depth: t.Optional[int] = None, 45): 46 """ 47 Rewrite sqlglot AST to simplify expressions. 48 49 Example: 50 >>> import sqlglot 51 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 52 >>> simplify(expression).sql() 53 'TRUE' 54 55 Args: 56 expression: expression to simplify 57 constant_propagation: whether the constant propagation rule should be used 58 max_depth: Chains of Connectors (AND, OR, etc) exceeding `max_depth` will be skipped 59 Returns: 60 sqlglot.Expression: simplified expression 61 """ 62 63 dialect = Dialect.get_or_raise(dialect) 64 65 def _simplify(expression, root=True): 66 if ( 67 max_depth 68 and isinstance(expression, exp.Connector) 69 and not isinstance(expression.parent, exp.Connector) 70 ): 71 depth = connector_depth(expression) 72 if depth > max_depth: 73 logger.info( 74 f"Skipping simplification because connector depth {depth} exceeds max {max_depth}" 75 ) 76 return expression 77 78 if expression.meta.get(FINAL): 79 return expression 80 81 # group by expressions cannot be simplified, for example 82 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 83 # the projection must exactly match the group by key 84 group = expression.args.get("group") 85 86 if group and hasattr(expression, "selects"): 87 groups = set(group.expressions) 88 group.meta[FINAL] = True 89 90 for e in expression.selects: 91 for node in e.walk(): 92 if node in groups: 93 e.meta[FINAL] = True 94 break 95 96 having = expression.args.get("having") 97 if having: 98 for node in having.walk(): 99 if node in groups: 100 having.meta[FINAL] = True 101 break 102 103 # Pre-order transformations 104 node = expression 105 node = rewrite_between(node) 106 node = uniq_sort(node, root) 107 node = absorb_and_eliminate(node, root) 108 node = simplify_concat(node) 109 node = simplify_conditionals(node) 110 111 if constant_propagation: 112 node = propagate_constants(node, root) 113 114 exp.replace_children(node, lambda e: _simplify(e, False)) 115 116 # Post-order transformations 117 node = simplify_not(node) 118 node = flatten(node) 119 node = simplify_connectors(node, root) 120 node = remove_complements(node, root) 121 node = simplify_coalesce(node) 122 node.parent = expression.parent 123 node = simplify_literals(node, root) 124 node = simplify_equality(node) 125 node = simplify_parens(node) 126 node = simplify_datetrunc(node, dialect) 127 node = sort_comparison(node) 128 node = simplify_startswith(node) 129 130 if root: 131 expression.replace(node) 132 return node 133 134 expression = while_changing(expression, _simplify) 135 remove_where_true(expression) 136 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression: expression to simplify
- constant_propagation: whether the constant propagation rule should be used
- max_depth: Chains of Connectors (AND, OR, etc) exceeding
max_depth
will be skipped
Returns:
sqlglot.Expression: simplified expression
139def connector_depth(expression: exp.Expression) -> int: 140 """ 141 Determine the maximum depth of a tree of Connectors. 142 143 For example: 144 >>> from sqlglot import parse_one 145 >>> connector_depth(parse_one("a AND b AND c AND d")) 146 3 147 """ 148 stack = deque([(expression, 0)]) 149 max_depth = 0 150 151 while stack: 152 expression, depth = stack.pop() 153 154 if not isinstance(expression, exp.Connector): 155 continue 156 157 depth += 1 158 max_depth = max(depth, max_depth) 159 160 stack.append((expression.left, depth)) 161 stack.append((expression.right, depth)) 162 163 return max_depth
Determine the maximum depth of a tree of Connectors.
For example:
>>> from sqlglot import parse_one >>> connector_depth(parse_one("a AND b AND c AND d")) 3
166def catch(*exceptions): 167 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 168 169 def decorator(func): 170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression 175 176 return wrapped 177 178 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
181def rewrite_between(expression: exp.Expression) -> exp.Expression: 182 """Rewrite x between y and z to x >= y AND x <= z. 183 184 This is done because comparison simplification is only done on lt/lte/gt/gte. 185 """ 186 if isinstance(expression, exp.Between): 187 negate = isinstance(expression.parent, exp.Not) 188 189 expression = exp.and_( 190 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 191 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 192 copy=False, 193 ) 194 195 if negate: 196 expression = exp.paren(expression, copy=False) 197 198 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
211def simplify_not(expression): 212 """ 213 Demorgan's Law 214 NOT (x OR y) -> NOT x AND NOT y 215 NOT (x AND y) -> NOT x OR NOT y 216 """ 217 if isinstance(expression, exp.Not): 218 this = expression.this 219 if is_null(this): 220 return exp.null() 221 if this.__class__ in COMPLEMENT_COMPARISONS: 222 return COMPLEMENT_COMPARISONS[this.__class__]( 223 this=this.this, expression=this.expression 224 ) 225 if isinstance(this, exp.Paren): 226 condition = this.unnest() 227 if isinstance(condition, exp.And): 228 return exp.paren( 229 exp.or_( 230 exp.not_(condition.left, copy=False), 231 exp.not_(condition.right, copy=False), 232 copy=False, 233 ) 234 ) 235 if isinstance(condition, exp.Or): 236 return exp.paren( 237 exp.and_( 238 exp.not_(condition.left, copy=False), 239 exp.not_(condition.right, copy=False), 240 copy=False, 241 ) 242 ) 243 if is_null(condition): 244 return exp.null() 245 if always_true(this): 246 return exp.false() 247 if is_false(this): 248 return exp.true() 249 if isinstance(this, exp.Not): 250 # double negation 251 # NOT NOT x -> x 252 return this.this 253 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
256def flatten(expression): 257 """ 258 A AND (B AND C) -> A AND B AND C 259 A OR (B OR C) -> A OR B OR C 260 """ 261 if isinstance(expression, exp.Connector): 262 for node in expression.args.values(): 263 child = node.unnest() 264 if isinstance(child, expression.__class__): 265 node.replace(child) 266 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
269def simplify_connectors(expression, root=True): 270 def _simplify_connectors(expression, left, right): 271 if left == right: 272 if isinstance(expression, exp.Xor): 273 return exp.false() 274 return left 275 if isinstance(expression, exp.And): 276 if is_false(left) or is_false(right): 277 return exp.false() 278 if is_null(left) or is_null(right): 279 return exp.null() 280 if always_true(left) and always_true(right): 281 return exp.true() 282 if always_true(left): 283 return right 284 if always_true(right): 285 return left 286 return _simplify_comparison(expression, left, right) 287 elif isinstance(expression, exp.Or): 288 if always_true(left) or always_true(right): 289 return exp.true() 290 if is_false(left) and is_false(right): 291 return exp.false() 292 if ( 293 (is_null(left) and is_null(right)) 294 or (is_null(left) and is_false(right)) 295 or (is_false(left) and is_null(right)) 296 ): 297 return exp.null() 298 if is_false(left): 299 return right 300 if is_false(right): 301 return left 302 return _simplify_comparison(expression, left, right, or_=True) 303 304 if isinstance(expression, exp.Connector): 305 return _flat_simplify(expression, _simplify_connectors, root) 306 return expression
393def remove_complements(expression, root=True): 394 """ 395 Removing complements. 396 397 A AND NOT A -> FALSE 398 A OR NOT A -> TRUE 399 """ 400 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 401 ops = set(expression.flatten()) 402 for op in ops: 403 if isinstance(op, exp.Not) and op.this in ops: 404 return exp.false() if isinstance(expression, exp.And) else exp.true() 405 406 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
409def uniq_sort(expression, root=True): 410 """ 411 Uniq and sort a connector. 412 413 C AND A AND B AND B -> A AND B AND C 414 """ 415 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 416 flattened = tuple(expression.flatten()) 417 418 if isinstance(expression, exp.Xor): 419 result_func = exp.xor 420 # Do not deduplicate XOR as A XOR A != A if A == True 421 deduped = None 422 arr = tuple((gen(e), e) for e in flattened) 423 else: 424 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 425 deduped = {gen(e): e for e in flattened} 426 arr = tuple(deduped.items()) 427 428 # check if the operands are already sorted, if not sort them 429 # A AND C AND B -> A AND B AND C 430 for i, (sql, e) in enumerate(arr[1:]): 431 if sql < arr[i][0]: 432 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 433 break 434 else: 435 # we didn't have to sort but maybe we need to dedup 436 if deduped and len(deduped) < len(flattened): 437 expression = result_func(*deduped.values(), copy=False) 438 439 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
442def absorb_and_eliminate(expression, root=True): 443 """ 444 absorption: 445 A AND (A OR B) -> A 446 A OR (A AND B) -> A 447 A AND (NOT A OR B) -> A AND B 448 A OR (NOT A AND B) -> A OR B 449 elimination: 450 (A AND B) OR (A AND NOT B) -> A 451 (A OR B) AND (A OR NOT B) -> A 452 """ 453 if isinstance(expression, AND_OR) and (root or not expression.same_parent): 454 kind = exp.Or if isinstance(expression, exp.And) else exp.And 455 456 ops = tuple(expression.flatten()) 457 458 # Initialize lookup tables: 459 # Set of all operands, used to find complements for absorption. 460 op_set = set() 461 # Sub-operands, used to find subsets for absorption. 462 subops = defaultdict(list) 463 # Pairs of complements, used for elimination. 464 pairs = defaultdict(list) 465 466 # Populate the lookup tables 467 for op in ops: 468 op_set.add(op) 469 470 if not isinstance(op, kind): 471 # In cases like: A OR (A AND B) 472 # Subop will be: ^ 473 subops[op].append({op}) 474 continue 475 476 # In cases like: (A AND B) OR (A AND B AND C) 477 # Subops will be: ^ ^ 478 subset = set(op.flatten()) 479 for i in subset: 480 subops[i].append(subset) 481 482 a, b = op.unnest_operands() 483 if isinstance(a, exp.Not): 484 pairs[frozenset((a.this, b))].append((op, b)) 485 if isinstance(b, exp.Not): 486 pairs[frozenset((a, b.this))].append((op, a)) 487 488 for op in ops: 489 if not isinstance(op, kind): 490 continue 491 492 a, b = op.unnest_operands() 493 494 # Absorb 495 if isinstance(a, exp.Not) and a.this in op_set: 496 a.replace(exp.true() if kind == exp.And else exp.false()) 497 continue 498 if isinstance(b, exp.Not) and b.this in op_set: 499 b.replace(exp.true() if kind == exp.And else exp.false()) 500 continue 501 superset = set(op.flatten()) 502 if any(any(subset < superset for subset in subops[i]) for i in superset): 503 op.replace(exp.false() if kind == exp.And else exp.true()) 504 continue 505 506 # Eliminate 507 for other, complement in pairs[frozenset((a, b))]: 508 op.replace(complement) 509 other.replace(complement) 510 511 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
514def propagate_constants(expression, root=True): 515 """ 516 Propagate constants for conjunctions in DNF: 517 518 SELECT * FROM t WHERE a = b AND b = 5 becomes 519 SELECT * FROM t WHERE a = 5 AND b = 5 520 521 Reference: https://www.sqlite.org/optoverview.html 522 """ 523 524 if ( 525 isinstance(expression, exp.And) 526 and (root or not expression.same_parent) 527 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 528 ): 529 constant_mapping = {} 530 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 531 if isinstance(expr, exp.EQ): 532 l, r = expr.left, expr.right 533 534 # TODO: create a helper that can be used to detect nested literal expressions such 535 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 536 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 537 constant_mapping[l] = (id(l), r) 538 539 if constant_mapping: 540 for column in find_all_in_scope(expression, exp.Column): 541 parent = column.parent 542 column_id, constant = constant_mapping.get(column) or (None, None) 543 if ( 544 column_id is not None 545 and id(column) != column_id 546 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 547 ): 548 column.replace(constant.copy()) 549 550 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
625def simplify_literals(expression, root=True): 626 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 627 return _flat_simplify(expression, _simplify_binary, root) 628 629 if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): 630 return expression.this.this 631 632 if type(expression) in INVERSE_DATE_OPS: 633 return _simplify_binary(expression, expression.this, expression.interval()) or expression 634 635 return expression
736def simplify_parens(expression): 737 if not isinstance(expression, exp.Paren): 738 return expression 739 740 this = expression.this 741 parent = expression.parent 742 parent_is_predicate = isinstance(parent, exp.Predicate) 743 744 if ( 745 not isinstance(this, exp.Select) 746 and not isinstance(parent, exp.SubqueryPredicate) 747 and ( 748 not isinstance(parent, (exp.Condition, exp.Binary)) 749 or isinstance(parent, exp.Paren) 750 or ( 751 not isinstance(this, exp.Binary) 752 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 753 ) 754 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 755 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 756 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 757 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 758 ) 759 ): 760 return this 761 return expression
772def simplify_coalesce(expression): 773 # COALESCE(x) -> x 774 if ( 775 isinstance(expression, exp.Coalesce) 776 and (not expression.expressions or _is_nonnull_constant(expression.this)) 777 # COALESCE is also used as a Spark partitioning hint 778 and not isinstance(expression.parent, exp.Hint) 779 ): 780 return expression.this 781 782 if not isinstance(expression, COMPARISONS): 783 return expression 784 785 if isinstance(expression.left, exp.Coalesce): 786 coalesce = expression.left 787 other = expression.right 788 elif isinstance(expression.right, exp.Coalesce): 789 coalesce = expression.right 790 other = expression.left 791 else: 792 return expression 793 794 # This transformation is valid for non-constants, 795 # but it really only does anything if they are both constants. 796 if not _is_constant(other): 797 return expression 798 799 # Find the first constant arg 800 for arg_index, arg in enumerate(coalesce.expressions): 801 if _is_constant(arg): 802 break 803 else: 804 return expression 805 806 coalesce.set("expressions", coalesce.expressions[:arg_index]) 807 808 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 809 # since we already remove COALESCE at the top of this function. 810 coalesce = coalesce if coalesce.expressions else coalesce.this 811 812 # This expression is more complex than when we started, but it will get simplified further 813 return exp.paren( 814 exp.or_( 815 exp.and_( 816 coalesce.is_(exp.null()).not_(copy=False), 817 expression.copy(), 818 copy=False, 819 ), 820 exp.and_( 821 coalesce.is_(exp.null()), 822 type(expression)(this=arg.copy(), expression=other.copy()), 823 copy=False, 824 ), 825 copy=False, 826 ) 827 )
833def simplify_concat(expression): 834 """Reduces all groups that contain string literals by concatenating them.""" 835 if not isinstance(expression, CONCATS) or ( 836 # We can't reduce a CONCAT_WS call if we don't statically know the separator 837 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 838 ): 839 return expression 840 841 if isinstance(expression, exp.ConcatWs): 842 sep_expr, *expressions = expression.expressions 843 sep = sep_expr.name 844 concat_type = exp.ConcatWs 845 args = {} 846 else: 847 expressions = expression.expressions 848 sep = "" 849 concat_type = exp.Concat 850 args = { 851 "safe": expression.args.get("safe"), 852 "coalesce": expression.args.get("coalesce"), 853 } 854 855 new_args = [] 856 for is_string_group, group in itertools.groupby( 857 expressions or expression.flatten(), lambda e: e.is_string 858 ): 859 if is_string_group: 860 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 861 else: 862 new_args.extend(group) 863 864 if len(new_args) == 1 and new_args[0].is_string: 865 return new_args[0] 866 867 if concat_type is exp.ConcatWs: 868 new_args = [sep_expr] + new_args 869 elif isinstance(expression, exp.DPipe): 870 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 871 872 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
875def simplify_conditionals(expression): 876 """Simplifies expressions like IF, CASE if their condition is statically known.""" 877 if isinstance(expression, exp.Case): 878 this = expression.this 879 for case in expression.args["ifs"]: 880 cond = case.this 881 if this: 882 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 883 cond = cond.replace(this.pop().eq(cond)) 884 885 if always_true(cond): 886 return case.args["true"] 887 888 if always_false(cond): 889 case.pop() 890 if not expression.args["ifs"]: 891 return expression.args.get("default") or exp.null() 892 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 893 if always_true(expression.this): 894 return expression.args["true"] 895 if always_false(expression.this): 896 return expression.args.get("false") or exp.null() 897 898 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
901def simplify_startswith(expression: exp.Expression) -> exp.Expression: 902 """ 903 Reduces a prefix check to either TRUE or FALSE if both the string and the 904 prefix are statically known. 905 906 Example: 907 >>> from sqlglot import parse_one 908 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 909 'TRUE' 910 """ 911 if ( 912 isinstance(expression, exp.StartsWith) 913 and expression.this.is_string 914 and expression.expression.is_string 915 ): 916 return exp.convert(expression.name.startswith(expression.expression.name)) 917 918 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
170 def wrapped(expression, *args, **kwargs): 171 try: 172 return func(expression, *args, **kwargs) 173 except exceptions: 174 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
1067def sort_comparison(expression: exp.Expression) -> exp.Expression: 1068 if expression.__class__ in COMPLEMENT_COMPARISONS: 1069 l, r = expression.this, expression.expression 1070 l_column = isinstance(l, exp.Column) 1071 r_column = isinstance(r, exp.Column) 1072 l_const = _is_constant(l) 1073 r_const = _is_constant(r) 1074 1075 if (l_column and not r_column) or (r_const and not l_const): 1076 return expression 1077 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1078 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1079 this=r, expression=l 1080 ) 1081 return expression
1095def remove_where_true(expression): 1096 for where in expression.find_all(exp.Where): 1097 if always_true(where.this): 1098 where.pop() 1099 for join in expression.find_all(exp.Join): 1100 if ( 1101 always_true(join.args.get("on")) 1102 and not join.args.get("using") 1103 and not join.args.get("method") 1104 and (join.side, join.kind) in JOINS 1105 ): 1106 join.args["on"].pop() 1107 join.set("side", None) 1108 join.set("kind", "CROSS")
1133def eval_boolean(expression, a, b): 1134 if isinstance(expression, (exp.EQ, exp.Is)): 1135 return boolean_literal(a == b) 1136 if isinstance(expression, exp.NEQ): 1137 return boolean_literal(a != b) 1138 if isinstance(expression, exp.GT): 1139 return boolean_literal(a > b) 1140 if isinstance(expression, exp.GTE): 1141 return boolean_literal(a >= b) 1142 if isinstance(expression, exp.LT): 1143 return boolean_literal(a < b) 1144 if isinstance(expression, exp.LTE): 1145 return boolean_literal(a <= b) 1146 return None
1149def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1150 if isinstance(value, datetime.datetime): 1151 return value.date() 1152 if isinstance(value, datetime.date): 1153 return value 1154 try: 1155 return datetime.datetime.fromisoformat(value).date() 1156 except ValueError: 1157 return None
1160def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1161 if isinstance(value, datetime.datetime): 1162 return value 1163 if isinstance(value, datetime.date): 1164 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1165 try: 1166 return datetime.datetime.fromisoformat(value) 1167 except ValueError: 1168 return None
1171def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1172 if not value: 1173 return None 1174 if to.is_type(exp.DataType.Type.DATE): 1175 return cast_as_date(value) 1176 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1177 return cast_as_datetime(value) 1178 return None
1181def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1182 if isinstance(cast, exp.Cast): 1183 to = cast.to 1184 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1185 to = exp.DataType.build(exp.DataType.Type.DATE) 1186 else: 1187 return None 1188 1189 if isinstance(cast.this, exp.Literal): 1190 value: t.Any = cast.this.name 1191 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1192 value = extract_date(cast.this) 1193 else: 1194 return None 1195 return cast_value(value, to)
1221def date_literal(date, target_type=None): 1222 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1223 target_type = ( 1224 exp.DataType.Type.DATETIME 1225 if isinstance(date, datetime.datetime) 1226 else exp.DataType.Type.DATE 1227 ) 1228 1229 return exp.cast(exp.Literal.string(date), target_type)
1232def interval(unit: str, n: int = 1): 1233 from dateutil.relativedelta import relativedelta 1234 1235 if unit == "year": 1236 return relativedelta(years=1 * n) 1237 if unit == "quarter": 1238 return relativedelta(months=3 * n) 1239 if unit == "month": 1240 return relativedelta(months=1 * n) 1241 if unit == "week": 1242 return relativedelta(weeks=1 * n) 1243 if unit == "day": 1244 return relativedelta(days=1 * n) 1245 if unit == "hour": 1246 return relativedelta(hours=1 * n) 1247 if unit == "minute": 1248 return relativedelta(minutes=1 * n) 1249 if unit == "second": 1250 return relativedelta(seconds=1 * n) 1251 1252 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1255def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1256 if unit == "year": 1257 return d.replace(month=1, day=1) 1258 if unit == "quarter": 1259 if d.month <= 3: 1260 return d.replace(month=1, day=1) 1261 elif d.month <= 6: 1262 return d.replace(month=4, day=1) 1263 elif d.month <= 9: 1264 return d.replace(month=7, day=1) 1265 else: 1266 return d.replace(month=10, day=1) 1267 if unit == "month": 1268 return d.replace(month=d.month, day=1) 1269 if unit == "week": 1270 # Assuming week starts on Monday (0) and ends on Sunday (6) 1271 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1272 if unit == "day": 1273 return d 1274 1275 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1317def gen(expression: t.Any) -> str: 1318 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1319 1320 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1321 generator is expensive so we have a bare minimum sql generator here. 1322 """ 1323 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1326class Gen: 1327 def __init__(self): 1328 self.stack = [] 1329 self.sqls = [] 1330 1331 def gen(self, expression: exp.Expression) -> str: 1332 self.stack = [expression] 1333 self.sqls.clear() 1334 1335 while self.stack: 1336 node = self.stack.pop() 1337 1338 if isinstance(node, exp.Expression): 1339 exp_handler_name = f"{node.key}_sql" 1340 1341 if hasattr(self, exp_handler_name): 1342 getattr(self, exp_handler_name)(node) 1343 elif isinstance(node, exp.Func): 1344 self._function(node) 1345 else: 1346 key = node.key.upper() 1347 self.stack.append(f"{key} " if self._args(node) else key) 1348 elif type(node) is list: 1349 for n in reversed(node): 1350 if n is not None: 1351 self.stack.extend((n, ",")) 1352 if node: 1353 self.stack.pop() 1354 else: 1355 if node is not None: 1356 self.sqls.append(str(node)) 1357 1358 return "".join(self.sqls) 1359 1360 def add_sql(self, e: exp.Add) -> None: 1361 self._binary(e, " + ") 1362 1363 def alias_sql(self, e: exp.Alias) -> None: 1364 self.stack.extend( 1365 ( 1366 e.args.get("alias"), 1367 " AS ", 1368 e.args.get("this"), 1369 ) 1370 ) 1371 1372 def and_sql(self, e: exp.And) -> None: 1373 self._binary(e, " AND ") 1374 1375 def anonymous_sql(self, e: exp.Anonymous) -> None: 1376 this = e.this 1377 if isinstance(this, str): 1378 name = this.upper() 1379 elif isinstance(this, exp.Identifier): 1380 name = this.this 1381 name = f'"{name}"' if this.quoted else name.upper() 1382 else: 1383 raise ValueError( 1384 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1385 ) 1386 1387 self.stack.extend( 1388 ( 1389 ")", 1390 e.expressions, 1391 "(", 1392 name, 1393 ) 1394 ) 1395 1396 def between_sql(self, e: exp.Between) -> None: 1397 self.stack.extend( 1398 ( 1399 e.args.get("high"), 1400 " AND ", 1401 e.args.get("low"), 1402 " BETWEEN ", 1403 e.this, 1404 ) 1405 ) 1406 1407 def boolean_sql(self, e: exp.Boolean) -> None: 1408 self.stack.append("TRUE" if e.this else "FALSE") 1409 1410 def bracket_sql(self, e: exp.Bracket) -> None: 1411 self.stack.extend( 1412 ( 1413 "]", 1414 e.expressions, 1415 "[", 1416 e.this, 1417 ) 1418 ) 1419 1420 def column_sql(self, e: exp.Column) -> None: 1421 for p in reversed(e.parts): 1422 self.stack.extend((p, ".")) 1423 self.stack.pop() 1424 1425 def datatype_sql(self, e: exp.DataType) -> None: 1426 self._args(e, 1) 1427 self.stack.append(f"{e.this.name} ") 1428 1429 def div_sql(self, e: exp.Div) -> None: 1430 self._binary(e, " / ") 1431 1432 def dot_sql(self, e: exp.Dot) -> None: 1433 self._binary(e, ".") 1434 1435 def eq_sql(self, e: exp.EQ) -> None: 1436 self._binary(e, " = ") 1437 1438 def from_sql(self, e: exp.From) -> None: 1439 self.stack.extend((e.this, "FROM ")) 1440 1441 def gt_sql(self, e: exp.GT) -> None: 1442 self._binary(e, " > ") 1443 1444 def gte_sql(self, e: exp.GTE) -> None: 1445 self._binary(e, " >= ") 1446 1447 def identifier_sql(self, e: exp.Identifier) -> None: 1448 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1449 1450 def ilike_sql(self, e: exp.ILike) -> None: 1451 self._binary(e, " ILIKE ") 1452 1453 def in_sql(self, e: exp.In) -> None: 1454 self.stack.append(")") 1455 self._args(e, 1) 1456 self.stack.extend( 1457 ( 1458 "(", 1459 " IN ", 1460 e.this, 1461 ) 1462 ) 1463 1464 def intdiv_sql(self, e: exp.IntDiv) -> None: 1465 self._binary(e, " DIV ") 1466 1467 def is_sql(self, e: exp.Is) -> None: 1468 self._binary(e, " IS ") 1469 1470 def like_sql(self, e: exp.Like) -> None: 1471 self._binary(e, " Like ") 1472 1473 def literal_sql(self, e: exp.Literal) -> None: 1474 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1475 1476 def lt_sql(self, e: exp.LT) -> None: 1477 self._binary(e, " < ") 1478 1479 def lte_sql(self, e: exp.LTE) -> None: 1480 self._binary(e, " <= ") 1481 1482 def mod_sql(self, e: exp.Mod) -> None: 1483 self._binary(e, " % ") 1484 1485 def mul_sql(self, e: exp.Mul) -> None: 1486 self._binary(e, " * ") 1487 1488 def neg_sql(self, e: exp.Neg) -> None: 1489 self._unary(e, "-") 1490 1491 def neq_sql(self, e: exp.NEQ) -> None: 1492 self._binary(e, " <> ") 1493 1494 def not_sql(self, e: exp.Not) -> None: 1495 self._unary(e, "NOT ") 1496 1497 def null_sql(self, e: exp.Null) -> None: 1498 self.stack.append("NULL") 1499 1500 def or_sql(self, e: exp.Or) -> None: 1501 self._binary(e, " OR ") 1502 1503 def paren_sql(self, e: exp.Paren) -> None: 1504 self.stack.extend( 1505 ( 1506 ")", 1507 e.this, 1508 "(", 1509 ) 1510 ) 1511 1512 def sub_sql(self, e: exp.Sub) -> None: 1513 self._binary(e, " - ") 1514 1515 def subquery_sql(self, e: exp.Subquery) -> None: 1516 self._args(e, 2) 1517 alias = e.args.get("alias") 1518 if alias: 1519 self.stack.append(alias) 1520 self.stack.extend((")", e.this, "(")) 1521 1522 def table_sql(self, e: exp.Table) -> None: 1523 self._args(e, 4) 1524 alias = e.args.get("alias") 1525 if alias: 1526 self.stack.append(alias) 1527 for p in reversed(e.parts): 1528 self.stack.extend((p, ".")) 1529 self.stack.pop() 1530 1531 def tablealias_sql(self, e: exp.TableAlias) -> None: 1532 columns = e.columns 1533 1534 if columns: 1535 self.stack.extend((")", columns, "(")) 1536 1537 self.stack.extend((e.this, " AS ")) 1538 1539 def var_sql(self, e: exp.Var) -> None: 1540 self.stack.append(e.this) 1541 1542 def _binary(self, e: exp.Binary, op: str) -> None: 1543 self.stack.extend((e.expression, op, e.this)) 1544 1545 def _unary(self, e: exp.Unary, op: str) -> None: 1546 self.stack.extend((e.this, op)) 1547 1548 def _function(self, e: exp.Func) -> None: 1549 self.stack.extend( 1550 ( 1551 ")", 1552 list(e.args.values()), 1553 "(", 1554 e.sql_name(), 1555 ) 1556 ) 1557 1558 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1559 kvs = [] 1560 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1561 1562 for k in arg_types or arg_types: 1563 v = node.args.get(k) 1564 1565 if v is not None: 1566 kvs.append([f":{k}", v]) 1567 if kvs: 1568 self.stack.append(kvs) 1569 return True 1570 return False
1331 def gen(self, expression: exp.Expression) -> str: 1332 self.stack = [expression] 1333 self.sqls.clear() 1334 1335 while self.stack: 1336 node = self.stack.pop() 1337 1338 if isinstance(node, exp.Expression): 1339 exp_handler_name = f"{node.key}_sql" 1340 1341 if hasattr(self, exp_handler_name): 1342 getattr(self, exp_handler_name)(node) 1343 elif isinstance(node, exp.Func): 1344 self._function(node) 1345 else: 1346 key = node.key.upper() 1347 self.stack.append(f"{key} " if self._args(node) else key) 1348 elif type(node) is list: 1349 for n in reversed(node): 1350 if n is not None: 1351 self.stack.extend((n, ",")) 1352 if node: 1353 self.stack.pop() 1354 else: 1355 if node is not None: 1356 self.sqls.append(str(node)) 1357 1358 return "".join(self.sqls)
1375 def anonymous_sql(self, e: exp.Anonymous) -> None: 1376 this = e.this 1377 if isinstance(this, str): 1378 name = this.upper() 1379 elif isinstance(this, exp.Identifier): 1380 name = this.this 1381 name = f'"{name}"' if this.quoted else name.upper() 1382 else: 1383 raise ValueError( 1384 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1385 ) 1386 1387 self.stack.extend( 1388 ( 1389 ")", 1390 e.expressions, 1391 "(", 1392 name, 1393 ) 1394 )