sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime 9from sqlglot.optimizer.annotate_types import TypeAnnotator 10 11 12def canonicalize(expression: exp.Expr, dialect: DialectType = None) -> exp.Expr: 13 """Converts a sql expression into a standard form. 14 15 This method relies on annotate_types because many of the 16 conversions rely on type inference. 17 18 Args: 19 expression: The expression to canonicalize. 20 """ 21 22 _dialect = Dialect.get_or_raise(dialect) 23 24 def _canonicalize(expression: exp.Expr) -> exp.Expr: 25 if not isinstance(expression, _CANONICALIZE_TYPES): 26 return expression 27 expression = add_text_to_concat(expression) 28 expression = replace_date_funcs(expression, dialect=_dialect) 29 expression = coerce_type(expression, _dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 30 expression = remove_redundant_casts(expression) 31 expression = ensure_bools(expression, _replace_int_predicate) 32 expression = remove_ascending_order(expression) 33 return expression 34 35 return exp.replace_tree(expression, _canonicalize) 36 37 38COERCIBLE_DATE_OPS = ( 39 exp.Add, 40 exp.Sub, 41 exp.EQ, 42 exp.NEQ, 43 exp.GT, 44 exp.GTE, 45 exp.LT, 46 exp.LTE, 47 exp.NullSafeEQ, 48 exp.NullSafeNEQ, 49) 50 51 52# All expression types that any of the canonicalize functions can act on 53_CANONICALIZE_TYPES = tuple( 54 { 55 # add_text_to_concat 56 exp.Add, 57 # replace_date_funcs 58 exp.Date, 59 exp.TsOrDsToDate, 60 exp.Timestamp, 61 # coerce_type (COERCIBLE_DATE_OPS + Between, Extract, DateAdd, DateSub, DateTrunc, DateDiff) 62 *COERCIBLE_DATE_OPS, 63 exp.Between, 64 exp.Extract, 65 exp.DateAdd, 66 exp.DateSub, 67 exp.DateTrunc, 68 exp.DateDiff, 69 # remove_redundant_casts 70 exp.Cast, 71 # ensure_bools (Connector, Not, If, Where, Having) 72 exp.Connector, 73 exp.Not, 74 exp.If, 75 exp.Where, 76 exp.Having, 77 # remove_ascending_order 78 exp.Ordered, 79 } 80) 81 82 83def add_text_to_concat(node: exp.Expr) -> exp.Expr: 84 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 85 node = exp.Concat( 86 expressions=[node.left, node.right], 87 # All known dialects, i.e. Redshift and T-SQL, that support 88 # concatenating strings with the + operator do not coalesce NULLs. 89 coalesce=False, 90 ) 91 return node 92 93 94def replace_date_funcs(node: exp.Expr, dialect: DialectType) -> exp.Expr: 95 if ( 96 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 97 and not node.expressions 98 and not node.args.get("zone") 99 and node.this.is_string 100 and is_iso_date(node.this.name) 101 ): 102 return exp.cast(node.this, to=exp.DType.DATE) 103 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 104 if not node.type: 105 from sqlglot.optimizer.annotate_types import annotate_types 106 107 node = annotate_types(node, dialect=dialect) 108 return exp.cast(node.this, to=node.type or exp.DType.TIMESTAMP) 109 110 return node 111 112 113def coerce_type(node: exp.Expr, promote_to_inferred_datetime_type: bool) -> exp.Expr: 114 if isinstance(node, COERCIBLE_DATE_OPS): 115 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 116 elif isinstance(node, exp.Between): 117 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 118 elif isinstance(node, exp.Extract) and not node.expression.is_type( 119 *exp.DataType.TEMPORAL_TYPES 120 ): 121 _replace_cast(node.expression, exp.DType.DATETIME) 122 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 123 _coerce_timeunit_arg(node.this, node.unit) 124 elif isinstance(node, exp.DateDiff): 125 _coerce_datediff_args(node) 126 127 return node 128 129 130def remove_redundant_casts(expression: exp.Expr) -> exp.Expr: 131 if ( 132 isinstance(expression, exp.Cast) 133 and expression.this.type 134 and expression.to == expression.this.type 135 ): 136 return expression.this 137 138 if ( 139 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 140 and expression.this.type 141 and expression.this.type.this == exp.DType.DATE 142 and not expression.this.type.expressions 143 ): 144 return expression.this 145 146 return expression 147 148 149def ensure_bools(expression: exp.Expr, replace_func: t.Callable[[exp.Expr], None]) -> exp.Expr: 150 if isinstance(expression, exp.Connector): 151 replace_func(expression.left) 152 replace_func(expression.right) 153 elif isinstance(expression, exp.Not): 154 replace_func(expression.this) 155 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 156 elif isinstance(expression, exp.If) and not ( 157 isinstance(expression.parent, exp.Case) and expression.parent.this 158 ): 159 replace_func(expression.this) 160 elif isinstance(expression, (exp.Where, exp.Having)): 161 replace_func(expression.this) 162 163 return expression 164 165 166def remove_ascending_order(expression: exp.Expr) -> exp.Expr: 167 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 168 # Convert ORDER BY a ASC to ORDER BY a 169 expression.set("desc", None) 170 171 return expression 172 173 174def _coerce_date( 175 a: exp.Expr, 176 b: exp.Expr, 177 promote_to_inferred_datetime_type: bool, 178) -> None: 179 for a, b in itertools.permutations([a, b]): 180 if isinstance(b, exp.Interval): 181 a = _coerce_timeunit_arg(a, b.unit) 182 183 a_type = a.type 184 if ( 185 not a_type 186 or a_type.this not in exp.DataType.TEMPORAL_TYPES 187 or not b.type 188 or b.type.this not in exp.DataType.TEXT_TYPES 189 ): 190 continue 191 192 if promote_to_inferred_datetime_type: 193 if b.is_string: 194 date_text = b.name 195 if is_iso_date(date_text): 196 b_type = exp.DType.DATE 197 elif is_iso_datetime(date_text): 198 b_type = exp.DType.DATETIME 199 else: 200 b_type = a_type.this 201 else: 202 # If b is not a datetime string, we conservatively promote it to a DATETIME, 203 # in order to ensure there are no surprising truncations due to downcasting 204 b_type = exp.DType.DATETIME 205 206 target_type = ( 207 b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type 208 ) 209 else: 210 target_type = a_type 211 212 if target_type != a_type: 213 _replace_cast(a, target_type) 214 215 _replace_cast(b, target_type) 216 217 218def _coerce_timeunit_arg(arg: exp.Expr, unit: exp.Expr | None) -> exp.Expr: 219 if not arg.type: 220 return arg 221 222 if arg.type.this in exp.DataType.TEXT_TYPES: 223 date_text = arg.name 224 is_iso_date_ = is_iso_date(date_text) 225 226 if is_iso_date_ and is_date_unit(unit): 227 return arg.replace(exp.cast(arg.copy(), to=exp.DType.DATE)) 228 229 # An ISO date is also an ISO datetime, but not vice versa 230 if is_iso_date_ or is_iso_datetime(date_text): 231 return arg.replace(exp.cast(arg.copy(), to=exp.DType.DATETIME)) 232 233 elif arg.type.this == exp.DType.DATE and not is_date_unit(unit): 234 return arg.replace(exp.cast(arg.copy(), to=exp.DType.DATETIME)) 235 236 return arg 237 238 239def _coerce_datediff_args(node: exp.DateDiff) -> None: 240 for e in (node.this, node.expression): 241 if e.type.this not in exp.DataType.TEMPORAL_TYPES: 242 e.replace(exp.cast(e.copy(), to=exp.DType.DATETIME)) 243 244 245def _replace_cast(node: exp.Expr, to: exp.DATA_TYPE) -> None: 246 node.replace(exp.cast(node.copy(), to=to)) 247 248 249# this was originally designed for presto, there is a similar transform for tsql 250# this is different in that it only operates on int types, this is because 251# presto has a boolean type whereas tsql doesn't (people use bits) 252# with y as (select true as x) select x = 0 FROM y -- illegal presto query 253def _replace_int_predicate(expression: exp.Expr) -> None: 254 if isinstance(expression, exp.Coalesce): 255 for child in expression.iter_expressions(): 256 _replace_int_predicate(child) 257 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 258 expression.replace(expression.neq(0))
def
canonicalize( expression: sqlglot.expressions.core.Expr, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.core.Expr:
13def canonicalize(expression: exp.Expr, dialect: DialectType = None) -> exp.Expr: 14 """Converts a sql expression into a standard form. 15 16 This method relies on annotate_types because many of the 17 conversions rely on type inference. 18 19 Args: 20 expression: The expression to canonicalize. 21 """ 22 23 _dialect = Dialect.get_or_raise(dialect) 24 25 def _canonicalize(expression: exp.Expr) -> exp.Expr: 26 if not isinstance(expression, _CANONICALIZE_TYPES): 27 return expression 28 expression = add_text_to_concat(expression) 29 expression = replace_date_funcs(expression, dialect=_dialect) 30 expression = coerce_type(expression, _dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 31 expression = remove_redundant_casts(expression) 32 expression = ensure_bools(expression, _replace_int_predicate) 33 expression = remove_ascending_order(expression) 34 return expression 35 36 return exp.replace_tree(expression, _canonicalize)
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
COERCIBLE_DATE_OPS =
(<class 'sqlglot.expressions.core.Add'>, <class 'sqlglot.expressions.core.Sub'>, <class 'sqlglot.expressions.core.EQ'>, <class 'sqlglot.expressions.core.NEQ'>, <class 'sqlglot.expressions.core.GT'>, <class 'sqlglot.expressions.core.GTE'>, <class 'sqlglot.expressions.core.LT'>, <class 'sqlglot.expressions.core.LTE'>, <class 'sqlglot.expressions.core.NullSafeEQ'>, <class 'sqlglot.expressions.core.NullSafeNEQ'>)
84def add_text_to_concat(node: exp.Expr) -> exp.Expr: 85 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 86 node = exp.Concat( 87 expressions=[node.left, node.right], 88 # All known dialects, i.e. Redshift and T-SQL, that support 89 # concatenating strings with the + operator do not coalesce NULLs. 90 coalesce=False, 91 ) 92 return node
def
replace_date_funcs( node: sqlglot.expressions.core.Expr, dialect: Union[str, sqlglot.dialects.Dialect, type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.core.Expr:
95def replace_date_funcs(node: exp.Expr, dialect: DialectType) -> exp.Expr: 96 if ( 97 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 98 and not node.expressions 99 and not node.args.get("zone") 100 and node.this.is_string 101 and is_iso_date(node.this.name) 102 ): 103 return exp.cast(node.this, to=exp.DType.DATE) 104 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 105 if not node.type: 106 from sqlglot.optimizer.annotate_types import annotate_types 107 108 node = annotate_types(node, dialect=dialect) 109 return exp.cast(node.this, to=node.type or exp.DType.TIMESTAMP) 110 111 return node
def
coerce_type( node: sqlglot.expressions.core.Expr, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.core.Expr:
114def coerce_type(node: exp.Expr, promote_to_inferred_datetime_type: bool) -> exp.Expr: 115 if isinstance(node, COERCIBLE_DATE_OPS): 116 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 117 elif isinstance(node, exp.Between): 118 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 119 elif isinstance(node, exp.Extract) and not node.expression.is_type( 120 *exp.DataType.TEMPORAL_TYPES 121 ): 122 _replace_cast(node.expression, exp.DType.DATETIME) 123 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 124 _coerce_timeunit_arg(node.this, node.unit) 125 elif isinstance(node, exp.DateDiff): 126 _coerce_datediff_args(node) 127 128 return node
def
remove_redundant_casts( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
131def remove_redundant_casts(expression: exp.Expr) -> exp.Expr: 132 if ( 133 isinstance(expression, exp.Cast) 134 and expression.this.type 135 and expression.to == expression.this.type 136 ): 137 return expression.this 138 139 if ( 140 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 141 and expression.this.type 142 and expression.this.type.this == exp.DType.DATE 143 and not expression.this.type.expressions 144 ): 145 return expression.this 146 147 return expression
def
ensure_bools( expression: sqlglot.expressions.core.Expr, replace_func: Callable[[sqlglot.expressions.core.Expr], NoneType]) -> sqlglot.expressions.core.Expr:
150def ensure_bools(expression: exp.Expr, replace_func: t.Callable[[exp.Expr], None]) -> exp.Expr: 151 if isinstance(expression, exp.Connector): 152 replace_func(expression.left) 153 replace_func(expression.right) 154 elif isinstance(expression, exp.Not): 155 replace_func(expression.this) 156 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 157 elif isinstance(expression, exp.If) and not ( 158 isinstance(expression.parent, exp.Case) and expression.parent.this 159 ): 160 replace_func(expression.this) 161 elif isinstance(expression, (exp.Where, exp.Having)): 162 replace_func(expression.this) 163 164 return expression
def
remove_ascending_order( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr: