Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  9from sqlglot.optimizer.annotate_types import TypeAnnotator
 10
 11
 12def canonicalize(expression: exp.Expr, dialect: DialectType = None) -> exp.Expr:
 13    """Converts a sql expression into a standard form.
 14
 15    This method relies on annotate_types because many of the
 16    conversions rely on type inference.
 17
 18    Args:
 19        expression: The expression to canonicalize.
 20    """
 21
 22    _dialect = Dialect.get_or_raise(dialect)
 23
 24    def _canonicalize(expression: exp.Expr) -> exp.Expr:
 25        if not isinstance(expression, _CANONICALIZE_TYPES):
 26            return expression
 27        expression = add_text_to_concat(expression)
 28        expression = replace_date_funcs(expression, dialect=_dialect)
 29        expression = coerce_type(expression, _dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
 30        expression = remove_redundant_casts(expression)
 31        expression = ensure_bools(expression, _replace_int_predicate)
 32        expression = remove_ascending_order(expression)
 33        return expression
 34
 35    return exp.replace_tree(expression, _canonicalize)
 36
 37
 38COERCIBLE_DATE_OPS = (
 39    exp.Add,
 40    exp.Sub,
 41    exp.EQ,
 42    exp.NEQ,
 43    exp.GT,
 44    exp.GTE,
 45    exp.LT,
 46    exp.LTE,
 47    exp.NullSafeEQ,
 48    exp.NullSafeNEQ,
 49)
 50
 51
 52# All expression types that any of the canonicalize functions can act on
 53_CANONICALIZE_TYPES = tuple(
 54    {
 55        # add_text_to_concat
 56        exp.Add,
 57        # replace_date_funcs
 58        exp.Date,
 59        exp.TsOrDsToDate,
 60        exp.Timestamp,
 61        # coerce_type (COERCIBLE_DATE_OPS + Between, Extract, DateAdd, DateSub, DateTrunc, DateDiff)
 62        *COERCIBLE_DATE_OPS,
 63        exp.Between,
 64        exp.Extract,
 65        exp.DateAdd,
 66        exp.DateSub,
 67        exp.DateTrunc,
 68        exp.DateDiff,
 69        # remove_redundant_casts
 70        exp.Cast,
 71        # ensure_bools (Connector, Not, If, Where, Having)
 72        exp.Connector,
 73        exp.Not,
 74        exp.If,
 75        exp.Where,
 76        exp.Having,
 77        # remove_ascending_order
 78        exp.Ordered,
 79    }
 80)
 81
 82
 83def add_text_to_concat(node: exp.Expr) -> exp.Expr:
 84    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 85        node = exp.Concat(
 86            expressions=[node.left, node.right],
 87            # All known dialects, i.e. Redshift and T-SQL, that support
 88            # concatenating strings with the + operator do not coalesce NULLs.
 89            coalesce=False,
 90        )
 91    return node
 92
 93
 94def replace_date_funcs(node: exp.Expr, dialect: DialectType) -> exp.Expr:
 95    if (
 96        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 97        and not node.expressions
 98        and not node.args.get("zone")
 99        and node.this.is_string
100        and is_iso_date(node.this.name)
101    ):
102        return exp.cast(node.this, to=exp.DType.DATE)
103    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
104        if not node.type:
105            from sqlglot.optimizer.annotate_types import annotate_types
106
107            node = annotate_types(node, dialect=dialect)
108        return exp.cast(node.this, to=node.type or exp.DType.TIMESTAMP)
109
110    return node
111
112
113def coerce_type(node: exp.Expr, promote_to_inferred_datetime_type: bool) -> exp.Expr:
114    if isinstance(node, COERCIBLE_DATE_OPS):
115        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
116    elif isinstance(node, exp.Between):
117        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
118    elif isinstance(node, exp.Extract) and not node.expression.is_type(
119        *exp.DataType.TEMPORAL_TYPES
120    ):
121        _replace_cast(node.expression, exp.DType.DATETIME)
122    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
123        _coerce_timeunit_arg(node.this, node.unit)
124    elif isinstance(node, exp.DateDiff):
125        _coerce_datediff_args(node)
126
127    return node
128
129
130def remove_redundant_casts(expression: exp.Expr) -> exp.Expr:
131    if (
132        isinstance(expression, exp.Cast)
133        and expression.this.type
134        and expression.to == expression.this.type
135    ):
136        return expression.this
137
138    if (
139        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
140        and expression.this.type
141        and expression.this.type.this == exp.DType.DATE
142        and not expression.this.type.expressions
143    ):
144        return expression.this
145
146    return expression
147
148
149def ensure_bools(expression: exp.Expr, replace_func: t.Callable[[exp.Expr], None]) -> exp.Expr:
150    if isinstance(expression, exp.Connector):
151        replace_func(expression.left)
152        replace_func(expression.right)
153    elif isinstance(expression, exp.Not):
154        replace_func(expression.this)
155        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
156    elif isinstance(expression, exp.If) and not (
157        isinstance(expression.parent, exp.Case) and expression.parent.this
158    ):
159        replace_func(expression.this)
160    elif isinstance(expression, (exp.Where, exp.Having)):
161        replace_func(expression.this)
162
163    return expression
164
165
166def remove_ascending_order(expression: exp.Expr) -> exp.Expr:
167    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
168        # Convert ORDER BY a ASC to ORDER BY a
169        expression.set("desc", None)
170
171    return expression
172
173
174def _coerce_date(
175    a: exp.Expr,
176    b: exp.Expr,
177    promote_to_inferred_datetime_type: bool,
178) -> None:
179    for a, b in itertools.permutations([a, b]):
180        if isinstance(b, exp.Interval):
181            a = _coerce_timeunit_arg(a, b.unit)
182
183        a_type = a.type
184        if (
185            not a_type
186            or a_type.this not in exp.DataType.TEMPORAL_TYPES
187            or not b.type
188            or b.type.this not in exp.DataType.TEXT_TYPES
189        ):
190            continue
191
192        if promote_to_inferred_datetime_type:
193            if b.is_string:
194                date_text = b.name
195                if is_iso_date(date_text):
196                    b_type = exp.DType.DATE
197                elif is_iso_datetime(date_text):
198                    b_type = exp.DType.DATETIME
199                else:
200                    b_type = a_type.this
201            else:
202                # If b is not a datetime string, we conservatively promote it to a DATETIME,
203                # in order to ensure there are no surprising truncations due to downcasting
204                b_type = exp.DType.DATETIME
205
206            target_type = (
207                b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
208            )
209        else:
210            target_type = a_type
211
212        if target_type != a_type:
213            _replace_cast(a, target_type)
214
215        _replace_cast(b, target_type)
216
217
218def _coerce_timeunit_arg(arg: exp.Expr, unit: exp.Expr | None) -> exp.Expr:
219    if not arg.type:
220        return arg
221
222    if arg.type.this in exp.DataType.TEXT_TYPES:
223        date_text = arg.name
224        is_iso_date_ = is_iso_date(date_text)
225
226        if is_iso_date_ and is_date_unit(unit):
227            return arg.replace(exp.cast(arg.copy(), to=exp.DType.DATE))
228
229        # An ISO date is also an ISO datetime, but not vice versa
230        if is_iso_date_ or is_iso_datetime(date_text):
231            return arg.replace(exp.cast(arg.copy(), to=exp.DType.DATETIME))
232
233    elif arg.type.this == exp.DType.DATE and not is_date_unit(unit):
234        return arg.replace(exp.cast(arg.copy(), to=exp.DType.DATETIME))
235
236    return arg
237
238
239def _coerce_datediff_args(node: exp.DateDiff) -> None:
240    for e in (node.this, node.expression):
241        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
242            e.replace(exp.cast(e.copy(), to=exp.DType.DATETIME))
243
244
245def _replace_cast(node: exp.Expr, to: exp.DATA_TYPE) -> None:
246    node.replace(exp.cast(node.copy(), to=to))
247
248
249# this was originally designed for presto, there is a similar transform for tsql
250# this is different in that it only operates on int types, this is because
251# presto has a boolean type whereas tsql doesn't (people use bits)
252# with y as (select true as x) select x = 0 FROM y -- illegal presto query
253def _replace_int_predicate(expression: exp.Expr) -> None:
254    if isinstance(expression, exp.Coalesce):
255        for child in expression.iter_expressions():
256            _replace_int_predicate(child)
257    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
258        expression.replace(expression.neq(0))
def canonicalize( expression: sqlglot.expressions.core.Expr, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.core.Expr:
13def canonicalize(expression: exp.Expr, dialect: DialectType = None) -> exp.Expr:
14    """Converts a sql expression into a standard form.
15
16    This method relies on annotate_types because many of the
17    conversions rely on type inference.
18
19    Args:
20        expression: The expression to canonicalize.
21    """
22
23    _dialect = Dialect.get_or_raise(dialect)
24
25    def _canonicalize(expression: exp.Expr) -> exp.Expr:
26        if not isinstance(expression, _CANONICALIZE_TYPES):
27            return expression
28        expression = add_text_to_concat(expression)
29        expression = replace_date_funcs(expression, dialect=_dialect)
30        expression = coerce_type(expression, _dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
31        expression = remove_redundant_casts(expression)
32        expression = ensure_bools(expression, _replace_int_predicate)
33        expression = remove_ascending_order(expression)
34        return expression
35
36    return exp.replace_tree(expression, _canonicalize)

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
84def add_text_to_concat(node: exp.Expr) -> exp.Expr:
85    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
86        node = exp.Concat(
87            expressions=[node.left, node.right],
88            # All known dialects, i.e. Redshift and T-SQL, that support
89            # concatenating strings with the + operator do not coalesce NULLs.
90            coalesce=False,
91        )
92    return node
def replace_date_funcs( node: sqlglot.expressions.core.Expr, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.core.Expr:
 95def replace_date_funcs(node: exp.Expr, dialect: DialectType) -> exp.Expr:
 96    if (
 97        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 98        and not node.expressions
 99        and not node.args.get("zone")
100        and node.this.is_string
101        and is_iso_date(node.this.name)
102    ):
103        return exp.cast(node.this, to=exp.DType.DATE)
104    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
105        if not node.type:
106            from sqlglot.optimizer.annotate_types import annotate_types
107
108            node = annotate_types(node, dialect=dialect)
109        return exp.cast(node.this, to=node.type or exp.DType.TIMESTAMP)
110
111    return node
def coerce_type( node: sqlglot.expressions.core.Expr, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.core.Expr:
114def coerce_type(node: exp.Expr, promote_to_inferred_datetime_type: bool) -> exp.Expr:
115    if isinstance(node, COERCIBLE_DATE_OPS):
116        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
117    elif isinstance(node, exp.Between):
118        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
119    elif isinstance(node, exp.Extract) and not node.expression.is_type(
120        *exp.DataType.TEMPORAL_TYPES
121    ):
122        _replace_cast(node.expression, exp.DType.DATETIME)
123    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
124        _coerce_timeunit_arg(node.this, node.unit)
125    elif isinstance(node, exp.DateDiff):
126        _coerce_datediff_args(node)
127
128    return node
def remove_redundant_casts( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
131def remove_redundant_casts(expression: exp.Expr) -> exp.Expr:
132    if (
133        isinstance(expression, exp.Cast)
134        and expression.this.type
135        and expression.to == expression.this.type
136    ):
137        return expression.this
138
139    if (
140        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
141        and expression.this.type
142        and expression.this.type.this == exp.DType.DATE
143        and not expression.this.type.expressions
144    ):
145        return expression.this
146
147    return expression
def ensure_bools( expression: sqlglot.expressions.core.Expr, replace_func: Callable[[sqlglot.expressions.core.Expr], NoneType]) -> sqlglot.expressions.core.Expr:
150def ensure_bools(expression: exp.Expr, replace_func: t.Callable[[exp.Expr], None]) -> exp.Expr:
151    if isinstance(expression, exp.Connector):
152        replace_func(expression.left)
153        replace_func(expression.right)
154    elif isinstance(expression, exp.Not):
155        replace_func(expression.this)
156        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
157    elif isinstance(expression, exp.If) and not (
158        isinstance(expression.parent, exp.Case) and expression.parent.this
159    ):
160        replace_func(expression.this)
161    elif isinstance(expression, (exp.Where, exp.Having)):
162        replace_func(expression.this)
163
164    return expression
def remove_ascending_order( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
167def remove_ascending_order(expression: exp.Expr) -> exp.Expr:
168    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
169        # Convert ORDER BY a ASC to ORDER BY a
170        expression.set("desc", None)
171
172    return expression