Edit on GitHub

sqlglot.optimizer.canonicalize

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
  8
  9
 10def canonicalize(expression: exp.Expression) -> exp.Expression:
 11    """Converts a sql expression into a standard form.
 12
 13    This method relies on annotate_types because many of the
 14    conversions rely on type inference.
 15
 16    Args:
 17        expression: The expression to canonicalize.
 18    """
 19    exp.replace_children(expression, canonicalize)
 20
 21    expression = add_text_to_concat(expression)
 22    expression = replace_date_funcs(expression)
 23    expression = coerce_type(expression)
 24    expression = remove_redundant_casts(expression)
 25    expression = ensure_bools(expression, _replace_int_predicate)
 26    expression = remove_ascending_order(expression)
 27
 28    return expression
 29
 30
 31def add_text_to_concat(node: exp.Expression) -> exp.Expression:
 32    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
 33        node = exp.Concat(expressions=[node.left, node.right])
 34    return node
 35
 36
 37def replace_date_funcs(node: exp.Expression) -> exp.Expression:
 38    if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
 39        return exp.cast(node.this, to=exp.DataType.Type.DATE)
 40    if isinstance(node, exp.Timestamp) and not node.expression:
 41        if not node.type:
 42            from sqlglot.optimizer.annotate_types import annotate_types
 43
 44            node = annotate_types(node)
 45        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
 46
 47    return node
 48
 49
 50COERCIBLE_DATE_OPS = (
 51    exp.Add,
 52    exp.Sub,
 53    exp.EQ,
 54    exp.NEQ,
 55    exp.GT,
 56    exp.GTE,
 57    exp.LT,
 58    exp.LTE,
 59    exp.NullSafeEQ,
 60    exp.NullSafeNEQ,
 61)
 62
 63
 64def coerce_type(node: exp.Expression) -> exp.Expression:
 65    if isinstance(node, COERCIBLE_DATE_OPS):
 66        _coerce_date(node.left, node.right)
 67    elif isinstance(node, exp.Between):
 68        _coerce_date(node.this, node.args["low"])
 69    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
 70        *exp.DataType.TEMPORAL_TYPES
 71    ):
 72        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
 73    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
 74        _coerce_timeunit_arg(node.this, node.unit)
 75    elif isinstance(node, exp.DateDiff):
 76        _coerce_datediff_args(node)
 77
 78    return node
 79
 80
 81def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
 82    if (
 83        isinstance(expression, exp.Cast)
 84        and expression.this.type
 85        and expression.to.this == expression.this.type.this
 86    ):
 87        return expression.this
 88    return expression
 89
 90
 91def ensure_bools(
 92    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
 93) -> exp.Expression:
 94    if isinstance(expression, exp.Connector):
 95        replace_func(expression.left)
 96        replace_func(expression.right)
 97    elif isinstance(expression, exp.Not):
 98        replace_func(expression.this)
 99        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
