diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index a2525e6ad4f4d..97108294ca6e2 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -39,7 +39,7 @@ use datafusion::{ }; use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; -use std::hash::{Hash, Hasher}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; @@ -141,6 +141,9 @@ pub struct FFI_AggregateUDF { /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(udaf: &mut Self), + /// Hash value for the UDAF used for equality comparison. + pub hash_value: u64, + /// Internal data. This is only to be accessed by the provider of the udaf. /// A [`ForeignAggregateUDF`] should never attempt to access this data. pub private_data: *mut c_void, @@ -339,6 +342,10 @@ impl From> for FFI_AggregateUDF { let is_nullable = udaf.is_nullable(); let volatility = udaf.signature().volatility.into(); + let mut hasher = DefaultHasher::new(); + udaf.hash(&mut hasher); + let hash_value = hasher.finish(); + let private_data = Box::new(AggregateUDFPrivateData { udaf }); Self { @@ -357,6 +364,7 @@ impl From> for FFI_AggregateUDF { coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, + hash_value, private_data: Box::into_raw(private_data) as *mut c_void, } } @@ -386,14 +394,20 @@ unsafe impl Sync for ForeignAggregateUDF {} impl PartialEq for ForeignAggregateUDF { fn eq(&self, other: &Self) -> bool { - // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do. - std::ptr::eq(self, other) + let Self { + signature, + aliases, + udaf, + } = self; + signature == &other.signature + && aliases == &other.aliases + && udaf.hash_value == other.udaf.hash_value } } impl Eq for ForeignAggregateUDF {} impl Hash for ForeignAggregateUDF { fn hash(&self, state: &mut H) { - std::ptr::hash(self, state) + self.udaf.hash_value.hash(state); } } @@ -740,4 +754,20 @@ mod tests { test_round_trip_order_sensitivity(AggregateOrderSensitivity::SoftRequirement); test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); } + + #[test] + fn test_eq() -> Result<()> { + // Test that identical UDAFs are equal + let sum_udaf1 = create_test_foreign_udaf(Sum::new())?; + let sum_udaf2 = create_test_foreign_udaf(Sum::new())?; + assert_eq!(sum_udaf1, sum_udaf2); + + // Test that different UDAFs are not equal + let count_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::count::Count::new(), + )?; + assert_ne!(sum_udaf1, count_udaf); + + Ok(()) + } } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 390b03fe621bb..01ec44c8db4ec 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -47,7 +47,7 @@ use datafusion::{ use return_type_args::{ FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; -use std::hash::{Hash, Hasher}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; pub mod return_type_args; @@ -111,6 +111,9 @@ pub struct FFI_ScalarUDF { /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(udf: &mut Self), + /// Hash value for the UDF used for equality comparison. + pub hash_value: u64, + /// Internal data. This is only to be accessed by the provider of the udf. /// A [`ForeignScalarUDF`] should never attempt to access this data. pub private_data: *mut c_void, @@ -248,6 +251,9 @@ impl From> for FFI_ScalarUDF { let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); let volatility = udf.signature().volatility.into(); let short_circuits = udf.short_circuits(); + let mut hasher = DefaultHasher::new(); + udf.hash(&mut hasher); + let hash_value = hasher.finish(); let private_data = Box::new(ScalarUDFPrivateData { udf }); @@ -262,6 +268,7 @@ impl From> for FFI_ScalarUDF { coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, + hash_value, private_data: Box::into_raw(private_data) as *mut c_void, } } @@ -300,24 +307,15 @@ impl PartialEq for ForeignScalarUDF { } = self; name == &other.name && aliases == &other.aliases - && std::ptr::eq(udf, &other.udf) && signature == &other.signature + && udf.hash_value == other.udf.hash_value } } impl Eq for ForeignScalarUDF {} impl Hash for ForeignScalarUDF { fn hash(&self, state: &mut H) { - let Self { - name, - aliases, - udf, - signature, - } = self; - name.hash(state); - aliases.hash(state); - std::ptr::hash(udf, state); - signature.hash(state); + self.udf.hash_value.hash(state); } } @@ -463,4 +461,30 @@ mod tests { Ok(()) } + + fn create_test_foreign_udf( + original_udf: impl ScalarUDFImpl + 'static, + ) -> Result { + let original_udf = Arc::new(ScalarUDF::from(original_udf)); + let local_udf: FFI_ScalarUDF = Arc::clone(&original_udf).into(); + let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; + Ok(foreign_udf.into()) + } + + #[test] + fn test_eq() -> Result<()> { + // Test that identical UDFs are equal + let abs_udf1 = + create_test_foreign_udf(datafusion::functions::math::abs::AbsFunc::new())?; + let abs_udf2 = + create_test_foreign_udf(datafusion::functions::math::abs::AbsFunc::new())?; + assert_eq!(abs_udf1, abs_udf2); + + // Test that different UDFs are not equal + let sqrt_udf = + create_test_foreign_udf(datafusion::functions::math::gcd::GcdFunc::new())?; + assert_ne!(abs_udf1, sqrt_udf); + + Ok(()) + } } diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs index d17999e274e2f..c926c53901edc 100644 --- a/datafusion/ffi/src/udwf/mod.rs +++ b/datafusion/ffi/src/udwf/mod.rs @@ -40,7 +40,7 @@ use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator}; use partition_evaluator_args::{ FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs, }; -use std::hash::{Hash, Hasher}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; mod partition_evaluator; @@ -99,6 +99,9 @@ pub struct FFI_WindowUDF { /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(udf: &mut Self), + /// Hash value for the UDWF used for equality comparison. + pub hash_value: u64, + /// Internal data. This is only to be accessed by the provider of the udf. /// A [`ForeignWindowUDF`] should never attempt to access this data. pub private_data: *mut c_void, @@ -177,12 +180,6 @@ unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) { } unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { - // let private_data = udf.private_data as *const WindowUDFPrivateData; - // let udf_data = &(*private_data); - - // let private_data = Box::new(WindowUDFPrivateData { - // udf: Arc::clone(&udf_data.udf), - // }); let private_data = Box::new(WindowUDFPrivateData { udf: Arc::clone(udwf.inner()), }); @@ -197,6 +194,7 @@ unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { field: field_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, + hash_value: udwf.hash_value, private_data: Box::into_raw(private_data) as *mut c_void, } } @@ -214,6 +212,10 @@ impl From> for FFI_WindowUDF { let volatility = udf.signature().volatility.into(); let sort_options = udf.sort_options().map(|v| (&v).into()).into(); + let mut hasher = DefaultHasher::new(); + udf.hash(&mut hasher); + let hash_value = hasher.finish(); + let private_data = Box::new(WindowUDFPrivateData { udf }); Self { @@ -226,6 +228,7 @@ impl From> for FFI_WindowUDF { field: field_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, + hash_value, private_data: Box::into_raw(private_data) as *mut c_void, } } @@ -256,14 +259,22 @@ unsafe impl Sync for ForeignWindowUDF {} impl PartialEq for ForeignWindowUDF { fn eq(&self, other: &Self) -> bool { - // FFI_WindowUDF cannot be compared, so identity equality is the best we can do. - std::ptr::eq(self, other) + let Self { + name, + aliases, + udf, + signature, + } = self; + name == &other.name + && aliases == &other.aliases + && signature == &other.signature + && udf.hash_value == other.udf.hash_value } } impl Eq for ForeignWindowUDF {} impl Hash for ForeignWindowUDF { fn hash(&self, state: &mut H) { - std::ptr::hash(self, state) + self.udf.hash_value.hash(state); } } @@ -443,4 +454,18 @@ mod tests { Ok(()) } + + #[test] + fn test_eq() -> datafusion::common::Result<()> { + // Test that identical UDWFs are equal (using hash-based comparison) + let lag_udwf1 = create_test_foreign_udwf(WindowShift::lag())?; + let lag_udwf2 = create_test_foreign_udwf(WindowShift::lag())?; + assert_eq!(lag_udwf1, lag_udwf2); + + // Test that different UDWFs are not equal + let lead_udwf = create_test_foreign_udwf(WindowShift::lead())?; + assert_ne!(lag_udwf1, lead_udwf); + + Ok(()) + } }