From 99ab8eb5d8153fd5bac9f1e246c4ed70d2a214a8 Mon Sep 17 00:00:00 2001 From: LorrensP-2158466 Date: Sun, 23 Jun 2024 17:16:09 +0200 Subject: [PATCH 1/5] Do checked negative op instead of unchecked --- datafusion/common/src/scalar/mod.rs | 53 ++++++++++++++++------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3daf347ae4ff0..13c79f237758b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1177,41 +1177,48 @@ impl ScalarValue { | ScalarValue::Float64(None) => Ok(self.clone()), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), - ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))), - ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))), - ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))), - ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))), + ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), + ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))), + ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))), + ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))), ScalarValue::IntervalYearMonth(Some(v)) => { - Ok(ScalarValue::IntervalYearMonth(Some(-v))) + Ok(ScalarValue::IntervalYearMonth(Some(v.neg_checked()?))) } ScalarValue::IntervalDayTime(Some(v)) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); - let val = IntervalDayTimeType::make_value(-days, -ms); + let val = IntervalDayTimeType::make_value( + days.neg_checked()?, + ms.neg_checked()?, + ); Ok(ScalarValue::IntervalDayTime(Some(val))) } ScalarValue::IntervalMonthDayNano(Some(v)) => { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos); + let val = IntervalMonthDayNanoType::make_value( + months.neg_checked()?, + days.neg_checked()?, + nanos.neg_checked()?, + ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } - ScalarValue::Decimal128(Some(v), precision, scale) => { - Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) - } + ScalarValue::Decimal128(Some(v), precision, scale) => Ok( + ScalarValue::Decimal128(Some(v.neg_checked()?), *precision, *scale), + ), ScalarValue::Decimal256(Some(v), precision, scale) => Ok( - ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), + ScalarValue::Decimal256(Some(v.neg_checked()?), *precision, *scale), + ), + ScalarValue::TimestampSecond(Some(v), tz) => Ok( + ScalarValue::TimestampSecond(Some(v.neg_checked()?), tz.clone()), + ), + ScalarValue::TimestampNanosecond(Some(v), tz) => Ok( + ScalarValue::TimestampNanosecond(Some(v.neg_checked()?), tz.clone()), + ), + ScalarValue::TimestampMicrosecond(Some(v), tz) => Ok( + ScalarValue::TimestampMicrosecond(Some(v.neg_checked()?), tz.clone()), + ), + ScalarValue::TimestampMillisecond(Some(v), tz) => Ok( + ScalarValue::TimestampMillisecond(Some(v.neg_checked()?), tz.clone()), ), - ScalarValue::TimestampSecond(Some(v), tz) => { - Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) - } - ScalarValue::TimestampNanosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) - } - ScalarValue::TimestampMicrosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) - } - ScalarValue::TimestampMillisecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) - } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" ), From 8e2087918d961dede94bccaaf994141165da50f0 Mon Sep 17 00:00:00 2001 From: LorrensP-2158466 Date: Sun, 23 Jun 2024 19:34:07 +0200 Subject: [PATCH 2/5] add tests for checking if overflow error occurs --- datafusion/common/src/scalar/mod.rs | 65 +++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 13c79f237758b..c43a2da765cec 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -29,6 +29,7 @@ use std::iter::repeat; use std::str::FromStr; use std::sync::Arc; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, @@ -3508,6 +3509,7 @@ mod tests { use crate::assert_batches_eq; use arrow::buffer::OffsetBuffer; use arrow::compute::{is_null, kernels}; + use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; use arrow_buffer::Buffer; use arrow_schema::Fields; @@ -5501,6 +5503,69 @@ mod tests { Ok(()) } + #[test] + #[allow(arithmetic_overflow)] // we want to test them + fn test_scalar_negative_overflows() -> Result<()> { + macro_rules! test_overflow_on_value { + ($($val:expr),* $(,)?) => {$( + { + let value: ScalarValue = $val; + let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}"); + let root_err = err.find_root(); + match root_err{ + DataFusionError::ArrowError( + ArrowError::ComputeError(_), + _, + ) => {} + _ => return Err(err), + }; + } + )*}; + } + test_overflow_on_value!( + // the integers + i8::MIN.into(), + i16::MIN.into(), + i32::MIN.into(), + i64::MIN.into(), + // for decimals, only value needs to be tested + ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?, + ScalarValue::Decimal256(Some(i256::MIN), 20, 5), + // interval, check all possible values + ScalarValue::IntervalYearMonth(Some(i32::MIN)), + ScalarValue::new_interval_dt(i32::MIN, 999), + ScalarValue::new_interval_dt(1, i32::MIN), + ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456), + ScalarValue::new_interval_mdn(12, i32::MIN, 123_456), + ScalarValue::new_interval_mdn(12, 15, i64::MIN), + // tz doesn't matter when negating + ScalarValue::TimestampSecond(Some(i64::MIN), None), + ScalarValue::TimestampMillisecond(Some(i64::MIN), None), + ScalarValue::TimestampMicrosecond(Some(i64::MIN), None), + ScalarValue::TimestampNanosecond(Some(i64::MIN), None), + ); + + let float_cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + (f32::MIN.into(), f32::MAX.into()), + (f32::MAX.into(), f32::MIN.into()), + (f64::MIN.into(), f64::MAX.into()), + (f64::MAX.into(), f64::MIN.into()), + ]; + // skip float 16 because they aren't supported + for (test, expected) in float_cases.into_iter().skip(2) { + assert_eq!(test.arithmetic_negate()?, expected); + } + Ok(()) + } + macro_rules! expect_operation_error { ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { #[test] From 61587e4536e9bcad844752adc0bb87957dada96e Mon Sep 17 00:00:00 2001 From: LorrensP-2158466 Date: Sun, 23 Jun 2024 19:50:13 +0200 Subject: [PATCH 3/5] add context to negating complexer ScalarValues --- datafusion/common/src/scalar/mod.rs | 111 ++++++++++++++++++++++------ 1 file changed, 87 insertions(+), 24 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index c43a2da765cec..f8860758b618b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1169,6 +1169,13 @@ impl ScalarValue { /// Calculate arithmetic negation for a scalar value pub fn arithmetic_negate(&self) -> Result { + fn neg_checked_with_ctx( + v: T, + ctx: impl Into, + ) -> Result { + v.neg_checked() + .map_err(|e| arrow_datafusion_err!(e).context(ctx)) + } match self { ScalarValue::Int8(None) | ScalarValue::Int16(None) @@ -1183,43 +1190,99 @@ impl ScalarValue { ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))), ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))), ScalarValue::IntervalYearMonth(Some(v)) => { - Ok(ScalarValue::IntervalYearMonth(Some(v.neg_checked()?))) + Ok(ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx( + *v, + format!("In negation of IntervalYearMonth({v})"), + )?))) } ScalarValue::IntervalDayTime(Some(v)) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); let val = IntervalDayTimeType::make_value( - days.neg_checked()?, - ms.neg_checked()?, + neg_checked_with_ctx( + days, + format!("In negation of days {days} in IntervalDayTime"), + )?, + neg_checked_with_ctx( + ms, + format!("In negation of milliseconds {ms} in IntervalDayTime"), + )?, ); Ok(ScalarValue::IntervalDayTime(Some(val))) } ScalarValue::IntervalMonthDayNano(Some(v)) => { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); let val = IntervalMonthDayNanoType::make_value( - months.neg_checked()?, - days.neg_checked()?, - nanos.neg_checked()?, + neg_checked_with_ctx( + months, + format!("In negation of months {months} of IntervalMonthDayNano"), + )?, + neg_checked_with_ctx( + days, + format!("In negation of days {days} of IntervalMonthDayNano"), + )?, + neg_checked_with_ctx( + nanos, + format!("In negation of nanos {nanos} of IntervalMonthDayNano"), + )?, ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } - ScalarValue::Decimal128(Some(v), precision, scale) => Ok( - ScalarValue::Decimal128(Some(v.neg_checked()?), *precision, *scale), - ), - ScalarValue::Decimal256(Some(v), precision, scale) => Ok( - ScalarValue::Decimal256(Some(v.neg_checked()?), *precision, *scale), - ), - ScalarValue::TimestampSecond(Some(v), tz) => Ok( - ScalarValue::TimestampSecond(Some(v.neg_checked()?), tz.clone()), - ), - ScalarValue::TimestampNanosecond(Some(v), tz) => Ok( - ScalarValue::TimestampNanosecond(Some(v.neg_checked()?), tz.clone()), - ), - ScalarValue::TimestampMicrosecond(Some(v), tz) => Ok( - ScalarValue::TimestampMicrosecond(Some(v.neg_checked()?), tz.clone()), - ), - ScalarValue::TimestampMillisecond(Some(v), tz) => Ok( - ScalarValue::TimestampMillisecond(Some(v.neg_checked()?), tz.clone()), - ), + ScalarValue::Decimal128(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal128( + Some(neg_checked_with_ctx( + *v, + format!("In negation of Decimal128({v}, {precision}, {scale})"), + )?), + *precision, + *scale, + )) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal256( + Some(neg_checked_with_ctx( + *v, + format!("In negation of Decimal256({v}, {precision}, {scale})"), + )?), + *precision, + *scale, + )) + } + ScalarValue::TimestampSecond(Some(v), tz) => { + Ok(ScalarValue::TimestampSecond( + Some(neg_checked_with_ctx( + *v, + format!("In negation of TimestampSecond({v})"), + )?), + tz.clone(), + )) + } + ScalarValue::TimestampNanosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampNanosecond( + Some(neg_checked_with_ctx( + *v, + format!("In negation of TimestampNanoSecond({v})"), + )?), + tz.clone(), + )) + } + ScalarValue::TimestampMicrosecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMicrosecond( + Some(neg_checked_with_ctx( + *v, + format!("In negation of TimestampMicroSecond({v})"), + )?), + tz.clone(), + )) + } + ScalarValue::TimestampMillisecond(Some(v), tz) => { + Ok(ScalarValue::TimestampMillisecond( + Some(neg_checked_with_ctx( + *v, + format!("In negation of TimestampMilliSecond({v})"), + )?), + tz.clone(), + )) + } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" ), From 220bf2b98b2948b829c1c236d8f810098ea7440c Mon Sep 17 00:00:00 2001 From: LorrensP-2158466 Date: Mon, 24 Jun 2024 14:49:36 +0200 Subject: [PATCH 4/5] put format! call to create error message in closure --- datafusion/common/src/scalar/mod.rs | 92 +++++++++++++---------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index f8860758b618b..8cbb8c585272e 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1171,10 +1171,10 @@ impl ScalarValue { pub fn arithmetic_negate(&self) -> Result { fn neg_checked_with_ctx( v: T, - ctx: impl Into, + ctx: impl Fn() -> String, ) -> Result { v.neg_checked() - .map_err(|e| arrow_datafusion_err!(e).context(ctx)) + .map_err(|e| arrow_datafusion_err!(e).context(ctx())) } match self { ScalarValue::Int8(None) @@ -1189,97 +1189,85 @@ impl ScalarValue { ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))), ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))), ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))), - ScalarValue::IntervalYearMonth(Some(v)) => { - Ok(ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx( - *v, - format!("In negation of IntervalYearMonth({v})"), - )?))) - } + ScalarValue::IntervalYearMonth(Some(v)) => Ok( + ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || { + format!("In negation of IntervalYearMonth({v})") + })?)), + ), ScalarValue::IntervalDayTime(Some(v)) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); let val = IntervalDayTimeType::make_value( - neg_checked_with_ctx( - days, - format!("In negation of days {days} in IntervalDayTime"), - )?, - neg_checked_with_ctx( - ms, - format!("In negation of milliseconds {ms} in IntervalDayTime"), - )?, + neg_checked_with_ctx(days, || { + format!("In negation of days {days} in IntervalDayTime") + })?, + neg_checked_with_ctx(ms, || { + format!("In negation of milliseconds {ms} in IntervalDayTime") + })?, ); Ok(ScalarValue::IntervalDayTime(Some(val))) } ScalarValue::IntervalMonthDayNano(Some(v)) => { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); let val = IntervalMonthDayNanoType::make_value( - neg_checked_with_ctx( - months, - format!("In negation of months {months} of IntervalMonthDayNano"), - )?, - neg_checked_with_ctx( - days, - format!("In negation of days {days} of IntervalMonthDayNano"), - )?, - neg_checked_with_ctx( - nanos, - format!("In negation of nanos {nanos} of IntervalMonthDayNano"), - )?, + neg_checked_with_ctx(months, || { + format!("In negation of months {months} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(days, || { + format!("In negation of days {days} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(nanos, || { + format!("In negation of nanos {nanos} of IntervalMonthDayNano") + })?, ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } ScalarValue::Decimal128(Some(v), precision, scale) => { Ok(ScalarValue::Decimal128( - Some(neg_checked_with_ctx( - *v, - format!("In negation of Decimal128({v}, {precision}, {scale})"), - )?), + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal128({v}, {precision}, {scale})") + })?), *precision, *scale, )) } ScalarValue::Decimal256(Some(v), precision, scale) => { Ok(ScalarValue::Decimal256( - Some(neg_checked_with_ctx( - *v, - format!("In negation of Decimal256({v}, {precision}, {scale})"), - )?), + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal256({v}, {precision}, {scale})") + })?), *precision, *scale, )) } ScalarValue::TimestampSecond(Some(v), tz) => { Ok(ScalarValue::TimestampSecond( - Some(neg_checked_with_ctx( - *v, - format!("In negation of TimestampSecond({v})"), - )?), + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampSecond({v})") + })?), tz.clone(), )) } ScalarValue::TimestampNanosecond(Some(v), tz) => { Ok(ScalarValue::TimestampNanosecond( - Some(neg_checked_with_ctx( - *v, - format!("In negation of TimestampNanoSecond({v})"), - )?), + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampNanoSecond({v})") + })?), tz.clone(), )) } ScalarValue::TimestampMicrosecond(Some(v), tz) => { Ok(ScalarValue::TimestampMicrosecond( - Some(neg_checked_with_ctx( - *v, - format!("In negation of TimestampMicroSecond({v})"), - )?), + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMicroSecond({v})") + })?), tz.clone(), )) } ScalarValue::TimestampMillisecond(Some(v), tz) => { Ok(ScalarValue::TimestampMillisecond( - Some(neg_checked_with_ctx( - *v, - format!("In negation of TimestampMilliSecond({v})"), - )?), + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMilliSecond({v})") + })?), tz.clone(), )) } From 98174bc367582b6c56a04a4bf86623bad1b3d129 Mon Sep 17 00:00:00 2001 From: LorrensP-2158466 Date: Mon, 24 Jun 2024 14:58:33 +0200 Subject: [PATCH 5/5] seperate test case for f16 that should panic with not implemented --- datafusion/common/src/scalar/mod.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8cbb8c585272e..dfd2eeefe9edf 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -5617,6 +5617,26 @@ mod tests { Ok(()) } + #[test] + #[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")] + fn f16_test_overflow() { + // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case + let cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + ]; + + for (test, expected) in cases { + assert_eq!(test.arithmetic_negate().unwrap(), expected); + } + } + macro_rules! expect_operation_error { ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { #[test]