Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.helper import while_changing
  8from sqlglot.optimizer.scope import find_all_in_scope
  9from sqlglot.optimizer.simplify import Simplifier, flatten
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(
 15    expression: exp.Expression, dnf: bool = False, max_distance: int = 128
 16) -> exp.Expression:
 17    """
 18    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 19
 20    Example:
 21        >>> import sqlglot
 22        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 23        >>> normalize(expression, dnf=False).sql()
 24        '(x OR z) AND (y OR z)'
 25
 26    Args:
 27        expression: expression to normalize
 28        dnf: rewrite in disjunctive normal form instead.
 29        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 30    Returns:
 31        sqlglot.Expression: normalized expression
 32    """
 33    simplifier = Simplifier(annotate_new_expressions=False)
 34
 35    for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
 36        if isinstance(node, exp.Connector):
 37            if normalized(node, dnf=dnf):
 38                continue
 39            root = node is expression
 40            original = node.copy()
 41
 42            node.transform(simplifier.rewrite_between, copy=False)
 43            distance = normalization_distance(node, dnf=dnf, max_=max_distance)
 44
 45            if distance > max_distance:
 46                logger.info(
 47                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 48                )
 49                return expression
 50
 51            try:
 52                node = node.replace(
 53                    while_changing(
 54                        node,
 55                        lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
 56                    )
 57                )
 58            except OptimizeError as e:
 59                logger.info(e)
 60                node.replace(original)
 61                if root:
 62                    return original
 63                return expression
 64
 65            if root:
 66                expression = node
 67
 68    return expression
 69
 70
 71def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
 72    """
 73    Checks whether a given expression is in a normal form of interest.
 74
 75    Example:
 76        >>> from sqlglot import parse_one
 77        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
 78        True
 79        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
 80        True
 81        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
 82        False
 83
 84    Args:
 85        expression: The expression to check if it's normalized.
 86        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
 87            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
 88    """
 89    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 90    return not any(
 91        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
 92    )
 93
 94
 95def normalization_distance(
 96    expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
 97) -> int:
 98    """
 99    The difference in the number of predicates between a given expression and its normalized form.
100
101    This is used as an estimate of the cost of the conversion which is exponential in complexity.
102
103    Example:
104        >>> import sqlglot
105        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
106        >>> normalization_distance(expression)
107        4
108
109    Args:
110        expression: The expression to compute the normalization distance for.
111        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
112            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
113        max_: stop early if count exceeds this.
114
115    Returns:
116        The normalization distance.
117    """
118    total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
119
120    for length in _predicate_lengths(expression, dnf, max_):
121        total += length
122        if total > max_:
123            return total
124
125    return total
126
127
128def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
129    """
130    Returns a list of predicate lengths when expanded to normalized form.
131
132    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
133    """
134    if depth > max_:
135        yield depth
136        return
137
138    expression = expression.unnest()
139
140    if not isinstance(expression, exp.Connector):
141        yield 1
142        return
143
144    depth += 1
145    left, right = expression.args.values()
146
147    if isinstance(expression, exp.And if dnf else exp.Or):
148        for a in _predicate_lengths(left, dnf, max_, depth):
149            for b in _predicate_lengths(right, dnf, max_, depth):
150                yield a + b
151    else:
152        yield from _predicate_lengths(left, dnf, max_, depth)
153        yield from _predicate_lengths(right, dnf, max_, depth)
154
155
156def distributive_law(expression, dnf, max_distance, simplifier=None):
157    """
158    x OR (y AND z) -> (x OR y) AND (x OR z)
159    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
160    """
161    if normalized(expression, dnf=dnf):
162        return expression
163
164    distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
165
166    if distance > max_distance:
167        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
168
169    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
170    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
171
172    if isinstance(expression, from_exp):
173        a, b = expression.unnest_operands()
174
175        from_func = exp.and_ if from_exp == exp.And else exp.or_
176        to_func = exp.and_ if to_exp == exp.And else exp.or_
177
178        simplifier = simplifier or Simplifier(annotate_new_expressions=False)
179
180        if isinstance(a, to_exp) and isinstance(b, to_exp):
181            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
182                return _distribute(a, b, from_func, to_func, simplifier)
183            return _distribute(b, a, from_func, to_func, simplifier)
184        if isinstance(a, to_exp):
185            return _distribute(b, a, from_func, to_func, simplifier)
186        if isinstance(b, to_exp):
187            return _distribute(a, b, from_func, to_func, simplifier)
188
189    return expression
190
191
192def _distribute(a, b, from_func, to_func, simplifier):
193    if isinstance(a, exp.Connector):
194        exp.replace_children(
195            a,
196            lambda c: to_func(
197                simplifier.uniq_sort(flatten(from_func(c, b.left))),
198                simplifier.uniq_sort(flatten(from_func(c, b.right))),
199                copy=False,
200            ),
201        )
202    else:
203        a = to_func(
204            simplifier.uniq_sort(flatten(from_func(a, b.left))),
205            simplifier.uniq_sort(flatten(from_func(a, b.right))),
206            copy=False,
207        )
208
209    return a
logger = <Logger sqlglot (WARNING)>
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128) -> sqlglot.expressions.Expression:
15def normalize(
16    expression: exp.Expression, dnf: bool = False, max_distance: int = 128
17) -> exp.Expression:
18    """
19    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
20
21    Example:
22        >>> import sqlglot
23        >>> expression = sqlglot.parse_one("(x AND y) OR z")
24        >>> normalize(expression, dnf=False).sql()
25        '(x OR z) AND (y OR z)'
26
27    Args:
28        expression: expression to normalize
29        dnf: rewrite in disjunctive normal form instead.
30        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
31    Returns:
32        sqlglot.Expression: normalized expression
33    """
34    simplifier = Simplifier(annotate_new_expressions=False)
35
36    for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
37        if isinstance(node, exp.Connector):
38            if normalized(node, dnf=dnf):
39                continue
40            root = node is expression
41            original = node.copy()
42
43            node.transform(simplifier.rewrite_between, copy=False)
44            distance = normalization_distance(node, dnf=dnf, max_=max_distance)
45
46            if distance > max_distance:
47                logger.info(
48                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
49                )
50                return expression
51
52            try:
53                node = node.replace(
54                    while_changing(
55                        node,
56                        lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
57                    )
58                )
59            except OptimizeError as e:
60                logger.info(e)
61                node.replace(original)
62                if root:
63                    return original
64                return expression
65
66            if root:
67                expression = node
68
69    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression: sqlglot.expressions.Expression, dnf: bool = False) -> bool:
72def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
73    """
74    Checks whether a given expression is in a normal form of interest.
75
76    Example:
77        >>> from sqlglot import parse_one
78        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
79        True
80        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
81        True
82        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
83        False
84
85    Args:
86        expression: The expression to check if it's normalized.
87        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
88            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
89    """
90    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
91    return not any(
92        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
93    )

