Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.helper import while_changing
  8from sqlglot.optimizer.scope import find_all_in_scope
  9from sqlglot.optimizer.simplify import Simplifier, flatten
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 15    """
 16    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 21        >>> normalize(expression, dnf=False).sql()
 22        '(x OR z) AND (y OR z)'
 23
 24    Args:
 25        expression: expression to normalize
 26        dnf: rewrite in disjunctive normal form instead.
 27        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 28    Returns:
 29        sqlglot.Expression: normalized expression
 30    """
 31    simplifier = Simplifier(annotate_new_expressions=False)
 32
 33    for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
 34        if isinstance(node, exp.Connector):
 35            if normalized(node, dnf=dnf):
 36                continue
 37            root = node is expression
 38            original = node.copy()
 39
 40            node.transform(simplifier.rewrite_between, copy=False)
 41            distance = normalization_distance(node, dnf=dnf, max_=max_distance)
 42
 43            if distance > max_distance:
 44                logger.info(
 45                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 46                )
 47                return expression
 48
 49            try:
 50                node = node.replace(
 51                    while_changing(
 52                        node,
 53                        lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
 54                    )
 55                )
 56            except OptimizeError as e:
 57                logger.info(e)
 58                node.replace(original)
 59                if root:
 60                    return original
 61                return expression
 62
 63            if root:
 64                expression = node
 65
 66    return expression
 67
 68
 69def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
 70    """
 71    Checks whether a given expression is in a normal form of interest.
 72
 73    Example:
 74        >>> from sqlglot import parse_one
 75        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
 76        True
 77        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
 78        True
 79        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
 80        False
 81
 82    Args:
 83        expression: The expression to check if it's normalized.
 84        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
 85            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
 86    """
 87    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 88    return not any(
 89        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
 90    )
 91
 92
 93def normalization_distance(
 94    expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
 95) -> int:
 96    """
 97    The difference in the number of predicates between a given expression and its normalized form.
 98
 99    This is used as an estimate of the cost of the conversion which is exponential in complexity.
100
101    Example:
102        >>> import sqlglot
103        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
104        >>> normalization_distance(expression)
105        4
106
107    Args:
108        expression: The expression to compute the normalization distance for.
109        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
110            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
111        max_: stop early if count exceeds this.
112
113    Returns:
114        The normalization distance.
115    """
116    total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
117
118    for length in _predicate_lengths(expression, dnf, max_):
119        total += length
120        if total > max_:
121            return total
122
123    return total
124
125
126def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
127    """
128    Returns a list of predicate lengths when expanded to normalized form.
129
130    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
131    """
132    if depth > max_:
133        yield depth
134        return
135
136    expression = expression.unnest()
137
138    if not isinstance(expression, exp.Connector):
139        yield 1
140        return
141
142    depth += 1
143    left, right = expression.args.values()
144
145    if isinstance(expression, exp.And if dnf else exp.Or):
146        for a in _predicate_lengths(left, dnf, max_, depth):
147            for b in _predicate_lengths(right, dnf, max_, depth):
148                yield a + b
149    else:
150        yield from _predicate_lengths(left, dnf, max_, depth)
151        yield from _predicate_lengths(right, dnf, max_, depth)
152
153
154def distributive_law(expression, dnf, max_distance, simplifier=None):
155    """
156    x OR (y AND z) -> (x OR y) AND (x OR z)
157    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
158    """
159    if normalized(expression, dnf=dnf):
160        return expression
161
162    distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
163
164    if distance > max_distance:
165        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
166
167    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
168    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
169
170    if isinstance(expression, from_exp):
171        a, b = expression.unnest_operands()
172
173        from_func = exp.and_ if from_exp == exp.And else exp.or_
174        to_func = exp.and_ if to_exp == exp.And else exp.or_
175
176        simplifier = simplifier or Simplifier(annotate_new_expressions=False)
177
178        if isinstance(a, to_exp) and isinstance(b, to_exp):
179            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
180                return _distribute(a, b, from_func, to_func, simplifier)
181            return _distribute(b, a, from_func, to_func, simplifier)
182        if isinstance(a, to_exp):
183            return _distribute(b, a, from_func, to_func, simplifier)
184        if isinstance(b, to_exp):
185            return _distribute(a, b, from_func, to_func, simplifier)
186
187    return expression
188
189
190def _distribute(a, b, from_func, to_func, simplifier):
191    if isinstance(a, exp.Connector):
192        exp.replace_children(
193            a,
194            lambda c: to_func(
195                simplifier.uniq_sort(flatten(from_func(c, b.left))),
196                simplifier.uniq_sort(flatten(from_func(c, b.right))),
197                copy=False,
198            ),
199        )
200    else:
201        a = to_func(
202            simplifier.uniq_sort(flatten(from_func(a, b.left))),
203            simplifier.uniq_sort(flatten(from_func(a, b.right))),
204            copy=False,
205        )
206
207    return a
logger = <Logger sqlglot (WARNING)>
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
16    """
17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
22        >>> normalize(expression, dnf=False).sql()
23        '(x OR z) AND (y OR z)'
24
25    Args:
26        expression: expression to normalize
27        dnf: rewrite in disjunctive normal form instead.
28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
29    Returns:
30        sqlglot.Expression: normalized expression
31    """
32    simplifier = Simplifier(annotate_new_expressions=False)
33
34    for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
35        if isinstance(node, exp.Connector):
36            if normalized(node, dnf=dnf):
37                continue
38            root = node is expression
39            original = node.copy()
40
41            node.transform(simplifier.rewrite_between, copy=False)
42            distance = normalization_distance(node, dnf=dnf, max_=max_distance)
43
44            if distance > max_distance:
45                logger.info(
46                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
47                )
48                return expression
49
50            try:
51                node = node.replace(
52                    while_changing(
53                        node,
54                        lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
55                    )
56                )
57            except OptimizeError as e:
58                logger.info(e)
59                node.replace(original)
60                if root:
61                    return original
62                return expression
63
64            if root:
65                expression = node
66
67    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression: sqlglot.expressions.Expression, dnf: bool = False) -> bool:
70def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
71    """
72    Checks whether a given expression is in a normal form of interest.
73
74    Example:
75        >>> from sqlglot import parse_one
76        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
77        True
78        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
79        True
80        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
81        False
82
83    Args:
84        expression: The expression to check if it's normalized.
85        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
86            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
87    """
88    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
89    return not any(
90        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
91    )