100    elif isinstance(expression, exp.If) and not (
101        isinstance(expression.parent, exp.Case) and expression.parent.this
102    ):
103        replace_func(expression.this)
104    elif isinstance(expression, (exp.Where, exp.Having)):
105        replace_func(expression.this)
106
107    return expression
108
109
110def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
111    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
112        # Convert ORDER BY a ASC to ORDER BY a
113        expression.set("desc", None)
114
115    return expression
116
117
118def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
119    for a, b in itertools.permutations([a, b]):
120        if isinstance(b, exp.Interval):
121            a = _coerce_timeunit_arg(a, b.unit)
122        if (
123            a.type
124            and a.type.this == exp.DataType.Type.DATE
125            and b.type
126            and b.type.this
127            not in (
128                exp.DataType.Type.DATE,
129                exp.DataType.Type.INTERVAL,
130            )
131        ):
132            _replace_cast(b, exp.DataType.Type.DATE)
133
134
135def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
136    if not arg.type:
137        return arg
138
139    if arg.type.this in exp.DataType.TEXT_TYPES:
140        date_text = arg.name
141        is_iso_date_ = is_iso_date(date_text)
142
143        if is_iso_date_ and is_date_unit(unit):
144            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
145
146        # An ISO date is also an ISO datetime, but not vice versa
147        if is_iso_date_ or is_iso_datetime(date_text):
148            return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
149
150    elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
151        return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
152
153    return arg
154
155
156def _coerce_datediff_args(node: exp.DateDiff) -> None:
157    for e in (node.this, node.expression):
158        if e.type.this not in exp.DataType.TEMPORAL_TYPES:
159            e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
160
161
162def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
163    node.replace(exp.cast(node.copy(), to=to))
164
165
166# this was originally designed for presto, there is a similar transform for tsql
167# this is different in that it only operates on int types, this is because
168# presto has a boolean type whereas tsql doesn't (people use bits)
169# with y as (select true as x) select x = 0 FROM y -- illegal presto query
170def _replace_int_predicate(expression: exp.Expression) -> None:
171    if isinstance(expression, exp.Coalesce):
172        for _, child in expression.iter_expressions():
173            _replace_int_predicate(child)
174    elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
175        expression.replace(expression.neq(0))
def canonicalize( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
11def canonicalize(expression: exp.Expression) -> exp.Expression:
12    """Converts a sql expression into a standard form.
13
14    This method relies on annotate_types because many of the
15    conversions rely on type inference.
16
17    Args:
18        expression: The expression to canonicalize.
19    """
20    exp.replace_children(expression, canonicalize)
21
22    expression = add_text_to_concat(expression)
23    expression = replace_date_funcs(expression)
24    expression = coerce_type(expression)
25    expression = remove_redundant_casts(expression)
26    expression = ensure_bools(expression, _replace_int_predicate)
27    expression = remove_ascending_order(expression)
28
29    return expression

Converts a sql expression into a standard form.

This method relies on annotate_types because many of the conversions rely on type inference.

Arguments:
  • expression: The expression to canonicalize.
def add_text_to_concat(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
32def add_text_to_concat(node: exp.Expression) -> exp.Expression:
33    if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
34        node = exp.Concat(expressions=[node.left, node.right])
35    return node
def replace_date_funcs(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
38def replace_date_funcs(node: exp.Expression) -> exp.Expression:
39    if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"):
40        return exp.cast(node.this, to=exp.DataType.Type.DATE)
41    if isinstance(node, exp.Timestamp) and not node.expression:
42        if not node.type:
43            from sqlglot.optimizer.annotate_types import annotate_types
44
45            node = annotate_types(node)
46        return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP)
47
48    return node
def coerce_type(node: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
65def coerce_type(node: exp.Expression) -> exp.Expression:
66    if isinstance(node, COERCIBLE_DATE_OPS):
67        _coerce_date(node.left, node.right)
68    elif isinstance(node, exp.Between):
69        _coerce_date(node.this, node.args["low"])
70    elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
71        *exp.DataType.TEMPORAL_TYPES
72    ):
73        _replace_cast(node.expression, exp.DataType.Type.DATETIME)
74    elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
75        _coerce_timeunit_arg(node.this, node.unit)
76    elif isinstance(node, exp.DateDiff):
77        _coerce_datediff_args(node)
78
79    return node
def remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
82def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
83    if (
84        isinstance(expression, exp.Cast)
85        and expression.this.type
86        and expression.to.this == expression.this.type.this
87    ):
88        return expression.this
89    return expression
def ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
 92def ensure_bools(
 93    expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
 94) -> exp.Expression:
 95    if isinstance(expression, exp.Connector):
 96        replace_func(expression.left)
 97        replace_func(expression.right)
 98    elif isinstance(expression, exp.Not):
 99        replace_func(expression.this)
100        # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
101    elif isinstance(expression, exp.If) and not (
102        isinstance(expression.parent, exp.Case) and expression.parent.this
103    ):
104        replace_func(expression.this)
105    elif isinstance(expression, (exp.Where, exp.Having)):
106        replace_func(expression.this)
107
108    return expression
def remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
111def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
112    if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False:
113        # Convert ORDER BY a ASC to ORDER BY a
114        expression.set("desc", None)
115
116    return expression