Checks whether a given expression is in a normal form of interest.

Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Arguments:
  • expression: The expression to check if it's normalized.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def normalization_distance( expression: sqlglot.expressions.Expression, dnf: bool = False, max_: float = inf) -> int:
 96def normalization_distance(
 97    expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
 98) -> int:
 99    """
100    The difference in the number of predicates between a given expression and its normalized form.
101
102    This is used as an estimate of the cost of the conversion which is exponential in complexity.
103
104    Example:
105        >>> import sqlglot
106        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
107        >>> normalization_distance(expression)
108        4
109
110    Args:
111        expression: The expression to compute the normalization distance for.
112        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
113            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
114        max_: stop early if count exceeds this.
115
116    Returns:
117        The normalization distance.
118    """
119    total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
120
121    for length in _predicate_lengths(expression, dnf, max_):
122        total += length
123        if total > max_:
124            return total
125
126    return total

The difference in the number of predicates between a given expression and its normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression: The expression to compute the normalization distance for.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
  • max_: stop early if count exceeds this.
Returns:

The normalization distance.

def distributive_law(expression, dnf, max_distance, simplifier=None):
157def distributive_law(expression, dnf, max_distance, simplifier=None):
158    """
159    x OR (y AND z) -> (x OR y) AND (x OR z)
160    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
161    """
162    if normalized(expression, dnf=dnf):
163        return expression
164
165    distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
166
167    if distance > max_distance:
168        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
169
170    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
171    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
172
173    if isinstance(expression, from_exp):
174        a, b = expression.unnest_operands()
175
176        from_func = exp.and_ if from_exp == exp.And else exp.or_
177        to_func = exp.and_ if to_exp == exp.And else exp.or_
178
179        simplifier = simplifier or Simplifier(annotate_new_expressions=False)
180
181        if isinstance(a, to_exp) and isinstance(b, to_exp):
182            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
183                return _distribute(a, b, from_func, to_func, simplifier)
184            return _distribute(b, a, from_func, to_func, simplifier)
185        if isinstance(a, to_exp):
186            return _distribute(b, a, from_func, to_func, simplifier)
187        if isinstance(b, to_exp):
188            return _distribute(a, b, from_func, to_func, simplifier)
189
190    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)