sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.helper import while_changing 8from sqlglot.optimizer.scope import find_all_in_scope 9from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 15 """ 16 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 17 18 Example: 19 >>> import sqlglot 20 >>> expression = sqlglot.parse_one("(x AND y) OR z") 21 >>> normalize(expression, dnf=False).sql() 22 '(x OR z) AND (y OR z)' 23 24 Args: 25 expression: expression to normalize 26 dnf: rewrite in disjunctive normal form instead. 27 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 28 Returns: 29 sqlglot.Expression: normalized expression 30 """ 31 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 32 if isinstance(node, exp.Connector): 33 if normalized(node, dnf=dnf): 34 continue 35 root = node is expression 36 original = node.copy() 37 38 node.transform(rewrite_between, copy=False) 39 distance = normalization_distance(node, dnf=dnf, max_=max_distance) 40 41 if distance > max_distance: 42 logger.info( 43 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 44 ) 45 return expression 46 47 try: 48 node = node.replace( 49 while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) 50 ) 51 except OptimizeError as e: 52 logger.info(e) 53 node.replace(original) 54 if root: 55 return original 56 return expression 57 58 if root: 59 expression = node 60 61 return expression 62 63 64def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 65 """ 66 Checks whether a given expression is in a normal form of interest. 67 68 Example: 69 >>> from sqlglot import parse_one 70 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 71 True 72 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 73 True 74 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 75 False 76 77 Args: 78 expression: The expression to check if it's normalized. 79 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 80 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 81 """ 82 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 83 return not any( 84 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 85 ) 86 87 88def normalization_distance( 89 expression: exp.Expression, dnf: bool = False, max_: float = float("inf") 90) -> int: 91 """ 92 The difference in the number of predicates between a given expression and its normalized form. 93 94 This is used as an estimate of the cost of the conversion which is exponential in complexity. 95 96 Example: 97 >>> import sqlglot 98 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 99 >>> normalization_distance(expression) 100 4 101 102 Args: 103 expression: The expression to compute the normalization distance for. 104 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 105 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 106 max_: stop early if count exceeds this. 107 108 Returns: 109 The normalization distance. 110 """ 111 total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) 112 113 for length in _predicate_lengths(expression, dnf, max_): 114 total += length 115 if total > max_: 116 return total 117 118 return total 119 120 121def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0): 122 """ 123 Returns a list of predicate lengths when expanded to normalized form. 124 125 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 126 """ 127 if depth > max_: 128 yield depth 129 return 130 131 expression = expression.unnest() 132 133 if not isinstance(expression, exp.Connector): 134 yield 1 135 return 136 137 depth += 1 138 left, right = expression.args.values() 139 140 if isinstance(expression, exp.And if dnf else exp.Or): 141 for a in _predicate_lengths(left, dnf, max_, depth): 142 for b in _predicate_lengths(right, dnf, max_, depth): 143 yield a + b 144 else: 145 yield from _predicate_lengths(left, dnf, max_, depth) 146 yield from _predicate_lengths(right, dnf, max_, depth) 147 148 149def distributive_law(expression, dnf, max_distance): 150 """ 151 x OR (y AND z) -> (x OR y) AND (x OR z) 152 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 153 """ 154 if normalized(expression, dnf=dnf): 155 return expression 156 157 distance = normalization_distance(expression, dnf=dnf, max_=max_distance) 158 159 if distance > max_distance: 160 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 161 162 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 163 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 164 165 if isinstance(expression, from_exp): 166 a, b = expression.unnest_operands() 167 168 from_func = exp.and_ if from_exp == exp.And else exp.or_ 169 to_func = exp.and_ if to_exp == exp.And else exp.or_ 170 171 if isinstance(a, to_exp) and isinstance(b, to_exp): 172 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 173 return _distribute(a, b, from_func, to_func) 174 return _distribute(b, a, from_func, to_func) 175 if isinstance(a, to_exp): 176 return _distribute(b, a, from_func, to_func) 177 if isinstance(b, to_exp): 178 return _distribute(a, b, from_func, to_func) 179 180 return expression 181 182 183def _distribute(a, b, from_func, to_func): 184 if isinstance(a, exp.Connector): 185 exp.replace_children( 186 a, 187 lambda c: to_func( 188 uniq_sort(flatten(from_func(c, b.left))), 189 uniq_sort(flatten(from_func(c, b.right))), 190 copy=False, 191 ), 192 ) 193 else: 194 a = to_func( 195 uniq_sort(flatten(from_func(a, b.left))), 196 uniq_sort(flatten(from_func(a, b.right))), 197 copy=False, 198 ) 199 200 return a
logger =
<Logger sqlglot (WARNING)>
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 33 if isinstance(node, exp.Connector): 34 if normalized(node, dnf=dnf): 35 continue 36 root = node is expression 37 original = node.copy() 38 39 node.transform(rewrite_between, copy=False) 40 distance = normalization_distance(node, dnf=dnf, max_=max_distance) 41 42 if distance > max_distance: 43 logger.info( 44 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 45 ) 46 return expression 47 48 try: 49 node = node.replace( 50 while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) 51 ) 52 except OptimizeError as e: 53 logger.info(e) 54 node.replace(original) 55 if root: 56 return original 57 return expression 58 59 if root: 60 expression = node 61 62 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
65def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 66 """ 67 Checks whether a given expression is in a normal form of interest. 68 69 Example: 70 >>> from sqlglot import parse_one 71 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 72 True 73 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 74 True 75 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 76 False 77 78 Args: 79 expression: The expression to check if it's normalized. 80 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 81 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 82 """ 83 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 84 return not any( 85 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 86 )
Checks whether a given expression is in a normal form of interest.
Example:
>>> from sqlglot import parse_one >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) True >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default True >>> normalized(parse_one("a AND (b OR c)"), dnf=True) False
Arguments:
- expression: The expression to check if it's normalized.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def
normalization_distance( expression: sqlglot.expressions.Expression, dnf: bool = False, max_: float = inf) -> int:
89def normalization_distance( 90 expression: exp.Expression, dnf: bool = False, max_: float = float("inf") 91) -> int: 92 """ 93 The difference in the number of predicates between a given expression and its normalized form. 94 95 This is used as an estimate of the cost of the conversion which is exponential in complexity. 96 97 Example: 98 >>> import sqlglot 99 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 100 >>> normalization_distance(expression) 101 4 102 103 Args: 104 expression: The expression to compute the normalization distance for. 105 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 106 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 107 max_: stop early if count exceeds this. 108 109 Returns: 110 The normalization distance. 111 """ 112 total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) 113 114 for length in _predicate_lengths(expression, dnf, max_): 115 total += length 116 if total > max_: 117 return total 118 119 return total
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression: The expression to compute the normalization distance for.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
- max_: stop early if count exceeds this.
Returns:
The normalization distance.
def
distributive_law(expression, dnf, max_distance):
150def distributive_law(expression, dnf, max_distance): 151 """ 152 x OR (y AND z) -> (x OR y) AND (x OR z) 153 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 154 """ 155 if normalized(expression, dnf=dnf): 156 return expression 157 158 distance = normalization_distance(expression, dnf=dnf, max_=max_distance) 159 160 if distance > max_distance: 161 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 162 163 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 164 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 165 166 if isinstance(expression, from_exp): 167 a, b = expression.unnest_operands() 168 169 from_func = exp.and_ if from_exp == exp.And else exp.or_ 170 to_func = exp.and_ if to_exp == exp.And else exp.or_ 171 172 if isinstance(a, to_exp) and isinstance(b, to_exp): 173 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 174 return _distribute(a, b, from_func, to_func) 175 return _distribute(b, a, from_func, to_func) 176 if isinstance(a, to_exp): 177 return _distribute(b, a, from_func, to_func) 178 if isinstance(b, to_exp): 179 return _distribute(a, b, from_func, to_func) 180 181 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)