Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.helper import while_changing
  8from sqlglot.optimizer.scope import find_all_in_scope
  9from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 15    """
 16    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 21        >>> normalize(expression, dnf=False).sql()
 22        '(x OR z) AND (y OR z)'
 23
 24    Args:
 25        expression: expression to normalize
 26        dnf: rewrite in disjunctive normal form instead.
 27        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 28    Returns:
 29        sqlglot.Expression: normalized expression
 30    """
 31    for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
 32        if isinstance(node, exp.Connector):
 33            if normalized(node, dnf=dnf):
 34                continue
 35            root = node is expression
 36            original = node.copy()
 37
 38            node.transform(rewrite_between, copy=False)
 39            distance = normalization_distance(node, dnf=dnf, max_=max_distance)
 40
 41            if distance > max_distance:
 42                logger.info(
 43                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 44                )
 45                return expression
 46
 47            try:
 48                node = node.replace(
 49                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
 50                )
 51            except OptimizeError as e:
 52                logger.info(e)
 53                node.replace(original)
 54                if root:
 55                    return original
 56                return expression
 57
 58            if root:
 59                expression = node
 60
 61    return expression
 62
 63
 64def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
 65    """
 66    Checks whether a given expression is in a normal form of interest.
 67
 68    Example:
 69        >>> from sqlglot import parse_one
 70        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
 71        True
 72        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
 73        True
 74        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
 75        False
 76
 77    Args:
 78        expression: The expression to check if it's normalized.
 79        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
 80            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
 81    """
 82    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 83    return not any(
 84        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
 85    )
 86
 87
 88def normalization_distance(
 89    expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
 90) -> int:
 91    """
 92    The difference in the number of predicates between a given expression and its normalized form.
 93
 94    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 95
 96    Example:
 97        >>> import sqlglot
 98        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 99        >>> normalization_distance(expression)
100        4
101
102    Args:
103        expression: The expression to compute the normalization distance for.
104        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
105            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
106        max_: stop early if count exceeds this.
107
108    Returns:
109        The normalization distance.
110    """
111    total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
112
113    for length in _predicate_lengths(expression, dnf, max_):
114        total += length
115        if total > max_:
116            return total
117
118    return total
119
120
121def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
122    """
123    Returns a list of predicate lengths when expanded to normalized form.
124
125    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
126    """
127    if depth > max_:
128        yield depth
129        return
130
131    expression = expression.unnest()
132
133    if not isinstance(expression, exp.Connector):
134        yield 1
135        return
136
137    depth += 1
138    left, right = expression.args.values()
139
140    if isinstance(expression, exp.And if dnf else exp.Or):
141        for a in _predicate_lengths(left, dnf, max_, depth):
142            for b in _predicate_lengths(right, dnf, max_, depth):
143                yield a + b
144    else:
145        yield from _predicate_lengths(left, dnf, max_, depth)
146        yield from _predicate_lengths(right, dnf, max_, depth)
147
148
149def distributive_law(expression, dnf, max_distance):
150    """
151    x OR (y AND z) -> (x OR y) AND (x OR z)
152    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
153    """
154    if normalized(expression, dnf=dnf):
155        return expression
156
157    distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
158
159    if distance > max_distance:
160        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
161
162    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
163    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
164
165    if isinstance(expression, from_exp):
166        a, b = expression.unnest_operands()
167
168        from_func = exp.and_ if from_exp == exp.And else exp.or_
169        to_func = exp.and_ if to_exp == exp.And else exp.or_
170
171        if isinstance(a, to_exp) and isinstance(b, to_exp):
172            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
173                return _distribute(a, b, from_func, to_func)
174            return _distribute(b, a, from_func, to_func)
175        if isinstance(a, to_exp):
176            return _distribute(b, a, from_func, to_func)
177        if isinstance(b, to_exp):
178            return _distribute(a, b, from_func, to_func)
179
180    return expression
181
182
183def _distribute(a, b, from_func, to_func):
184    if isinstance(a, exp.Connector):
185        exp.replace_children(
186            a,
187            lambda c: to_func(
188                uniq_sort(flatten(from_func(c, b.left))),
189                uniq_sort(flatten(from_func(c, b.right))),
190                copy=False,
191            ),
192        )
193    else:
194        a = to_func(
195            uniq_sort(flatten(from_func(a, b.left))),
196            uniq_sort(flatten(from_func(a, b.right))),
197            copy=False,
198        )
199
200    return a
logger = <Logger sqlglot (WARNING)>
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
16    """
17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
22        >>> normalize(expression, dnf=False).sql()
23        '(x OR z) AND (y OR z)'
24
25    Args:
26        expression: expression to normalize
27        dnf: rewrite in disjunctive normal form instead.
28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
29    Returns:
30        sqlglot.Expression: normalized expression
31    """
32    for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
33        if isinstance(node, exp.Connector):
34            if normalized(node, dnf=dnf):
35                continue
36            root = node is expression
37            original = node.copy()
38
39            node.transform(rewrite_between, copy=False)
40            distance = normalization_distance(node, dnf=dnf, max_=max_distance)
41
42            if distance > max_distance:
43                logger.info(
44                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
45                )
46                return expression
47
48            try:
49                node = node.replace(
50                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
51                )
52            except OptimizeError as e:
53                logger.info(e)
54                node.replace(original)
55                if root:
56                    return original
57                return expression
58
59            if root:
60                expression = node
61
62    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression: sqlglot.expressions.Expression, dnf: bool = False) -> bool:
65def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
66    """
67    Checks whether a given expression is in a normal form of interest.
68
69    Example:
70        >>> from sqlglot import parse_one
71        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
72        True
73        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
74        True
75        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
76        False
77
78    Args:
79        expression: The expression to check if it's normalized.
80        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
81            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
82    """
83    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
84    return not any(
85        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
86    )

