sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.helper import while_changing 8from sqlglot.optimizer.scope import find_all_in_scope 9from sqlglot.optimizer.simplify import Simplifier, flatten 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize( 15 expression: exp.Expression, dnf: bool = False, max_distance: int = 128 16) -> exp.Expression: 17 """ 18 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("(x AND y) OR z") 23 >>> normalize(expression, dnf=False).sql() 24 '(x OR z) AND (y OR z)' 25 26 Args: 27 expression: expression to normalize 28 dnf: rewrite in disjunctive normal form instead. 29 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 30 Returns: 31 sqlglot.Expression: normalized expression 32 """ 33 simplifier = Simplifier(annotate_new_expressions=False) 34 35 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 36 if isinstance(node, exp.Connector): 37 if normalized(node, dnf=dnf): 38 continue 39 root = node is expression 40 original = node.copy() 41 42 node.transform(simplifier.rewrite_between, copy=False) 43 distance = normalization_distance(node, dnf=dnf, max_=max_distance) 44 45 if distance > max_distance: 46 logger.info( 47 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 48 ) 49 return expression 50 51 try: 52 node = node.replace( 53 while_changing( 54 node, 55 lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier), 56 ) 57 ) 58 except OptimizeError as e: 59 logger.info(e) 60 node.replace(original) 61 if root: 62 return original 63 return expression 64 65 if root: 66 expression = node 67 68 return expression 69 70 71def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 72 """ 73 Checks whether a given expression is in a normal form of interest. 74 75 Example: 76 >>> from sqlglot import parse_one 77 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 78 True 79 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 80 True 81 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 82 False 83 84 Args: 85 expression: The expression to check if it's normalized. 86 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 87 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 88 """ 89 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 90 return not any( 91 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 92 ) 93 94 95def normalization_distance( 96 expression: exp.Expression, dnf: bool = False, max_: float = float("inf") 97) -> int: 98 """ 99 The difference in the number of predicates between a given expression and its normalized form. 100 101 This is used as an estimate of the cost of the conversion which is exponential in complexity. 102 103 Example: 104 >>> import sqlglot 105 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 106 >>> normalization_distance(expression) 107 4 108 109 Args: 110 expression: The expression to compute the normalization distance for. 111 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 112 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 113 max_: stop early if count exceeds this. 114 115 Returns: 116 The normalization distance. 117 """ 118 total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) 119 120 for length in _predicate_lengths(expression, dnf, max_): 121 total += length 122 if total > max_: 123 return total 124 125 return total 126 127 128def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0): 129 """ 130 Returns a list of predicate lengths when expanded to normalized form. 131 132 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 133 """ 134 if depth > max_: 135 yield depth 136 return 137 138 expression = expression.unnest() 139 140 if not isinstance(expression, exp.Connector): 141 yield 1 142 return 143 144 depth += 1 145 left, right = expression.args.values() 146 147 if isinstance(expression, exp.And if dnf else exp.Or): 148 for a in _predicate_lengths(left, dnf, max_, depth): 149 for b in _predicate_lengths(right, dnf, max_, depth): 150 yield a + b 151 else: 152 yield from _predicate_lengths(left, dnf, max_, depth) 153 yield from _predicate_lengths(right, dnf, max_, depth) 154 155 156def distributive_law(expression, dnf, max_distance, simplifier=None): 157 """ 158 x OR (y AND z) -> (x OR y) AND (x OR z) 159 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 160 """ 161 if normalized(expression, dnf=dnf): 162 return expression 163 164 distance = normalization_distance(expression, dnf=dnf, max_=max_distance) 165 166 if distance > max_distance: 167 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 168 169 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 170 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 171 172 if isinstance(expression, from_exp): 173 a, b = expression.unnest_operands() 174 175 from_func = exp.and_ if from_exp == exp.And else exp.or_ 176 to_func = exp.and_ if to_exp == exp.And else exp.or_ 177 178 simplifier = simplifier or Simplifier(annotate_new_expressions=False) 179 180 if isinstance(a, to_exp) and isinstance(b, to_exp): 181 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 182 return _distribute(a, b, from_func, to_func, simplifier) 183 return _distribute(b, a, from_func, to_func, simplifier) 184 if isinstance(a, to_exp): 185 return _distribute(b, a, from_func, to_func, simplifier) 186 if isinstance(b, to_exp): 187 return _distribute(a, b, from_func, to_func, simplifier) 188 189 return expression 190 191 192def _distribute(a, b, from_func, to_func, simplifier): 193 if isinstance(a, exp.Connector): 194 exp.replace_children( 195 a, 196 lambda c: to_func( 197 simplifier.uniq_sort(flatten(from_func(c, b.left))), 198 simplifier.uniq_sort(flatten(from_func(c, b.right))), 199 copy=False, 200 ), 201 ) 202 else: 203 a = to_func( 204 simplifier.uniq_sort(flatten(from_func(a, b.left))), 205 simplifier.uniq_sort(flatten(from_func(a, b.right))), 206 copy=False, 207 ) 208 209 return a
logger =
<Logger sqlglot (WARNING)>
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128) -> sqlglot.expressions.Expression:
15def normalize( 16 expression: exp.Expression, dnf: bool = False, max_distance: int = 128 17) -> exp.Expression: 18 """ 19 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 20 21 Example: 22 >>> import sqlglot 23 >>> expression = sqlglot.parse_one("(x AND y) OR z") 24 >>> normalize(expression, dnf=False).sql() 25 '(x OR z) AND (y OR z)' 26 27 Args: 28 expression: expression to normalize 29 dnf: rewrite in disjunctive normal form instead. 30 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 31 Returns: 32 sqlglot.Expression: normalized expression 33 """ 34 simplifier = Simplifier(annotate_new_expressions=False) 35 36 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 37 if isinstance(node, exp.Connector): 38 if normalized(node, dnf=dnf): 39 continue 40 root = node is expression 41 original = node.copy() 42 43 node.transform(simplifier.rewrite_between, copy=False) 44 distance = normalization_distance(node, dnf=dnf, max_=max_distance) 45 46 if distance > max_distance: 47 logger.info( 48 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 49 ) 50 return expression 51 52 try: 53 node = node.replace( 54 while_changing( 55 node, 56 lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier), 57 ) 58 ) 59 except OptimizeError as e: 60 logger.info(e) 61 node.replace(original) 62 if root: 63 return original 64 return expression 65 66 if root: 67 expression = node 68 69 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
72def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 73 """ 74 Checks whether a given expression is in a normal form of interest. 75 76 Example: 77 >>> from sqlglot import parse_one 78 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 79 True 80 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 81 True 82 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 83 False 84 85 Args: 86 expression: The expression to check if it's normalized. 87 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 88 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 89 """ 90 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 91 return not any( 92 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 93 )
Checks whether a given expression is in a normal form of interest.
Example:
>>> from sqlglot import parse_one >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) True >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default True >>> normalized(parse_one("a AND (b OR c)"), dnf=True) False
Arguments:
- expression: The expression to check if it's normalized.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def
normalization_distance( expression: sqlglot.expressions.Expression, dnf: bool = False, max_: float = inf) -> int:
96def normalization_distance( 97 expression: exp.Expression, dnf: bool = False, max_: float = float("inf") 98) -> int: 99 """ 100 The difference in the number of predicates between a given expression and its normalized form. 101 102 This is used as an estimate of the cost of the conversion which is exponential in complexity. 103 104 Example: 105 >>> import sqlglot 106 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 107 >>> normalization_distance(expression) 108 4 109 110 Args: 111 expression: The expression to compute the normalization distance for. 112 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 113 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 114 max_: stop early if count exceeds this. 115 116 Returns: 117 The normalization distance. 118 """ 119 total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) 120 121 for length in _predicate_lengths(expression, dnf, max_): 122 total += length 123 if total > max_: 124 return total 125 126 return total
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression: The expression to compute the normalization distance for.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
- max_: stop early if count exceeds this.
Returns:
The normalization distance.
def
distributive_law(expression, dnf, max_distance, simplifier=None):
157def distributive_law(expression, dnf, max_distance, simplifier=None): 158 """ 159 x OR (y AND z) -> (x OR y) AND (x OR z) 160 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 161 """ 162 if normalized(expression, dnf=dnf): 163 return expression 164 165 distance = normalization_distance(expression, dnf=dnf, max_=max_distance) 166 167 if distance > max_distance: 168 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 169 170 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 171 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 172 173 if isinstance(expression, from_exp): 174 a, b = expression.unnest_operands() 175 176 from_func = exp.and_ if from_exp == exp.And else exp.or_ 177 to_func = exp.and_ if to_exp == exp.And else exp.or_ 178 179 simplifier = simplifier or Simplifier(annotate_new_expressions=False) 180 181 if isinstance(a, to_exp) and isinstance(b, to_exp): 182 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 183 return _distribute(a, b, from_func, to_func, simplifier) 184 return _distribute(b, a, from_func, to_func, simplifier) 185 if isinstance(a, to_exp): 186 return _distribute(b, a, from_func, to_func, simplifier) 187 if isinstance(b, to_exp): 188 return _distribute(a, b, from_func, to_func, simplifier) 189 190 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)