Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.helper import while_changing
  8from sqlglot.optimizer.scope import find_all_in_scope
  9from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 15    """
 16    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 21        >>> normalize(expression, dnf=False).sql()
 22        '(x OR z) AND (y OR z)'
 23
 24    Args:
 25        expression: expression to normalize
 26        dnf: rewrite in disjunctive normal form instead.
 27        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 28    Returns:
 29        sqlglot.Expression: normalized expression
 30    """
 31    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
 32        if isinstance(node, exp.Connector):
 33            if normalized(node, dnf=dnf):
 34                continue
 35            root = node is expression
 36            original = node.copy()
 37
 38            node.transform(rewrite_between, copy=False)
 39            distance = normalization_distance(node, dnf=dnf)
 40
 41            if distance > max_distance:
 42                logger.info(
 43                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 44                )
 45                return expression
 46
 47            try:
 48                node = node.replace(
 49                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
 50                )
 51            except OptimizeError as e:
 52                logger.info(e)
 53                node.replace(original)
 54                if root:
 55                    return original
 56                return expression
 57
 58            if root:
 59                expression = node
 60
 61    return expression
 62
 63
 64def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
 65    """
 66    Checks whether a given expression is in a normal form of interest.
 67
 68    Example:
 69        >>> from sqlglot import parse_one
 70        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
 71        True
 72        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
 73        True
 74        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
 75        False
 76
 77    Args:
 78        expression: The expression to check if it's normalized.
 79        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
 80            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
 81    """
 82    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 83    return not any(
 84        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
 85    )
 86
 87
 88def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
 89    """
 90    The difference in the number of predicates between a given expression and its normalized form.
 91
 92    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 93
 94    Example:
 95        >>> import sqlglot
 96        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 97        >>> normalization_distance(expression)
 98        4
 99
100    Args:
101        expression: The expression to compute the normalization distance for.
102        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
103            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
104
105    Returns:
106        The normalization distance.
107    """
108    return sum(_predicate_lengths(expression, dnf)) - (
109        sum(1 for _ in expression.find_all(exp.Connector)) + 1
110    )
111
112
113def _predicate_lengths(expression, dnf):
114    """
115    Returns a list of predicate lengths when expanded to normalized form.
116
117    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
118    """
119    expression = expression.unnest()
120
121    if not isinstance(expression, exp.Connector):
122        return (1,)
123
124    left, right = expression.args.values()
125
126    if isinstance(expression, exp.And if dnf else exp.Or):
127        return tuple(
128            a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
129        )
130    return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
131
132
133def distributive_law(expression, dnf, max_distance):
134    """
135    x OR (y AND z) -> (x OR y) AND (x OR z)
136    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
137    """
138    if normalized(expression, dnf=dnf):
139        return expression
140
141    distance = normalization_distance(expression, dnf=dnf)
142
143    if distance > max_distance:
144        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
145
146    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
147    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
148
149    if isinstance(expression, from_exp):
150        a, b = expression.unnest_operands()
151
152        from_func = exp.and_ if from_exp == exp.And else exp.or_
153        to_func = exp.and_ if to_exp == exp.And else exp.or_
154
155        if isinstance(a, to_exp) and isinstance(b, to_exp):
156            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
157                return _distribute(a, b, from_func, to_func)
158            return _distribute(b, a, from_func, to_func)
159        if isinstance(a, to_exp):
160            return _distribute(b, a, from_func, to_func)
161        if isinstance(b, to_exp):
162            return _distribute(a, b, from_func, to_func)
163
164    return expression
165
166
167def _distribute(a, b, from_func, to_func):
168    if isinstance(a, exp.Connector):
169        exp.replace_children(
170            a,
171            lambda c: to_func(
172                uniq_sort(flatten(from_func(c, b.left))),
173                uniq_sort(flatten(from_func(c, b.right))),
174                copy=False,
175            ),
176        )
177    else:
178        a = to_func(
179            uniq_sort(flatten(from_func(a, b.left))),
180            uniq_sort(flatten(from_func(a, b.right))),
181            copy=False,
182        )
183
184    return a
logger = <Logger sqlglot (WARNING)>
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
16    """
17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
22        >>> normalize(expression, dnf=False).sql()
23        '(x OR z) AND (y OR z)'
24
25    Args:
26        expression: expression to normalize
27        dnf: rewrite in disjunctive normal form instead.
28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
29    Returns:
30        sqlglot.Expression: normalized expression
31    """
32    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
33        if isinstance(node, exp.Connector):
34            if normalized(node, dnf=dnf):
35                continue
36            root = node is expression
37            original = node.copy()
38
39            node.transform(rewrite_between, copy=False)
40            distance = normalization_distance(node, dnf=dnf)
41
42            if distance > max_distance:
43                logger.info(
44                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
45                )
46                return expression
47
48            try:
49                node = node.replace(
50                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
51                )
52            except OptimizeError as e:
53                logger.info(e)
54                node.replace(original)
55                if root:
56                    return original
57                return expression
58
59            if root:
60                expression = node
61
62    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression: sqlglot.expressions.Expression, dnf: bool = False) -> bool:
65def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
66    """
67    Checks whether a given expression is in a normal form of interest.
68
69    Example:
70        >>> from sqlglot import parse_one
71        >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
72        True
73        >>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
74        True
75        >>> normalized(parse_one("a AND (b OR c)"), dnf=True)
76        False
77
78    Args:
79        expression: The expression to check if it's normalized.
80        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
81            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
82    """
83    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
84    return not any(
85        connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root)
86    )

Checks whether a given expression is in a normal form of interest.

Example:
>>> from sqlglot import parse_one
>>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True)
True
>>> normalized(parse_one("(a OR b) AND c"))  # Checks CNF by default
True
>>> normalized(parse_one("a AND (b OR c)"), dnf=True)
False
Arguments:
  • expression: The expression to check if it's normalized.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
def normalization_distance(expression: sqlglot.expressions.Expression, dnf: bool = False) -> int:
 89def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int:
 90    """
 91    The difference in the number of predicates between a given expression and its normalized form.
 92
 93    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 94
 95    Example:
 96        >>> import sqlglot
 97        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 98        >>> normalization_distance(expression)
 99        4
100
101    Args:
102        expression: The expression to compute the normalization distance for.
103        dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
104            Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
105
106    Returns:
107        The normalization distance.
108    """
109    return sum(_predicate_lengths(expression, dnf)) - (
110        sum(1 for _ in expression.find_all(exp.Connector)) + 1
111    )

The difference in the number of predicates between a given expression and its normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression: The expression to compute the normalization distance for.
  • dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:

The normalization distance.

def distributive_law(expression, dnf, max_distance):
134def distributive_law(expression, dnf, max_distance):
135    """
136    x OR (y AND z) -> (x OR y) AND (x OR z)
137    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
138    """
139    if normalized(expression, dnf=dnf):
140        return expression
141
142    distance = normalization_distance(expression, dnf=dnf)
143
144    if distance > max_distance:
145        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
146
147    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
148    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
149
150    if isinstance(expression, from_exp):
151        a, b = expression.unnest_operands()
152
153        from_func = exp.and_ if from_exp == exp.And else exp.or_
154        to_func = exp.and_ if to_exp == exp.And else exp.or_
155
156        if isinstance(a, to_exp) and isinstance(b, to_exp):
157            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
158                return _distribute(a, b, from_func, to_func)
159            return _distribute(b, a, from_func, to_func)
160        if isinstance(a, to_exp):
161            return _distribute(b, a, from_func, to_func)
162        if isinstance(b, to_exp):
163            return _distribute(a, b, from_func, to_func)
164
165    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)