diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3de7482d714c9..8f5d1584e0a9d 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -671,7 +671,7 @@ async fn window_frame_creation_type_checking() -> Result<()> { // Error is returned from the logical plan. check_query( false, - "Internal error: Optimizer rule 'type_coercion' failed due to unexpected error: Arrow error: Cast error: Cannot cast string '1 DAY' to value of UInt32 type" + "Internal error: Optimizer rule 'type_coercion' failed due to unexpected error: Execution error: Cannot cast Utf8(\"1 DAY\") to UInt32." ).await } diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/core/tests/sqllogictests/test_files/window.slt index cbbc82c91653d..64920bb3dbb16 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/core/tests/sqllogictests/test_files/window.slt @@ -527,8 +527,7 @@ LIMIT 5 #// } # async fn window_frame_ranges_preceding_following_desc -# This query should pass. Tracked in https://github.com/apache/arrow-datafusion/issues/5346 -query error DataFusion error: Internal error: Operator \+ is not implemented +query III SELECT SUM(c4) OVER(ORDER BY c2 DESC RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING), SUM(c3) OVER(ORDER BY c2 DESC RANGE BETWEEN 10000 PRECEDING AND 10000 FOLLOWING), @@ -536,6 +535,31 @@ COUNT(*) OVER(ORDER BY c2 DESC RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM aggregate_test_100 ORDER BY c9 LIMIT 5 +---- +52276 781 56 +260620 781 63 +-28623 781 37 +260620 781 63 +260620 781 63 + +# async fn window_frame_large_range +# Range offset 10000 is too big for Int8 (i.e. the type of c3). +# In this case, we should be able to still produce correct results. +# See the issue: https://github.com/apache/arrow-datafusion/issues/5346 +# below over clause is equivalent to OVER(ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +# in terms of behaviour. +query I +SELECT +SUM(c3) OVER(ORDER BY c3 DESC RANGE BETWEEN 10000 PRECEDING AND 10000 FOLLOWING) as summation1 +FROM aggregate_test_100 +ORDER BY c9 +LIMIT 5 +---- +781 +781 +781 +781 +781 # async fn window_frame_order_by_asc_desc_large query I diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs index 502925e9b8a02..ca624a877793b 100644 --- a/datafusion/expr/src/type_coercion.rs +++ b/datafusion/expr/src/type_coercion.rs @@ -33,7 +33,7 @@ use arrow::datatypes::DataType; -/// Determine if a DataType is signed numeric or not +/// Determine whether the given data type `dt` represents unsigned numeric values. pub fn is_signed_numeric(dt: &DataType) -> bool { matches!( dt, @@ -48,12 +48,12 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { ) } -// Determine if a DataType is Null or not +/// Determine whether the given data type `dt` is `Null`. pub fn is_null(dt: &DataType) -> bool { *dt == DataType::Null } -/// Determine if a DataType is numeric or not +/// Determine whether the given data type `dt` represents numeric values. pub fn is_numeric(dt: &DataType) -> bool { is_signed_numeric(dt) || matches!( @@ -62,17 +62,18 @@ pub fn is_numeric(dt: &DataType) -> bool { ) } -/// Determine if a DataType is Timestamp or not +/// Determine whether the given data type `dt` is a `Timestamp`. pub fn is_timestamp(dt: &DataType) -> bool { matches!(dt, DataType::Timestamp(_, _)) } -/// Determine if a DataType is Date or not +/// Determine whether the given data type `dt` is a `Date`. pub fn is_date(dt: &DataType) -> bool { matches!(dt, DataType::Date32 | DataType::Date64) } -pub fn is_uft8(dt: &DataType) -> bool { +/// Determine whether the given data type `dt` is a `Utf8`. +pub fn is_utf8(dt: &DataType) -> bool { matches!(dt, DataType::Utf8) } diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 6b6cea82fcecd..b1ee55f93cfb3 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -34,7 +34,7 @@ use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_when, get_coerce_type_for_list, }; -use datafusion_expr::type_coercion::{is_date, is_numeric, is_timestamp, is_uft8}; +use datafusion_expr::type_coercion::{is_date, is_numeric, is_timestamp, is_utf8}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, @@ -411,7 +411,7 @@ impl ExprRewriter for TypeCoercionRewriter { window_frame, }) => { let window_frame = - get_coerced_window_frame(window_frame, &self.schema, &order_by)?; + coerce_window_frame(window_frame, &self.schema, &order_by)?; let expr = Expr::WindowFunction(WindowFunction::new( fun, args, @@ -426,95 +426,128 @@ impl ExprRewriter for TypeCoercionRewriter { } } -/// Casts the ScalarValue `value` to coerced type. -// When coerced type is `Interval` we use `parse_interval` since `try_from_string` not -// supports conversion from string to Interval -fn convert_to_coerced_type( - coerced_type: &DataType, - value: &ScalarValue, -) -> Result { +/// Casts the given `value` to `target_type`. Note that this function +/// only considers `Null` or `Utf8` values. +fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result { match value { - // In here we do casting either for NULL types or - // ScalarValue::Utf8(Some(val)). The other types are already casted. - // The reason is that we convert the sqlparser result - // to the Utf8 for all possible cases. Hence the types other than Utf8 - // are already casted to appropriate type. Therefore they can be returned directly. + // Coerce Utf8 values: ScalarValue::Utf8(Some(val)) => { - // we need special handling for Interval types - if let DataType::Interval(..) = coerced_type { + // When `target_type` is `Interval`, we use `parse_interval` since + // `try_from_string` does not support `String` to `Interval` coercions. + if let DataType::Interval(..) = target_type { parse_interval("millisecond", val) } else { - ScalarValue::try_from_string(val.clone(), coerced_type) + ScalarValue::try_from_string(val.clone(), target_type) } } s => { if s.is_null() { - ScalarValue::try_from(coerced_type) + // Coerce `Null` values: + ScalarValue::try_from(target_type) } else { + // Values except `Utf8`/`Null` variants already have the right type + // (casted before) since we convert `sqlparser` outputs to `Utf8` + // for all possible cases. Therefore, we return a clone here. Ok(s.clone()) } } } } +/// This function coerces `value` to `target_type` in a range-aware fashion. +/// If the coercion is successful, we return an `Ok` value with the result. +/// If the coercion fails because `target_type` is not wide enough (i.e. we +/// can not coerce to `target_type`, but we can to a wider type in the same +/// family), we return a `Null` value of this type to signal this situation. +/// Downstream code uses this signal to treat these values as *unbounded*. +fn coerce_scalar_range_aware( + target_type: &DataType, + value: &ScalarValue, +) -> Result { + coerce_scalar(target_type, value).or_else(|err| { + // If type coercion fails, check if the largest type in family works: + if let Some(largest_type) = get_widest_type_in_family(target_type) { + coerce_scalar(largest_type, value).map_or_else( + |_| { + Err(DataFusionError::Execution(format!( + "Cannot cast {:?} to {:?}", + value, target_type + ))) + }, + |_| ScalarValue::try_from(target_type), + ) + } else { + Err(err) + } + }) +} + +/// This function returns the widest type in the family of `given_type`. +/// If the given type is already the widest type, it returns `None`. +/// For example, if `given_type` is `Int8`, it returns `Int64`. +fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> { + match given_type { + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64), + DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64), + DataType::Float16 | DataType::Float32 => Some(&DataType::Float64), + _ => None, + } +} + +/// Coerces the given (window frame) `bound` to `target_type`. fn coerce_frame_bound( - coerced_type: &DataType, + target_type: &DataType, bound: &WindowFrameBound, ) -> Result { - Ok(match bound { - WindowFrameBound::Preceding(val) => { - WindowFrameBound::Preceding(convert_to_coerced_type(coerced_type, val)?) + match bound { + WindowFrameBound::Preceding(v) => { + coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Preceding) } - WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, - WindowFrameBound::Following(val) => { - WindowFrameBound::Following(convert_to_coerced_type(coerced_type, val)?) + WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow), + WindowFrameBound::Following(v) => { + coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Following) } - }) + } } -fn get_coerced_window_frame( +// Coerces the given `window_frame` to use appropriate natural types. +// For example, ROWS and GROUPS frames use `UInt64` during calculations. +fn coerce_window_frame( window_frame: WindowFrame, schema: &DFSchemaRef, expressions: &[Expr], ) -> Result { - fn get_coerced_type(column_type: &DataType) -> Result { - if is_numeric(column_type) | is_uft8(column_type) { - Ok(column_type.clone()) - } else if is_timestamp(column_type) || is_date(column_type) { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } else { - Err(DataFusionError::Internal(format!( - "Cannot run range queries on datatype: {column_type:?}" - ))) - } - } - let mut window_frame = window_frame; let current_types = expressions .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - match &mut window_frame.units { + let target_type = match window_frame.units { WindowFrameUnits::Range => { - let col_type = current_types.first().ok_or_else(|| { - DataFusionError::Internal("ORDER BY column cannot be empty".to_string()) - })?; - let coerced_type = get_coerced_type(col_type)?; - window_frame.start_bound = - coerce_frame_bound(&coerced_type, &window_frame.start_bound)?; - window_frame.end_bound = - coerce_frame_bound(&coerced_type, &window_frame.end_bound)?; - } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => { - let coerced_type = DataType::UInt64; - window_frame.start_bound = - coerce_frame_bound(&coerced_type, &window_frame.start_bound)?; - window_frame.end_bound = - coerce_frame_bound(&coerced_type, &window_frame.end_bound)?; + if let Some(col_type) = current_types.first() { + if is_numeric(col_type) || is_utf8(col_type) { + col_type + } else if is_timestamp(col_type) || is_date(col_type) { + &DataType::Interval(IntervalUnit::MonthDayNano) + } else { + return Err(DataFusionError::Internal(format!( + "Cannot run range queries on datatype: {col_type:?}" + ))); + } + } else { + return Err(DataFusionError::Internal( + "ORDER BY column cannot be empty".to_string(), + )); + } } - } + WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + }; + window_frame.start_bound = + coerce_frame_bound(target_type, &window_frame.start_bound)?; + window_frame.end_bound = coerce_frame_bound(target_type, &window_frame.end_bound)?; Ok(window_frame) } + // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result {