diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index f774f46b424d5..6716cb5f9be89 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -32,7 +32,6 @@ use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; -use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, @@ -41,6 +40,7 @@ use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_ use crate::format::DEFAULT_CAST_OPTIONS; use crate::hash_utils::create_hashes; use crate::utils::SingleRowListArrayBuilder; +use crate::{_internal_datafusion_err, arrow_datafusion_err}; use arrow::array::{ types::{IntervalDayTime, IntervalMonthDayNano}, *, @@ -1849,10 +1849,6 @@ impl ScalarValue { /// Returns an error if the iterator is empty or if the /// [`ScalarValue`]s are not all the same type /// - /// # Panics - /// - /// Panics if `self` is a dictionary with invalid key type - /// /// # Example /// ``` /// use datafusion_common::ScalarValue; @@ -3343,6 +3339,16 @@ impl ScalarValue { arr1 == &right } + /// Compare `self` with `other` and return an `Ordering`. + /// + /// This is the same as [`PartialOrd`] except that it returns + /// `Err` if the values cannot be compared, e.g., they have incompatible data types. + pub fn try_cmp(&self, other: &Self) -> Result { + self.partial_cmp(other).ok_or_else(|| { + _internal_datafusion_err!("Uncomparable values: {self:?}, {other:?}") + }) + } + /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { @@ -4761,6 +4767,32 @@ mod tests { Ok(()) } + #[test] + fn test_try_cmp() { + assert_eq!( + ScalarValue::try_cmp( + &ScalarValue::Int32(Some(1)), + &ScalarValue::Int32(Some(2)) + ) + .unwrap(), + Ordering::Less + ); + assert_eq!( + ScalarValue::try_cmp(&ScalarValue::Int32(None), &ScalarValue::Int32(Some(2))) + .unwrap(), + Ordering::Less + ); + assert_starts_with( + ScalarValue::try_cmp( + &ScalarValue::Int32(Some(1)), + &ScalarValue::Int64(Some(2)), + ) + .unwrap_err() + .message(), + "Uncomparable values: Int32(1), Int64(2)", + ); + } + #[test] fn scalar_decimal_test() -> Result<()> { let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1); @@ -7669,4 +7701,15 @@ mod tests { ]; assert!(scalars.iter().all(|s| s.is_null())); } + + // `err.to_string()` depends on backtrace being present (may have backtrace appended) + // `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped) + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{actual}' to start with '{expected_prefix}'" + ); + } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index c09859c46e15f..ad2bab879a26f 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,7 +22,7 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::error::{_exec_datafusion_err, _internal_err}; use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, @@ -120,14 +120,13 @@ pub fn compare_rows( let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) { (true, false, false) | (false, true, true) => Ordering::Greater, (true, false, true) | (false, true, false) => Ordering::Less, - (false, false, _) => if sort_options.descending { - rhs.partial_cmp(lhs) - } else { - lhs.partial_cmp(rhs) + (false, false, _) => { + if sort_options.descending { + rhs.try_cmp(lhs)? + } else { + lhs.try_cmp(rhs)? + } } - .ok_or_else(|| { - _internal_datafusion_err!("Column array shouldn't be empty") - })?, (true, true, _) => continue, }; if result != Ordering::Equal { diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index b705448203d70..c6ed260e714e8 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -52,7 +52,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { // workaround lack of Ord of ScalarValue let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") + a.try_cmp(b).expect("Can compare ScalarValues") }; scalars.sort_by(cmp); assert_eq!( diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 6d9f7f4643626..aa37abd618557 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -291,10 +291,9 @@ fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result b.partial_cmp(a).unwrap_or(Ordering::Equal), - false => a.partial_cmp(b).unwrap_or(Ordering::Equal), + true => b.try_cmp(a), + false => a.try_cmp(b), } + .unwrap_or_else(|err| { + delayed_cmp_err = Err(err); + Ordering::Equal + }) }); + delayed_cmp_err?; }; let arr = ScalarValue::new_list(&values, &self.datatype, true); diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs index 32b2315e15d58..9d9e840636b8a 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -204,17 +204,16 @@ fn find_most_restrictive_predicate( if let Some(scalar) = scalar_value { if let Some(current_best) = best_value { - if let Some(comparison) = scalar.partial_cmp(current_best) { - let is_better = if find_greater { - comparison == std::cmp::Ordering::Greater - } else { - comparison == std::cmp::Ordering::Less - }; - - if is_better { - best_value = Some(scalar); - most_restrictive_idx = idx; - } + let comparison = scalar.try_cmp(current_best)?; + let is_better = if find_greater { + comparison == std::cmp::Ordering::Greater + } else { + comparison == std::cmp::Ordering::Less + }; + + if is_better { + best_value = Some(scalar); + most_restrictive_idx = idx; } } else { best_value = Some(scalar);