sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime 9from sqlglot.optimizer.annotate_types import TypeAnnotator 10 11 12def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 13 """Converts a sql expression into a standard form. 14 15 This method relies on annotate_types because many of the 16 conversions rely on type inference. 17 18 Args: 19 expression: The expression to canonicalize. 20 """ 21 22 dialect = Dialect.get_or_raise(dialect) 23 24 def _canonicalize(expression: exp.Expression) -> exp.Expression: 25 if not isinstance(expression, _CANONICALIZE_TYPES): 26 return expression 27 expression = add_text_to_concat(expression) 28 expression = replace_date_funcs(expression, dialect=dialect) 29 expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 30 expression = remove_redundant_casts(expression) 31 expression = ensure_bools(expression, _replace_int_predicate) 32 expression = remove_ascending_order(expression) 33 return expression 34 35 return exp.replace_tree(expression, _canonicalize) 36 37 38COERCIBLE_DATE_OPS = ( 39 exp.Add, 40 exp.Sub, 41 exp.EQ, 42 exp.NEQ, 43 exp.GT, 44 exp.GTE, 45 exp.LT, 46 exp.LTE, 47 exp.NullSafeEQ, 48 exp.NullSafeNEQ, 49) 50 51 52# All expression types that any of the canonicalize functions can act on 53_CANONICALIZE_TYPES = tuple( 54 { 55 # add_text_to_concat 56 exp.Add, 57 # replace_date_funcs 58 exp.Date, 59 exp.TsOrDsToDate, 60 exp.Timestamp, 61 # coerce_type (COERCIBLE_DATE_OPS + Between, Extract, DateAdd, DateSub, DateTrunc, DateDiff) 62 *COERCIBLE_DATE_OPS, 63 exp.Between, 64 exp.Extract, 65 exp.DateAdd, 66 exp.DateSub, 67 exp.DateTrunc, 68 exp.DateDiff, 69 # remove_redundant_casts 70 exp.Cast, 71 # ensure_bools (Connector, Not, If, Where, Having) 72 exp.Connector, 73 exp.Not, 74 exp.If, 75 exp.Where, 76 exp.Having, 77 # remove_ascending_order 78 exp.Ordered, 79 } 80) 81 82 83def add_text_to_concat(node: exp.Expression) -> exp.Expression: 84 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 85 node = exp.Concat( 86 expressions=[node.left, node.right], 87 # All known dialects, i.e. Redshift and T-SQL, that support 88 # concatenating strings with the + operator do not coalesce NULLs. 89 coalesce=False, 90 ) 91 return node 92 93 94def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: 95 if ( 96 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 97 and not node.expressions 98 and not node.args.get("zone") 99 and node.this.is_string 100 and is_iso_date(node.this.name) 101 ): 102 return exp.cast(node.this, to=exp.DataType.Type.DATE) 103 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 104 if not node.type: 105 from sqlglot.optimizer.annotate_types import annotate_types 106 107 node = annotate_types(node, dialect=dialect) 108 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 109 110 return node 111 112 113def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: 114 if isinstance(node, COERCIBLE_DATE_OPS): 115 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 116 elif isinstance(node, exp.Between): 117 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 118 elif isinstance(node, exp.Extract) and not node.expression.is_type( 119 *exp.DataType.TEMPORAL_TYPES 120 ): 121 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 122 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 123 _coerce_timeunit_arg(node.this, node.unit) 124 elif isinstance(node, exp.DateDiff): 125 _coerce_datediff_args(node) 126 127 return node 128 129 130def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 131 if ( 132 isinstance(expression, exp.Cast) 133 and expression.this.type 134 and expression.to == expression.this.type 135 ): 136 return expression.this 137 138 if ( 139 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 140 and expression.this.type 141 and expression.this.type.this == exp.DataType.Type.DATE 142 and not expression.this.type.expressions 143 ): 144 return expression.this 145 146 return expression 147 148 149def ensure_bools( 150 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 151) -> exp.Expression: 152 if isinstance(expression, exp.Connector): 153 replace_func(expression.left) 154 replace_func(expression.right) 155 elif isinstance(expression, exp.Not): 156 replace_func(expression.this) 157 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 158 elif isinstance(expression, exp.If) and not ( 159 isinstance(expression.parent, exp.Case) and expression.parent.this 160 ): 161 replace_func(expression.this) 162 elif isinstance(expression, (exp.Where, exp.Having)): 163 replace_func(expression.this) 164 165 return expression 166 167 168def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 169 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 170 # Convert ORDER BY a ASC to ORDER BY a 171 expression.set("desc", None) 172 173 return expression 174 175 176def _coerce_date( 177 a: exp.Expression, 178 b: exp.Expression, 179 promote_to_inferred_datetime_type: bool, 180) -> None: 181 for a, b in itertools.permutations([a, b]): 182 if isinstance(b, exp.Interval): 183 a = _coerce_timeunit_arg(a, b.unit) 184 185 a_type = a.type 186 if ( 187 not a_type 188 or a_type.this not in exp.DataType.TEMPORAL_TYPES 189 or not b.type 190 or b.type.this not in exp.DataType.TEXT_TYPES 191 ): 192 continue 193 194 if promote_to_inferred_datetime_type: 195 if b.is_string: 196 date_text = b.name 197 if is_iso_date(date_text): 198 b_type = exp.DataType.Type.DATE 199 elif is_iso_datetime(date_text): 200 b_type = exp.DataType.Type.DATETIME 201 else: 202 b_type = a_type.this 203 else: 204 # If b is not a datetime string, we conservatively promote it to a DATETIME, 205 # in order to ensure there are no surprising truncations due to downcasting 206 b_type = exp.DataType.Type.DATETIME 207 208 target_type = ( 209 b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type 210 ) 211 else: 212 target_type = a_type 213 214 if target_type != a_type: 215 _replace_cast(a, target_type) 216 217 _replace_cast(b, target_type) 218 219 220def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: 221 if not arg.type: 222 return arg 223 224 if arg.type.this in exp.DataType.TEXT_TYPES: 225 date_text = arg.name 226 is_iso_date_ = is_iso_date(date_text) 227 228 if is_iso_date_ and is_date_unit(unit): 229 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) 230 231 # An ISO date is also an ISO datetime, but not vice versa 232 if is_iso_date_ or is_iso_datetime(date_text): 233 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 234 235 elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): 236 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 237 238 return arg 239 240 241def _coerce_datediff_args(node: exp.DateDiff) -> None: 242 for e in (node.this, node.expression): 243 if e.type.this not in exp.DataType.TEMPORAL_TYPES: 244 e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) 245 246 247def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None: 248 node.replace(exp.cast(node.copy(), to=to)) 249 250 251# this was originally designed for presto, there is a similar transform for tsql 252# this is different in that it only operates on int types, this is because 253# presto has a boolean type whereas tsql doesn't (people use bits) 254# with y as (select true as x) select x = 0 FROM y -- illegal presto query 255def _replace_int_predicate(expression: exp.Expression) -> None: 256 if isinstance(expression, exp.Coalesce): 257 for child in expression.iter_expressions(): 258 _replace_int_predicate(child) 259 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 260 expression.replace(expression.neq(0))
def
canonicalize( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
13def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 14 """Converts a sql expression into a standard form. 15 16 This method relies on annotate_types because many of the 17 conversions rely on type inference. 18 19 Args: 20 expression: The expression to canonicalize. 21 """ 22 23 dialect = Dialect.get_or_raise(dialect) 24 25 def _canonicalize(expression: exp.Expression) -> exp.Expression: 26 if not isinstance(expression, _CANONICALIZE_TYPES): 27 return expression 28 expression = add_text_to_concat(expression) 29 expression = replace_date_funcs(expression, dialect=dialect) 30 expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 31 expression = remove_redundant_casts(expression) 32 expression = ensure_bools(expression, _replace_int_predicate) 33 expression = remove_ascending_order(expression) 34 return expression 35 36 return exp.replace_tree(expression, _canonicalize)
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
COERCIBLE_DATE_OPS =
(<class 'sqlglot.expressions.Add'>, <class 'sqlglot.expressions.Sub'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NullSafeEQ'>, <class 'sqlglot.expressions.NullSafeNEQ'>)
84def add_text_to_concat(node: exp.Expression) -> exp.Expression: 85 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 86 node = exp.Concat( 87 expressions=[node.left, node.right], 88 # All known dialects, i.e. Redshift and T-SQL, that support 89 # concatenating strings with the + operator do not coalesce NULLs. 90 coalesce=False, 91 ) 92 return node
def
replace_date_funcs( node: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
95def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: 96 if ( 97 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 98 and not node.expressions 99 and not node.args.get("zone") 100 and node.this.is_string 101 and is_iso_date(node.this.name) 102 ): 103 return exp.cast(node.this, to=exp.DataType.Type.DATE) 104 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 105 if not node.type: 106 from sqlglot.optimizer.annotate_types import annotate_types 107 108 node = annotate_types(node, dialect=dialect) 109 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 110 111 return node
def
coerce_type( node: sqlglot.expressions.Expression, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.Expression:
114def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: 115 if isinstance(node, COERCIBLE_DATE_OPS): 116 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 117 elif isinstance(node, exp.Between): 118 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 119 elif isinstance(node, exp.Extract) and not node.expression.is_type( 120 *exp.DataType.TEMPORAL_TYPES 121 ): 122 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 123 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 124 _coerce_timeunit_arg(node.this, node.unit) 125 elif isinstance(node, exp.DateDiff): 126 _coerce_datediff_args(node) 127 128 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
131def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 132 if ( 133 isinstance(expression, exp.Cast) 134 and expression.this.type 135 and expression.to == expression.this.type 136 ): 137 return expression.this 138 139 if ( 140 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 141 and expression.this.type 142 and expression.this.type.this == exp.DataType.Type.DATE 143 and not expression.this.type.expressions 144 ): 145 return expression.this 146 147 return expression
def
ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
150def ensure_bools( 151 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 152) -> exp.Expression: 153 if isinstance(expression, exp.Connector): 154 replace_func(expression.left) 155 replace_func(expression.right) 156 elif isinstance(expression, exp.Not): 157 replace_func(expression.this) 158 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 159 elif isinstance(expression, exp.If) and not ( 160 isinstance(expression.parent, exp.Case) and expression.parent.this 161 ): 162 replace_func(expression.this) 163 elif isinstance(expression, (exp.Where, exp.Having)): 164 replace_func(expression.this) 165 166 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: