sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.helper import while_changing 8from sqlglot.optimizer.scope import find_all_in_scope 9from sqlglot.optimizer.simplify import Simplifier, flatten 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 15 """ 16 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 17 18 Example: 19 >>> import sqlglot 20 >>> expression = sqlglot.parse_one("(x AND y) OR z") 21 >>> normalize(expression, dnf=False).sql() 22 '(x OR z) AND (y OR z)' 23 24 Args: 25 expression: expression to normalize 26 dnf: rewrite in disjunctive normal form instead. 27 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 28 Returns: 29 sqlglot.Expression: normalized expression 30 """ 31 simplifier = Simplifier(annotate_new_expressions=False) 32 33 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 34 if isinstance(node, exp.Connector): 35 if normalized(node, dnf=dnf): 36 continue 37 root = node is expression 38 original = node.copy() 39 40 node.transform(simplifier.rewrite_between, copy=False) 41 distance = normalization_distance(node, dnf=dnf, max_=max_distance) 42 43 if distance > max_distance: 44 logger.info( 45 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 46 ) 47 return expression 48 49 try: 50 node = node.replace( 51 while_changing( 52 node, 53 lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier), 54 ) 55 ) 56 except OptimizeError as e: 57 logger.info(e) 58 node.replace(original) 59 if root: 60 return original 61 return expression 62 63 if root: 64 expression = node 65 66 return expression 67 68 69def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 70 """ 71 Checks whether a given expression is in a normal form of interest. 72 73 Example: 74 >>> from sqlglot import parse_one 75 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 76 True 77 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 78 True 79 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 80 False 81 82 Args: 83 expression: The expression to check if it's normalized. 84 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 85 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 86 """ 87 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 88 return not any( 89 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 90 ) 91 92 93def normalization_distance( 94 expression: exp.Expression, dnf: bool = False, max_: float = float("inf") 95) -> int: 96 """ 97 The difference in the number of predicates between a given expression and its normalized form. 98 99 This is used as an estimate of the cost of the conversion which is exponential in complexity. 100 101 Example: 102 >>> import sqlglot 103 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 104 >>> normalization_distance(expression) 105 4 106 107 Args: 108 expression: The expression to compute the normalization distance for. 109 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 110 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 111 max_: stop early if count exceeds this. 112 113 Returns: 114 The normalization distance. 115 """ 116 total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) 117 118 for length in _predicate_lengths(expression, dnf, max_): 119 total += length 120 if total > max_: 121 return total 122 123 return total 124 125 126def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0): 127 """ 128 Returns a list of predicate lengths when expanded to normalized form. 129 130 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 131 """ 132 if depth > max_: 133 yield depth 134 return 135 136 expression = expression.unnest() 137 138 if not isinstance(expression, exp.Connector): 139 yield 1 140 return 141 142 depth += 1 143 left, right = expression.args.values() 144 145 if isinstance(expression, exp.And if dnf else exp.Or): 146 for a in _predicate_lengths(left, dnf, max_, depth): 147 for b in _predicate_lengths(right, dnf, max_, depth): 148 yield a + b 149 else: 150 yield from _predicate_lengths(left, dnf, max_, depth) 151 yield from _predicate_lengths(right, dnf, max_, depth) 152 153 154def distributive_law(expression, dnf, max_distance, simplifier=None): 155 """ 156 x OR (y AND z) -> (x OR y) AND (x OR z) 157 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 158 """ 159 if normalized(expression, dnf=dnf): 160 return expression 161 162 distance = normalization_distance(expression, dnf=dnf, max_=max_distance) 163 164 if distance > max_distance: 165 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 166 167 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 168 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 169 170 if isinstance(expression, from_exp): 171 a, b = expression.unnest_operands() 172 173 from_func = exp.and_ if from_exp == exp.And else exp.or_ 174 to_func = exp.and_ if to_exp == exp.And else exp.or_ 175 176 simplifier = simplifier or Simplifier(annotate_new_expressions=False) 177 178 if isinstance(a, to_exp) and isinstance(b, to_exp): 179 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 180 return _distribute(a, b, from_func, to_func, simplifier) 181 return _distribute(b, a, from_func, to_func, simplifier) 182 if isinstance(a, to_exp): 183 return _distribute(b, a, from_func, to_func, simplifier) 184 if isinstance(b, to_exp): 185 return _distribute(a, b, from_func, to_func, simplifier) 186 187 return expression 188 189 190def _distribute(a, b, from_func, to_func, simplifier): 191 if isinstance(a, exp.Connector): 192 exp.replace_children( 193 a, 194 lambda c: to_func( 195 simplifier.uniq_sort(flatten(from_func(c, b.left))), 196 simplifier.uniq_sort(flatten(from_func(c, b.right))), 197 copy=False, 198 ), 199 ) 200 else: 201 a = to_func( 202 simplifier.uniq_sort(flatten(from_func(a, b.left))), 203 simplifier.uniq_sort(flatten(from_func(a, b.right))), 204 copy=False, 205 ) 206 207 return a
logger =
<Logger sqlglot (WARNING)>
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 simplifier = Simplifier(annotate_new_expressions=False) 33 34 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 35 if isinstance(node, exp.Connector): 36 if normalized(node, dnf=dnf): 37 continue 38 root = node is expression 39 original = node.copy() 40 41 node.transform(simplifier.rewrite_between, copy=False) 42 distance = normalization_distance(node, dnf=dnf, max_=max_distance) 43 44 if distance > max_distance: 45 logger.info( 46 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 47 ) 48 return expression 49 50 try: 51 node = node.replace( 52 while_changing( 53 node, 54 lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier), 55 ) 56 ) 57 except OptimizeError as e: 58 logger.info(e) 59 node.replace(original) 60 if root: 61 return original 62 return expression 63 64 if root: 65 expression = node 66 67 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
70def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 71 """ 72 Checks whether a given expression is in a normal form of interest. 73 74 Example: 75 >>> from sqlglot import parse_one 76 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 77 True 78 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 79 True 80 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 81 False 82 83 Args: 84 expression: The expression to check if it's normalized. 85 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 86 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 87 """ 88 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 89 return not any( 90 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 91 )
Checks whether a given expression is in a normal form of interest.
Example:
>>> from sqlglot import parse_one >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) True >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default True >>> normalized(parse_one("a AND (b OR c)"), dnf=True) False
Arguments:
- expression: The expression to check if it's normalized.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def
normalization_distance( expression: sqlglot.expressions.Expression, dnf: bool = False, max_: float = inf) -> int:
94def normalization_distance( 95 expression: exp.Expression, dnf: bool = False, max_: float = float("inf") 96) -> int: 97 """ 98 The difference in the number of predicates between a given expression and its normalized form. 99 100 This is used as an estimate of the cost of the conversion which is exponential in complexity. 101 102 Example: 103 >>> import sqlglot 104 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 105 >>> normalization_distance(expression) 106 4 107 108 Args: 109 expression: The expression to compute the normalization distance for. 110 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 111 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 112 max_: stop early if count exceeds this. 113 114 Returns: 115 The normalization distance. 116 """ 117 total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) 118 119 for length in _predicate_lengths(expression, dnf, max_): 120 total += length 121 if total > max_: 122 return total 123 124 return total
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression: The expression to compute the normalization distance for.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
- max_: stop early if count exceeds this.
Returns:
The normalization distance.
def
distributive_law(expression, dnf, max_distance, simplifier=None):
155def distributive_law(expression, dnf, max_distance, simplifier=None): 156 """ 157 x OR (y AND z) -> (x OR y) AND (x OR z) 158 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 159 """ 160 if normalized(expression, dnf=dnf): 161 return expression 162 163 distance = normalization_distance(expression, dnf=dnf, max_=max_distance) 164 165 if distance > max_distance: 166 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 167 168 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 169 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 170 171 if isinstance(expression, from_exp): 172 a, b = expression.unnest_operands() 173 174 from_func = exp.and_ if from_exp == exp.And else exp.or_ 175 to_func = exp.and_ if to_exp == exp.And else exp.or_ 176 177 simplifier = simplifier or Simplifier(annotate_new_expressions=False) 178 179 if isinstance(a, to_exp) and isinstance(b, to_exp): 180 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 181 return _distribute(a, b, from_func, to_func, simplifier) 182 return _distribute(b, a, from_func, to_func, simplifier) 183 if isinstance(a, to_exp): 184 return _distribute(b, a, from_func, to_func, simplifier) 185 if isinstance(b, to_exp): 186 return _distribute(a, b, from_func, to_func, simplifier) 187 188 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)