Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.dialects.dialect import Dialect, DialectType
  8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  9from sqlglot.optimizer.annotate_types import TypeAnnotator
 10
 11
 12def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
 13    """Converts a sql expression into a standard form.
 14
 15    This method relies on annotate_types because many of the
 16    conversions rely on type inference.
 17
 18    Args:
 19        expression: The expression to canonicalize.
 20    """
 21
 22    dialect = Dialect.get_or_raise(dialect)
 23
 24    def _canonicalize(expression: exp.Expression) -> exp.Expression:
 25        expression = add_text_to_concat(expression)
 26        expression = replace_date_funcs(expression, dialect=dialect)
 27        expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
 28        expression = remove_redundant_casts(expression)
 29        expression = ensure_bools(expression, _replace_int_predicate)
 30        expression = remove_ascending_order(expression)
 31        return expression
 32
 33    return exp.replace_tree(expression, _canonicalize)
 34
 35
 36def add_text_to_concat(node: exp.Expression) -> exp.Expression:
 37    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 38        node = exp.Concat(
 39            expressions=[node.left, node.right],
 40            # All known dialects, i.e. Redshift and T-SQL, that support
 41            # concatenating strings with the + operator do not coalesce NULLs.
 42            coalesce=False,
 43        )
 44    return node
 45
 46
 47def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
 48    if (
 49        isinstance(node, (exp.Date, exp.TsOrDsToDate))
 50        and not node.expressions
 51        and not node.args.get("zone")
 52        and node.this.is_string
 53        and is_iso_date(node.this.name)
 54    ):
 55        return exp.cast(node.this, to=exp.DataType.Type.DATE)
 56    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
 57        if not node.type:
 58            from sqlglot.optimizer.annotate_types import annotate_types
 59
 60            node = annotate_types(node, dialect=dialect)
 61        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
 62
 63    return node
 64
 65
 66COERCIBLE_DATE_OPS = (
 67    exp.Add,
 68    exp.Sub,
 69    exp.EQ,
 70    exp.NEQ,
 71    exp.GT,
 72    exp.GTE,
 73    exp.LT,
 74    exp.LTE,
 75    exp.NullSafeEQ,
 76    exp.NullSafeNEQ,
 77)
 78
 79
 80def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
 81    if isinstance(node, COERCIBLE_DATE_OPS):
 82        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
 83    elif isinstance(node, exp.Between):
 84        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
 85    elif isinstance(node, exp.Extract) and not node.expression.is_type(
 86        *exp.DataType.TEMPORAL_TYPES
 87    ):
 88        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
 89    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
 90        _coerce_timeunit_arg(node.this, node.unit)
 91    elif isinstance(node, exp.DateDiff):
 92        _coerce_datediff_args(node)
 93
 94    return node
 95
 96
 97def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 98    if (
 99        isinstance(expression, exp.Cast)
100        and expression.this.type
101        and expression.to == expression.this.type
102    ):
103        return expression.this
104
105    if (
106        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
107        and expression.this.type
108        and expression.this.type.this == exp.DataType.Type.DATE
109        and not expression.this.type.expressions
110    ):
111        return expression.this
112
113    return expression
114
115
116def ensure_bools(
117    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
118) -> exp.Expression:
119    if isinstance(expression, exp.Connector):
120        replace_func(expression.left)
121        replace_func(expression.right)
122    elif isinstance(expression, exp.Not):
123        replace_func(expression.this)
124        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
125    elif isinstance(expression, exp.If) and not (
126        isinstance(expression.parent, exp.Case) and expression.parent.this
127    ):
128        replace_func(expression.this)
129    elif isinstance(expression, (exp.Where, exp.Having)):
130        replace_func(expression.this)
131
132    return expression
133
134
135def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
136    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
137        # Convert ORDER BY a ASC to ORDER BY a
138        expression.set("desc", None)
139
140    return expression
141
142
143def _coerce_date(
144    a: exp.Expression,
145    b: exp.Expression,
146    promote_to_inferred_datetime_type: bool,
147) -> None:
148    for a, b in itertools.permutations([a, b]):
149        if isinstance(b, exp.Interval):
150            a = _coerce_timeunit_arg(a, b.unit)
151
152        a_type = a.type
153        if (
154            not a_type
155            or a_type.this not in exp.DataType.TEMPORAL_TYPES
156            or not b.type
157            or b.type.this not in exp.DataType.TEXT_TYPES
158        ):
159            continue
160
161        if promote_to_inferred_datetime_type:
162            if b.is_string:
163                date_text = b.name
164                if is_iso_date(date_text):
165                    b_type = exp.DataType.Type.DATE
166                elif is_iso_datetime(date_text):
167                    b_type = exp.DataType.Type.DATETIME
168                else:
169                    b_type = a_type.this
170            else:
171                # If b is not a datetime string, we conservatively promote it to a DATETIME,
172                # in order to ensure there are no surprising truncations due to downcasting
173                b_type = exp.DataType.Type.DATETIME
174
175            target_type = (
176                b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type
177            )
178        else:
179            target_type = a_type
180
181        if target_type != a_type:
182            _replace_cast(a, target_type)
183
184        _replace_cast(b, target_type)
185
186
187def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
188    if not arg.type:
189        return arg
190
191    if arg.type.this in exp.DataType.TEXT_TYPES:
192        date_text = arg.name
193        is_iso_date_ = is_iso_date(date_text)
194
195        if is_iso_date_ and is_date_unit(unit):
196            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
197
198        # An ISO date is also an ISO datetime, but not vice versa
199        if is_iso_date_ or is_iso_datetime(date_text):
200            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
201
202    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
203        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
204
205    return arg
206
207
208def _coerce_datediff_args(node: exp.DateDiff) -> None:
209    for e in (node.this, node.expression):
210        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
211            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
212
213
214def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None:
215    node.replace(exp.cast(node.copy(), to=to))
216
217
218# this was originally designed for presto, there is a similar transform for tsql
219# this is different in that it only operates on int types, this is because
220# presto has a boolean type whereas tsql doesn't (people use bits)
221# with y as (select true as x) select x = 0 FROM y -- illegal presto query
222def _replace_int_predicate(expression: exp.Expression) -> None:
223    if isinstance(expression, exp.Coalesce):
224        for child in expression.iter_expressions():
225            _replace_int_predicate(child)
226    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
227        expression.replace(expression.neq(0))
def canonicalize( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
13def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression:
14    """Converts a sql expression into a standard form.
15
16    This method relies on annotate_types because many of the
17    conversions rely on type inference.
18
19    Args:
20        expression: The expression to canonicalize.
21    """
22
23    dialect = Dialect.get_or_raise(dialect)
24
25    def _canonicalize(expression: exp.Expression) -> exp.Expression:
26        expression = add_text_to_concat(expression)
27        expression = replace_date_funcs(expression, dialect=dialect)
28        expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE)
29        expression = remove_redundant_casts(expression)
30        expression = ensure_bools(expression, _replace_int_predicate)
31        expression = remove_ascending_order(expression)
32        return expression
33
34    return exp.replace_tree(expression, _canonicalize)

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
37def add_text_to_concat(node: exp.Expression) -> exp.Expression:
38    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
39        node = exp.Concat(
40            expressions=[node.left, node.right],
41            # All known dialects, i.e. Redshift and T-SQL, that support
42            # concatenating strings with the + operator do not coalesce NULLs.
43            coalesce=False,
44        )
45    return node
def replace_date_funcs( node: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
48def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression:
49    if (
50        isinstance(node, (exp.Date, exp.TsOrDsToDate))
51        and not node.expressions
52        and not node.args.get("zone")
53        and node.this.is_string
54        and is_iso_date(node.this.name)
55    ):
56        return exp.cast(node.this, to=exp.DataType.Type.DATE)
57    if isinstance(node, exp.Timestamp) and not node.args.get("zone"):
58        if not node.type:
59            from sqlglot.optimizer.annotate_types import annotate_types
60
61            node = annotate_types(node, dialect=dialect)
62        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
63
64    return node
def coerce_type( node: sqlglot.expressions.Expression, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.Expression:
81def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression:
82    if isinstance(node, COERCIBLE_DATE_OPS):
83        _coerce_date(node.left, node.right, promote_to_inferred_datetime_type)
84    elif isinstance(node, exp.Between):
85        _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type)
86    elif isinstance(node, exp.Extract) and not node.expression.is_type(
87        *exp.DataType.TEMPORAL_TYPES
88    ):
89        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
90    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
91        _coerce_timeunit_arg(node.this, node.unit)
92    elif isinstance(node, exp.DateDiff):
93        _coerce_datediff_args(node)
94
95    return node
def remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 98def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 99    if (
100        isinstance(expression, exp.Cast)
101        and expression.this.type
102        and expression.to == expression.this.type
103    ):
104        return expression.this
105
106    if (
107        isinstance(expression, (exp.Date, exp.TsOrDsToDate))
108        and expression.this.type
109        and expression.this.type.this == exp.DataType.Type.DATE
110        and not expression.this.type.expressions
111    ):
112        return expression.this
113
114    return expression
def ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
117def ensure_bools(
118    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
119) -> exp.Expression:
120    if isinstance(expression, exp.Connector):
121        replace_func(expression.left)
122        replace_func(expression.right)
123    elif isinstance(expression, exp.Not):
124        replace_func(expression.this)
125        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
126    elif isinstance(expression, exp.If) and not (
127        isinstance(expression.parent, exp.Case) and expression.parent.this
128    ):
129        replace_func(expression.this)
130    elif isinstance(expression, (exp.Where, exp.Having)):
131        replace_func(expression.this)
132
133    return expression
def remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
136def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
137    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
138        # Convert ORDER BY a ASC to ORDER BY a
139        expression.set("desc", None)
140
141    return expression