Checks whether a given expression is in a normal form of interest.

Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Arguments:
  • expression: The expression to check if it's normalized.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def normalization_distance( expression: sqlglot.expressions.Expression, dnf: bool = False, max_: float = inf) -> int:
 94def normalization_distance(
 95    expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
 96) -> int:
 97    """
 98    The difference in the number of predicates between a given expression and its normalized form.
 99
100    This is used as an estimate of the cost of the conversion which is exponential in complexity.
101
102    Example:
103        >>> import sqlglot
104        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
105        >>> normalization_distance(expression)
106        4
107
108    Args:
109        expression: The expression to compute the normalization distance for.
110        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
111            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
112        max_: stop early if count exceeds this.
113
114    Returns:
115        The normalization distance.
116    """
117    total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
118
119    for length in _predicate_lengths(expression, dnf, max_):
120        total += length
121        if total > max_:
122            return total
123
124    return total

The difference in the number of predicates between a given expression and its normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression: The expression to compute the normalization distance for.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
  • max_: stop early if count exceeds this.
Returns:

The normalization distance.

def distributive_law(expression, dnf, max_distance, simplifier=None):
155def distributive_law(expression, dnf, max_distance, simplifier=None):
156    """
157    x OR (y AND z) -> (x OR y) AND (x OR z)
158    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
159    """
160    if normalized(expression, dnf=dnf):
161        return expression
162
163    distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
164
165    if distance > max_distance:
166        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
167
168    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
169    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
170
171    if isinstance(expression, from_exp):
172        a, b = expression.unnest_operands()
173
174        from_func = exp.and_ if from_exp == exp.And else exp.or_
175        to_func = exp.and_ if to_exp == exp.And else exp.or_
176
177        simplifier = simplifier or Simplifier(annotate_new_expressions=False)
178
179        if isinstance(a, to_exp) and isinstance(b, to_exp):
180            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
181                return _distribute(a, b, from_func, to_func, simplifier)
182            return _distribute(b, a, from_func, to_func, simplifier)
183        if isinstance(a, to_exp):
184            return _distribute(b, a, from_func, to_func, simplifier)
185        if isinstance(b, to_exp):
186            return _distribute(a, b, from_func, to_func, simplifier)
187
188    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)