Checks whether a given expression is in a normal form of interest.

Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Arguments:
  • expression: The expression to check if it's normalized.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def normalization_distance( expression: sqlglot.expressions.Expression, dnf: bool = False, max_: float = inf) -> int:
 89def normalization_distance(
 90    expression: exp.Expression, dnf: bool = False, max_: float = float("inf")
 91) -> int:
 92    """
 93    The difference in the number of predicates between a given expression and its normalized form.
 94
 95    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 96
 97    Example:
 98        >>> import sqlglot
 99        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
100        >>> normalization_distance(expression)
101        4
102
103    Args:
104        expression: The expression to compute the normalization distance for.
105        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
106            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
107        max_: stop early if count exceeds this.
108
109    Returns:
110        The normalization distance.
111    """
112    total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1)
113
114    for length in _predicate_lengths(expression, dnf, max_):
115        total += length
116        if total > max_:
117            return total
118
119    return total

The difference in the number of predicates between a given expression and its normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression: The expression to compute the normalization distance for.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
  • max_: stop early if count exceeds this.
Returns:

The normalization distance.

def distributive_law(expression, dnf, max_distance):
150def distributive_law(expression, dnf, max_distance):
151    """
152    x OR (y AND z) -> (x OR y) AND (x OR z)
153    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
154    """
155    if normalized(expression, dnf=dnf):
156        return expression
157
158    distance = normalization_distance(expression, dnf=dnf, max_=max_distance)
159
160    if distance > max_distance:
161        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
162
163    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
164    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
165
166    if isinstance(expression, from_exp):
167        a, b = expression.unnest_operands()
168
169        from_func = exp.and_ if from_exp == exp.And else exp.or_
170        to_func = exp.and_ if to_exp == exp.And else exp.or_
171
172        if isinstance(a, to_exp) and isinstance(b, to_exp):
173            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
174                return _distribute(a, b, from_func, to_func)
175            return _distribute(b, a, from_func, to_func)
176        if isinstance(a, to_exp):
177            return _distribute(b, a, from_func, to_func)
178        if isinstance(b, to_exp):
179            return _distribute(a, b, from_func, to_func)
180
181    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)