Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  9from sqlglot.optimizer.annotate_types import TypeAnnotator
 10
 11
 12def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
 13    """Converts a sql expression into a standard form.
 14
 15    This method relies on annotate_types because many of the
 16    conversions rely on type inference.
 17
 18    Args:
 19        expression: The expression to canonicalize.
 20    """
 21
 22    dialect = Dialect.get_or_raise(dialect)
 23
 24    def _canonicalize(expression: exp.Expression) -> exp.Expression:
 25        expression = add_text_to_concat(expression)
 26        expression = replace_date_funcs(expression)
 27        expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
 28        expression = remove_redundant_casts(expression)
 29        expression = ensure_bools(expression, _replace_int_predicate)
 30        expression = remove_ascending_order(expression)
 31        return expression
 32
 33    return exp.replace_tree(expression, _canonicalize)
 34
 35
 36def add_text_to_concat(node: exp.Expression) -> exp.Expression:
 37    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 38        node = exp.Concat(expressions=[node.left, node.right])
 39    return node
 40
 41
 42def replace_date_funcs(node: exp.Expression) -> exp.Expression:
 43    if (
 44        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 45        and not node.expressions
 46        and not node.args.get("zone")
 47        and node.this.is_string
 48        and is_iso_date(node.this.name)
 49    ):
 50        return exp.cast(node.this, to=exp.DataType.Type.DATE)
 51    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
 52        if not node.type:
 53            from sqlglot.optimizer.annotate_types import annotate_types
 54
 55            node = annotate_types(node)
 56        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
 57
 58    return node
 59
 60
 61COERCIBLE_DATE_OPS = (
 62    exp.Add,
 63    exp.Sub,
 64    exp.EQ,
 65    exp.NEQ,
 66    exp.GT,
 67    exp.GTE,
 68    exp.LT,
 69    exp.LTE,
 70    exp.NullSafeEQ,
 71    exp.NullSafeNEQ,
 72)
 73
 74
 75def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
 76    if isinstance(node, COERCIBLE_DATE_OPS):
 77        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
 78    elif isinstance(node, exp.Between):
 79        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
 80    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
 81        *exp.DataType.TEMPORAL_TYPES
 82    ):
 83        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
 84    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
 85        _coerce_timeunit_arg(node.this, node.unit)
 86    elif isinstance(node, exp.DateDiff):
 87        _coerce_datediff_args(node)
 88
 89    return node
 90
 91
 92def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 93    if (
 94        isinstance(expression, exp.Cast)
 95        and expression.this.type
 96        and expression.to.this == expression.this.type.this
 97    ):
 98        return expression.this
 99    if (
100        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
101        and expression.this.type
102        and expression.this.type.this == exp.DataType.Type.DATE
103    ):
104        return expression.this
105    return expression
106
107
108def ensure_bools(
109    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
110) -> exp.Expression:
111    if isinstance(expression, exp.Connector):
112        replace_func(expression.left)
113        replace_func(expression.right)
114    elif isinstance(expression, exp.Not):
115        replace_func(expression.this)
116        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
117    elif isinstance(expression, exp.If) and not (
118        isinstance(expression.parent, exp.Case) and expression.parent.this
119    ):
120        replace_func(expression.this)
121    elif isinstance(expression, (exp.Where, exp.Having)):
122        replace_func(expression.this)
123
124    return expression
125
126
127def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
128    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
129        # Convert ORDER BY a ASC to ORDER BY a
130        expression.set("desc", None)
131
132    return expression
133
134
135def _coerce_date(
136    a: exp.Expression,
137    b: exp.Expression,
138    promote_to_inferred_datetime_type: bool,
139) -> None:
140    for a, b in itertools.permutations([a, b]):
141        if isinstance(b, exp.Interval):
142            a = _coerce_timeunit_arg(a, b.unit)
143
144        a_type = a.type
145        if (
146            not a_type
147            or a_type.this not in exp.DataType.TEMPORAL_TYPES
148            or not b.type
149            or b.type.this not in exp.DataType.TEXT_TYPES
150        ):
151            continue
152
153        if promote_to_inferred_datetime_type:
154            if b.is_string:
155                date_text = b.name
156                if is_iso_date(date_text):
157                    b_type = exp.DataType.Type.DATE
158                elif is_iso_datetime(date_text):
159                    b_type = exp.DataType.Type.DATETIME
160                else:
161                    b_type = a_type.this
162            else:
163                # If b is not a datetime string, we conservatively promote it to a DATETIME,
164                # in order to ensure there are no surprising truncations due to downcasting
165                b_type = exp.DataType.Type.DATETIME
166
167            target_type = (
168                b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
169            )
170        else:
171            target_type = a_type
172
173        if target_type != a_type:
174            _replace_cast(a, target_type)
175
176        _replace_cast(b, target_type)
177
178
179def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
180    if not arg.type:
181        return arg
182
183    if arg.type.this in exp.DataType.TEXT_TYPES:
184        date_text = arg.name
185        is_iso_date_ = is_iso_date(date_text)
186
187        if is_iso_date_ and is_date_unit(unit):
188            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
189
190        # An ISO date is also an ISO datetime, but not vice versa
191        if is_iso_date_ or is_iso_datetime(date_text):
192            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
193
194    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
195        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
196
197    return arg
198
199
200def _coerce_datediff_args(node: exp.DateDiff) -> None:
201    for e in (node.this, node.expression):
202        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
203            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
204
205
206def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None:
207    node.replace(exp.cast(node.copy(), to=to))
208
209
210# this was originally designed for presto, there is a similar transform for tsql
211# this is different in that it only operates on int types, this is because
212# presto has a boolean type whereas tsql doesn't (people use bits)
213# with y as (select true as x) select x = 0 FROM y -- illegal presto query
214def _replace_int_predicate(expression: exp.Expression) -> None:
215    if isinstance(expression, exp.Coalesce):
216        for child in expression.iter_expressions():
217            _replace_int_predicate(child)
218    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
219        expression.replace(expression.neq(0))
def canonicalize( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
13def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
14    """Converts a sql expression into a standard form.
15
16    This method relies on annotate_types because many of the
17    conversions rely on type inference.
18
19    Args:
20        expression: The expression to canonicalize.
21    """
22
23    dialect = Dialect.get_or_raise(dialect)
24
25    def _canonicalize(expression: exp.Expression) -> exp.Expression:
26        expression = add_text_to_concat(expression)
27        expression = replace_date_funcs(expression)
28        expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
29        expression = remove_redundant_casts(expression)
30        expression = ensure_bools(expression, _replace_int_predicate)
31        expression = remove_ascending_order(expression)
32        return expression
33
34    return exp.replace_tree(expression, _canonicalize)

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
37def add_text_to_concat(node: exp.Expression) -> exp.Expression:
38    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
39        node = exp.Concat(expressions=[node.left, node.right])
40    return node
def replace_date_funcs(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
43def replace_date_funcs(node: exp.Expression) -> exp.Expression:
44    if (
45        isinstance(node, (exp.Date, exp.TsOrDsToDate))
46        and not node.expressions
47        and not node.args.get("zone")
48        and node.this.is_string
49        and is_iso_date(node.this.name)
50    ):
51        return exp.cast(node.this, to=exp.DataType.Type.DATE)
52    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
53        if not node.type:
54            from sqlglot.optimizer.annotate_types import annotate_types
55
56            node = annotate_types(node)
57        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
58
59    return node
def coerce_type( node: sqlglot.expressions.Expression, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.Expression:
76def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
77    if isinstance(node, COERCIBLE_DATE_OPS):
78        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
79    elif isinstance(node, exp.Between):
80        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
81    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
82        *exp.DataType.TEMPORAL_TYPES
83    ):
84        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
85    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
86        _coerce_timeunit_arg(node.this, node.unit)
87    elif isinstance(node, exp.DateDiff):
88        _coerce_datediff_args(node)
89
90    return node
def remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 93def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 94    if (
 95        isinstance(expression, exp.Cast)
 96        and expression.this.type
 97        and expression.to.this == expression.this.type.this
 98    ):
 99        return expression.this
100    if (
101        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
102        and expression.this.type
103        and expression.this.type.this == exp.DataType.Type.DATE
104    ):
105        return expression.this
106    return expression
def ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
109def ensure_bools(
110    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
111) -> exp.Expression:
112    if isinstance(expression, exp.Connector):
113        replace_func(expression.left)
114        replace_func(expression.right)
115    elif isinstance(expression, exp.Not):
116        replace_func(expression.this)
117        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
118    elif isinstance(expression, exp.If) and not (
119        isinstance(expression.parent, exp.Case) and expression.parent.this
120    ):
121        replace_func(expression.this)
122    elif isinstance(expression, (exp.Where, exp.Having)):
123        replace_func(expression.this)
124
125    return expression
def remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
128def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
129    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
130        # Convert ORDER BY a ASC to ORDER BY a
131        expression.set("desc", None)
132
133    return expression