Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use std::mem::{size_of, size_of_val};
use std::str::FromStr;
use std::sync::Arc;

use crate::arrow_datafusion_err;
use crate::cast::{
as_decimal128_array, as_decimal256_array, as_dictionary_array,
as_fixed_size_binary_array, as_fixed_size_list_array,
Expand All @@ -41,6 +40,7 @@ use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_
use crate::format::DEFAULT_CAST_OPTIONS;
use crate::hash_utils::create_hashes;
use crate::utils::SingleRowListArrayBuilder;
use crate::{_internal_datafusion_err, arrow_datafusion_err};
use arrow::array::{
types::{IntervalDayTime, IntervalMonthDayNano},
*,
Expand Down Expand Up @@ -1849,10 +1849,6 @@ impl ScalarValue {
/// Returns an error if the iterator is empty or if the
/// [`ScalarValue`]s are not all the same type
///
/// # Panics
///
/// Panics if `self` is a dictionary with invalid key type
///
/// # Example
/// ```
/// use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -3343,6 +3339,16 @@ impl ScalarValue {
arr1 == &right
}

/// Compare `self` with `other` and return an `Ordering`.
///
/// This is the same as [`PartialOrd`] except that it returns
/// `Err` if the values cannot be compared, e.g., they have incompatible data types.
pub fn try_cmp(&self, other: &Self) -> Result<Ordering> {
self.partial_cmp(other).ok_or_else(|| {
_internal_datafusion_err!("Uncomparable values: {self:?}, {other:?}")
})
}

/// Estimate size if bytes including `Self`. For values with internal containers such as `String`
/// includes the allocated size (`capacity`) rather than the current length (`len`)
pub fn size(&self) -> usize {
Expand Down Expand Up @@ -4761,6 +4767,32 @@ mod tests {
Ok(())
}

#[test]
fn test_try_cmp() {
assert_eq!(
ScalarValue::try_cmp(
&ScalarValue::Int32(Some(1)),
&ScalarValue::Int32(Some(2))
)
.unwrap(),
Ordering::Less
);
assert_eq!(
ScalarValue::try_cmp(&ScalarValue::Int32(None), &ScalarValue::Int32(Some(2)))
.unwrap(),
Ordering::Less
);
assert_starts_with(
ScalarValue::try_cmp(
&ScalarValue::Int32(Some(1)),
&ScalarValue::Int64(Some(2)),
)
.unwrap_err()
.message(),
"Uncomparable values: Int32(1), Int64(2)",
);
}

#[test]
fn scalar_decimal_test() -> Result<()> {
let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1);
Expand Down Expand Up @@ -7669,4 +7701,15 @@ mod tests {
];
assert!(scalars.iter().all(|s| s.is_null()));
}

// `err.to_string()` depends on backtrace being present (may have backtrace appended)
// `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped)
fn assert_starts_with(actual: impl AsRef<str>, expected_prefix: impl AsRef<str>) {
let actual = actual.as_ref();
let expected_prefix = expected_prefix.as_ref();
assert!(
actual.starts_with(expected_prefix),
"Expected '{actual}' to start with '{expected_prefix}'"
);
}
}
15 changes: 7 additions & 8 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod memory;
pub mod proxy;
pub mod string_utils;

use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
use crate::error::{_exec_datafusion_err, _internal_err};
use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::{
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
Expand Down Expand Up @@ -120,14 +120,13 @@ pub fn compare_rows(
let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) {
(true, false, false) | (false, true, true) => Ordering::Greater,
(true, false, true) | (false, true, false) => Ordering::Less,
(false, false, _) => if sort_options.descending {
rhs.partial_cmp(lhs)
} else {
lhs.partial_cmp(rhs)
(false, false, _) => {
if sort_options.descending {
rhs.try_cmp(lhs)?
} else {
lhs.try_cmp(rhs)?
}
}
.ok_or_else(|| {
_internal_datafusion_err!("Column array shouldn't be empty")
})?,
(true, true, _) => continue,
};
if result != Ordering::Equal {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {

// workaround lack of Ord of ScalarValue
let cmp = |a: &ScalarValue, b: &ScalarValue| {
a.partial_cmp(b).expect("Can compare ScalarValues")
a.try_cmp(b).expect("Can compare ScalarValues")
};
scalars.sort_by(cmp);
assert_eq!(
Expand Down
7 changes: 3 additions & 4 deletions datafusion/functions-aggregate-common/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,9 @@ fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarV
extreme = current;
continue;
}
if let Some(cmp) = extreme.partial_cmp(&current) {
if cmp == ordering {
extreme = current;
}
let cmp = extreme.try_cmp(&current)?;
if cmp == ordering {
extreme = current;
}
}

Expand Down
10 changes: 8 additions & 2 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
}

if let Some(opts) = self.sort_options {
let mut delayed_cmp_err = Ok(());
values.sort_by(|a, b| {
if a.is_null() {
return match opts.nulls_first {
Expand All @@ -475,10 +476,15 @@ impl Accumulator for DistinctArrayAggAccumulator {
};
}
match opts.descending {
true => b.partial_cmp(a).unwrap_or(Ordering::Equal),
false => a.partial_cmp(b).unwrap_or(Ordering::Equal),
true => b.try_cmp(a),
false => a.try_cmp(b),
}
.unwrap_or_else(|err| {
delayed_cmp_err = Err(err);
Ordering::Equal
})
});
delayed_cmp_err?;
};

let arr = ScalarValue::new_list(&values, &self.datatype, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,16 @@ fn find_most_restrictive_predicate(

if let Some(scalar) = scalar_value {
if let Some(current_best) = best_value {
if let Some(comparison) = scalar.partial_cmp(current_best) {
let is_better = if find_greater {
comparison == std::cmp::Ordering::Greater
} else {
comparison == std::cmp::Ordering::Less
};

if is_better {
best_value = Some(scalar);
most_restrictive_idx = idx;
}
let comparison = scalar.try_cmp(current_best)?;
let is_better = if find_greater {
comparison == std::cmp::Ordering::Greater
} else {
comparison == std::cmp::Ordering::Less
};

if is_better {
best_value = Some(scalar);
most_restrictive_idx = idx;
}
} else {
best_value = Some(scalar);
Expand Down