Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  9from sqlglot.optimizer.annotate_types import TypeAnnotator
 10
 11
 12def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
 13    """Converts a sql expression into a standard form.
 14
 15    This method relies on annotate_types because many of the
 16    conversions rely on type inference.
 17
 18    Args:
 19        expression: The expression to canonicalize.
 20    """
 21
 22    dialect = Dialect.get_or_raise(dialect)
 23
 24    def _canonicalize(expression: exp.Expression) -> exp.Expression:
 25        if not isinstance(expression, _CANONICALIZE_TYPES):
 26            return expression
 27        expression = add_text_to_concat(expression)
 28        expression = replace_date_funcs(expression, dialect=dialect)
 29        expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
 30        expression = remove_redundant_casts(expression)
 31        expression = ensure_bools(expression, _replace_int_predicate)
 32        expression = remove_ascending_order(expression)
 33        return expression
 34
 35    return exp.replace_tree(expression, _canonicalize)
 36
 37
 38COERCIBLE_DATE_OPS = (
 39    exp.Add,
 40    exp.Sub,
 41    exp.EQ,
 42    exp.NEQ,
 43    exp.GT,
 44    exp.GTE,
 45    exp.LT,
 46    exp.LTE,
 47    exp.NullSafeEQ,
 48    exp.NullSafeNEQ,
 49)
 50
 51
 52# All expression types that any of the canonicalize functions can act on
 53_CANONICALIZE_TYPES = tuple(
 54    {
 55        # add_text_to_concat
 56        exp.Add,
 57        # replace_date_funcs
 58        exp.Date,
 59        exp.TsOrDsToDate,
 60        exp.Timestamp,
 61        # coerce_type (COERCIBLE_DATE_OPS + Between, Extract, DateAdd, DateSub, DateTrunc, DateDiff)
 62        *COERCIBLE_DATE_OPS,
 63        exp.Between,
 64        exp.Extract,
 65        exp.DateAdd,
 66        exp.DateSub,
 67        exp.DateTrunc,
 68        exp.DateDiff,
 69        # remove_redundant_casts
 70        exp.Cast,
 71        # ensure_bools (Connector, Not, If, Where, Having)
 72        exp.Connector,
 73        exp.Not,
 74        exp.If,
 75        exp.Where,
 76        exp.Having,
 77        # remove_ascending_order
 78        exp.Ordered,
 79    }
 80)
 81
 82
 83def add_text_to_concat(node: exp.Expression) -> exp.Expression:
 84    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 85        node = exp.Concat(
 86            expressions=[node.left, node.right],
 87            # All known dialects, i.e. Redshift and T-SQL, that support
 88            # concatenating strings with the + operator do not coalesce NULLs.
 89            coalesce=False,
 90        )
 91    return node
 92
 93
 94def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
 95    if (
 96        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 97        and not node.expressions
 98        and not node.args.get("zone")
 99        and node.this.is_string
100        and is_iso_date(node.this.name)
101    ):
102        return exp.cast(node.this, to=exp.DataType.Type.DATE)
103    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
104        if not node.type:
105            from sqlglot.optimizer.annotate_types import annotate_types
106
107            node = annotate_types(node, dialect=dialect)
108        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
109
110    return node
111
112
113def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
114    if isinstance(node, COERCIBLE_DATE_OPS):
115        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
116    elif isinstance(node, exp.Between):
117        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
118    elif isinstance(node, exp.Extract) and not node.expression.is_type(
119        *exp.DataType.TEMPORAL_TYPES
120    ):
121        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
122    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
123        _coerce_timeunit_arg(node.this, node.unit)
124    elif isinstance(node, exp.DateDiff):
125        _coerce_datediff_args(node)
126
127    return node
128
129
130def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
131    if (
132        isinstance(expression, exp.Cast)
133        and expression.this.type
134        and expression.to == expression.this.type
135    ):
136        return expression.this
137
138    if (
139        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
140        and expression.this.type
141        and expression.this.type.this == exp.DataType.Type.DATE
142        and not expression.this.type.expressions
143    ):
144        return expression.this
145
146    return expression
147
148
149def ensure_bools(
150    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
151) -> exp.Expression:
152    if isinstance(expression, exp.Connector):
153        replace_func(expression.left)
154        replace_func(expression.right)
155    elif isinstance(expression, exp.Not):
156        replace_func(expression.this)
157        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
158    elif isinstance(expression, exp.If) and not (
159        isinstance(expression.parent, exp.Case) and expression.parent.this
160    ):
161        replace_func(expression.this)
162    elif isinstance(expression, (exp.Where, exp.Having)):
163        replace_func(expression.this)
164
165    return expression
166
167
168def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
169    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
170        # Convert ORDER BY a ASC to ORDER BY a
171        expression.set("desc", None)
172
173    return expression
174
175
176def _coerce_date(
177    a: exp.Expression,
178    b: exp.Expression,
179    promote_to_inferred_datetime_type: bool,
180) -> None:
181    for a, b in itertools.permutations([a, b]):
182        if isinstance(b, exp.Interval):
183            a = _coerce_timeunit_arg(a, b.unit)
184
185        a_type = a.type
186        if (
187            not a_type
188            or a_type.this not in exp.DataType.TEMPORAL_TYPES
189            or not b.type
190            or b.type.this not in exp.DataType.TEXT_TYPES
191        ):
192            continue
193
194        if promote_to_inferred_datetime_type:
195            if b.is_string:
196                date_text = b.name
197                if is_iso_date(date_text):
198                    b_type = exp.DataType.Type.DATE
199                elif is_iso_datetime(date_text):
200                    b_type = exp.DataType.Type.DATETIME
201                else:
202                    b_type = a_type.this
203            else:
204                # If b is not a datetime string, we conservatively promote it to a DATETIME,
205                # in order to ensure there are no surprising truncations due to downcasting
206                b_type = exp.DataType.Type.DATETIME
207
208            target_type = (
209                b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
210            )
211        else:
212            target_type = a_type
213
214        if target_type != a_type:
215            _replace_cast(a, target_type)
216
217        _replace_cast(b, target_type)
218
219
220def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
221    if not arg.type:
222        return arg
223
224    if arg.type.this in exp.DataType.TEXT_TYPES:
225        date_text = arg.name
226        is_iso_date_ = is_iso_date(date_text)
227
228        if is_iso_date_ and is_date_unit(unit):
229            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
230
231        # An ISO date is also an ISO datetime, but not vice versa
232        if is_iso_date_ or is_iso_datetime(date_text):
233            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
234
235    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
236        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
237
238    return arg
239
240
241def _coerce_datediff_args(node: exp.DateDiff) -> None:
242    for e in (node.this, node.expression):
243        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
244            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
245
246
247def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None:
248    node.replace(exp.cast(node.copy(), to=to))
249
250
251# this was originally designed for presto, there is a similar transform for tsql
252# this is different in that it only operates on int types, this is because
253# presto has a boolean type whereas tsql doesn't (people use bits)
254# with y as (select true as x) select x = 0 FROM y -- illegal presto query
255def _replace_int_predicate(expression: exp.Expression) -> None:
256    if isinstance(expression, exp.Coalesce):
257        for child in expression.iter_expressions():
258            _replace_int_predicate(child)
259    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
260        expression.replace(expression.neq(0))
def canonicalize( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
13def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
14    """Converts a sql expression into a standard form.
15
16    This method relies on annotate_types because many of the
17    conversions rely on type inference.
18
19    Args:
20        expression: The expression to canonicalize.
21    """
22
23    dialect = Dialect.get_or_raise(dialect)
24
25    def _canonicalize(expression: exp.Expression) -> exp.Expression:
26        if not isinstance(expression, _CANONICALIZE_TYPES):
27            return expression
28        expression = add_text_to_concat(expression)
29        expression = replace_date_funcs(expression, dialect=dialect)
30        expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
31        expression = remove_redundant_casts(expression)
32        expression = ensure_bools(expression, _replace_int_predicate)
33        expression = remove_ascending_order(expression)
34        return expression
35
36    return exp.replace_tree(expression, _canonicalize)

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
84def add_text_to_concat(node: exp.Expression) -> exp.Expression:
85    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
86        node = exp.Concat(
87            expressions=[node.left, node.right],
88            # All known dialects, i.e. Redshift and T-SQL, that support
89            # concatenating strings with the + operator do not coalesce NULLs.
90            coalesce=False,
91        )
92    return node
def replace_date_funcs( node: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
 95def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
 96    if (
 97        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 98        and not node.expressions
 99        and not node.args.get("zone")
100        and node.this.is_string
101        and is_iso_date(node.this.name)
102    ):
103        return exp.cast(node.this, to=exp.DataType.Type.DATE)
104    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
105        if not node.type:
106            from sqlglot.optimizer.annotate_types import annotate_types
107
108            node = annotate_types(node, dialect=dialect)
109        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
110
111    return node
def coerce_type( node: sqlglot.expressions.Expression, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.Expression:
114def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
115    if isinstance(node, COERCIBLE_DATE_OPS):
116        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
117    elif isinstance(node, exp.Between):
118        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
119    elif isinstance(node, exp.Extract) and not node.expression.is_type(
120        *exp.DataType.TEMPORAL_TYPES
121    ):
122        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
123    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
124        _coerce_timeunit_arg(node.this, node.unit)
125    elif isinstance(node, exp.DateDiff):
126        _coerce_datediff_args(node)
127
128    return node
def remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
131def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
132    if (
133        isinstance(expression, exp.Cast)
134        and expression.this.type
135        and expression.to == expression.this.type
136    ):
137        return expression.this
138
139    if (
140        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
141        and expression.this.type
142        and expression.this.type.this == exp.DataType.Type.DATE
143        and not expression.this.type.expressions
144    ):
145        return expression.this
146
147    return expression
def ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
150def ensure_bools(
151    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
152) -> exp.Expression:
153    if isinstance(expression, exp.Connector):
154        replace_func(expression.left)
155        replace_func(expression.right)
156    elif isinstance(expression, exp.Not):
157        replace_func(expression.this)
158        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
159    elif isinstance(expression, exp.If) and not (
160        isinstance(expression.parent, exp.Case) and expression.parent.this
161    ):
162        replace_func(expression.this)
163    elif isinstance(expression, (exp.Where, exp.Having)):
164        replace_func(expression.this)
165
166    return expression
def remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
169def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
170    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
171        # Convert ORDER BY a ASC to ORDER BY a
172        expression.set("desc", None)
173
174    return expression