sqlglot.optimizer.canonicalize
1from __future__ import annotations 2 3import itertools 4import typing as t 5 6from sqlglot import exp 7from sqlglot.dialects.dialect import Dialect, DialectType 8from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime 9from sqlglot.optimizer.annotate_types import TypeAnnotator 10 11 12def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 13 """Converts a sql expression into a standard form. 14 15 This method relies on annotate_types because many of the 16 conversions rely on type inference. 17 18 Args: 19 expression: The expression to canonicalize. 20 """ 21 22 dialect = Dialect.get_or_raise(dialect) 23 24 def _canonicalize(expression: exp.Expression) -> exp.Expression: 25 expression = add_text_to_concat(expression) 26 expression = replace_date_funcs(expression, dialect=dialect) 27 expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 28 expression = remove_redundant_casts(expression) 29 expression = ensure_bools(expression, _replace_int_predicate) 30 expression = remove_ascending_order(expression) 31 return expression 32 33 return exp.replace_tree(expression, _canonicalize) 34 35 36def add_text_to_concat(node: exp.Expression) -> exp.Expression: 37 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 38 node = exp.Concat( 39 expressions=[node.left, node.right], 40 # All known dialects, i.e. Redshift and T-SQL, that support 41 # concatenating strings with the + operator do not coalesce NULLs. 42 coalesce=False, 43 ) 44 return node 45 46 47def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: 48 if ( 49 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 50 and not node.expressions 51 and not node.args.get("zone") 52 and node.this.is_string 53 and is_iso_date(node.this.name) 54 ): 55 return exp.cast(node.this, to=exp.DataType.Type.DATE) 56 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 57 if not node.type: 58 from sqlglot.optimizer.annotate_types import annotate_types 59 60 node = annotate_types(node, dialect=dialect) 61 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 62 63 return node 64 65 66COERCIBLE_DATE_OPS = ( 67 exp.Add, 68 exp.Sub, 69 exp.EQ, 70 exp.NEQ, 71 exp.GT, 72 exp.GTE, 73 exp.LT, 74 exp.LTE, 75 exp.NullSafeEQ, 76 exp.NullSafeNEQ, 77) 78 79 80def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: 81 if isinstance(node, COERCIBLE_DATE_OPS): 82 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 83 elif isinstance(node, exp.Between): 84 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 85 elif isinstance(node, exp.Extract) and not node.expression.is_type( 86 *exp.DataType.TEMPORAL_TYPES 87 ): 88 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 89 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 90 _coerce_timeunit_arg(node.this, node.unit) 91 elif isinstance(node, exp.DateDiff): 92 _coerce_datediff_args(node) 93 94 return node 95 96 97def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 98 if ( 99 isinstance(expression, exp.Cast) 100 and expression.this.type 101 and expression.to == expression.this.type 102 ): 103 return expression.this 104 105 if ( 106 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 107 and expression.this.type 108 and expression.this.type.this == exp.DataType.Type.DATE 109 and not expression.this.type.expressions 110 ): 111 return expression.this 112 113 return expression 114 115 116def ensure_bools( 117 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 118) -> exp.Expression: 119 if isinstance(expression, exp.Connector): 120 replace_func(expression.left) 121 replace_func(expression.right) 122 elif isinstance(expression, exp.Not): 123 replace_func(expression.this) 124 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 125 elif isinstance(expression, exp.If) and not ( 126 isinstance(expression.parent, exp.Case) and expression.parent.this 127 ): 128 replace_func(expression.this) 129 elif isinstance(expression, (exp.Where, exp.Having)): 130 replace_func(expression.this) 131 132 return expression 133 134 135def remove_ascending_order(expression: exp.Expression) -> exp.Expression: 136 if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: 137 # Convert ORDER BY a ASC to ORDER BY a 138 expression.set("desc", None) 139 140 return expression 141 142 143def _coerce_date( 144 a: exp.Expression, 145 b: exp.Expression, 146 promote_to_inferred_datetime_type: bool, 147) -> None: 148 for a, b in itertools.permutations([a, b]): 149 if isinstance(b, exp.Interval): 150 a = _coerce_timeunit_arg(a, b.unit) 151 152 a_type = a.type 153 if ( 154 not a_type 155 or a_type.this not in exp.DataType.TEMPORAL_TYPES 156 or not b.type 157 or b.type.this not in exp.DataType.TEXT_TYPES 158 ): 159 continue 160 161 if promote_to_inferred_datetime_type: 162 if b.is_string: 163 date_text = b.name 164 if is_iso_date(date_text): 165 b_type = exp.DataType.Type.DATE 166 elif is_iso_datetime(date_text): 167 b_type = exp.DataType.Type.DATETIME 168 else: 169 b_type = a_type.this 170 else: 171 # If b is not a datetime string, we conservatively promote it to a DATETIME, 172 # in order to ensure there are no surprising truncations due to downcasting 173 b_type = exp.DataType.Type.DATETIME 174 175 target_type = ( 176 b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type 177 ) 178 else: 179 target_type = a_type 180 181 if target_type != a_type: 182 _replace_cast(a, target_type) 183 184 _replace_cast(b, target_type) 185 186 187def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: 188 if not arg.type: 189 return arg 190 191 if arg.type.this in exp.DataType.TEXT_TYPES: 192 date_text = arg.name 193 is_iso_date_ = is_iso_date(date_text) 194 195 if is_iso_date_ and is_date_unit(unit): 196 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) 197 198 # An ISO date is also an ISO datetime, but not vice versa 199 if is_iso_date_ or is_iso_datetime(date_text): 200 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 201 202 elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): 203 return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) 204 205 return arg 206 207 208def _coerce_datediff_args(node: exp.DateDiff) -> None: 209 for e in (node.this, node.expression): 210 if e.type.this not in exp.DataType.TEMPORAL_TYPES: 211 e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) 212 213 214def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None: 215 node.replace(exp.cast(node.copy(), to=to)) 216 217 218# this was originally designed for presto, there is a similar transform for tsql 219# this is different in that it only operates on int types, this is because 220# presto has a boolean type whereas tsql doesn't (people use bits) 221# with y as (select true as x) select x = 0 FROM y -- illegal presto query 222def _replace_int_predicate(expression: exp.Expression) -> None: 223 if isinstance(expression, exp.Coalesce): 224 for child in expression.iter_expressions(): 225 _replace_int_predicate(child) 226 elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: 227 expression.replace(expression.neq(0))
def
canonicalize( expression: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType] = None) -> sqlglot.expressions.Expression:
13def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: 14 """Converts a sql expression into a standard form. 15 16 This method relies on annotate_types because many of the 17 conversions rely on type inference. 18 19 Args: 20 expression: The expression to canonicalize. 21 """ 22 23 dialect = Dialect.get_or_raise(dialect) 24 25 def _canonicalize(expression: exp.Expression) -> exp.Expression: 26 expression = add_text_to_concat(expression) 27 expression = replace_date_funcs(expression, dialect=dialect) 28 expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) 29 expression = remove_redundant_casts(expression) 30 expression = ensure_bools(expression, _replace_int_predicate) 31 expression = remove_ascending_order(expression) 32 return expression 33 34 return exp.replace_tree(expression, _canonicalize)
Converts a sql expression into a standard form.
This method relies on annotate_types because many of the conversions rely on type inference.
Arguments:
- expression: The expression to canonicalize.
37def add_text_to_concat(node: exp.Expression) -> exp.Expression: 38 if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: 39 node = exp.Concat( 40 expressions=[node.left, node.right], 41 # All known dialects, i.e. Redshift and T-SQL, that support 42 # concatenating strings with the + operator do not coalesce NULLs. 43 coalesce=False, 44 ) 45 return node
def
replace_date_funcs( node: sqlglot.expressions.Expression, dialect: Union[str, sqlglot.dialects.Dialect, Type[sqlglot.dialects.Dialect], NoneType]) -> sqlglot.expressions.Expression:
48def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: 49 if ( 50 isinstance(node, (exp.Date, exp.TsOrDsToDate)) 51 and not node.expressions 52 and not node.args.get("zone") 53 and node.this.is_string 54 and is_iso_date(node.this.name) 55 ): 56 return exp.cast(node.this, to=exp.DataType.Type.DATE) 57 if isinstance(node, exp.Timestamp) and not node.args.get("zone"): 58 if not node.type: 59 from sqlglot.optimizer.annotate_types import annotate_types 60 61 node = annotate_types(node, dialect=dialect) 62 return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) 63 64 return node
COERCIBLE_DATE_OPS =
(<class 'sqlglot.expressions.Add'>, <class 'sqlglot.expressions.Sub'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.NullSafeEQ'>, <class 'sqlglot.expressions.NullSafeNEQ'>)
def
coerce_type( node: sqlglot.expressions.Expression, promote_to_inferred_datetime_type: bool) -> sqlglot.expressions.Expression:
81def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: 82 if isinstance(node, COERCIBLE_DATE_OPS): 83 _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) 84 elif isinstance(node, exp.Between): 85 _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) 86 elif isinstance(node, exp.Extract) and not node.expression.is_type( 87 *exp.DataType.TEMPORAL_TYPES 88 ): 89 _replace_cast(node.expression, exp.DataType.Type.DATETIME) 90 elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): 91 _coerce_timeunit_arg(node.this, node.unit) 92 elif isinstance(node, exp.DateDiff): 93 _coerce_datediff_args(node) 94 95 return node
def
remove_redundant_casts( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
98def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: 99 if ( 100 isinstance(expression, exp.Cast) 101 and expression.this.type 102 and expression.to == expression.this.type 103 ): 104 return expression.this 105 106 if ( 107 isinstance(expression, (exp.Date, exp.TsOrDsToDate)) 108 and expression.this.type 109 and expression.this.type.this == exp.DataType.Type.DATE 110 and not expression.this.type.expressions 111 ): 112 return expression.this 113 114 return expression
def
ensure_bools( expression: sqlglot.expressions.Expression, replace_func: Callable[[sqlglot.expressions.Expression], NoneType]) -> sqlglot.expressions.Expression:
117def ensure_bools( 118 expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] 119) -> exp.Expression: 120 if isinstance(expression, exp.Connector): 121 replace_func(expression.left) 122 replace_func(expression.right) 123 elif isinstance(expression, exp.Not): 124 replace_func(expression.this) 125 # We can't replace num in CASE x WHEN num ..., because it's not the full predicate 126 elif isinstance(expression, exp.If) and not ( 127 isinstance(expression.parent, exp.Case) and expression.parent.this 128 ): 129 replace_func(expression.this) 130 elif isinstance(expression, (exp.Where, exp.Having)): 131 replace_func(expression.this) 132 133 return expression
def
remove_ascending_order( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression: