diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 6964965a6431a..ced8e450df28d 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -100,6 +100,7 @@ impl ScalarUDFImpl for DummyUDF { #[test] fn test_update_matching_exprs() -> Result<()> { + let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())); let exprs: Vec> = vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 3)), @@ -114,7 +115,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + udf.clone(), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -179,7 +180,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + udf, vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -231,8 +232,19 @@ fn test_update_matching_exprs() -> Result<()> { Ok(()) } +#[test] +fn test_scalar_udf_pointer_equality() { + let udf_a = ScalarUDF::new_from_impl(DummyUDF::new()); + let udf_b = ScalarUDF::new_from_impl(DummyUDF::new()); + assert_ne!(udf_a, udf_b); + + let udf_a_clone = udf_a.clone(); + assert_eq!(udf_a, udf_a_clone); +} + #[test] fn test_update_projected_exprs() -> Result<()> { + let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())); let exprs: Vec> = vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 3)), @@ -247,7 +259,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + udf.clone(), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -312,7 +324,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), + udf, vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 81865a836d2cf..6d2a62333b5a5 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -25,12 +25,14 @@ use crate::{ColumnarValue, Documentation, Expr, Signature}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; - +use std::{ + any::Any, + cmp::Ordering, + fmt::Debug, + hash::{DefaultHasher, Hash, Hasher}, + ptr, + sync::Arc, +}; /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This @@ -60,7 +62,11 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.inner.equals(other.inner.as_ref()) + if Arc::ptr_eq(&self.inner, &other.inner) { + true + } else { + self.inner.equals(other.inner.as_ref()) + } } } @@ -696,16 +702,98 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Return true if this scalar UDF is equal to the other. /// - /// Allows customizing the equality of scalar UDFs. - /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// This method allows customizing the equality of scalar UDFs. It must adhere to the rules of equivalence: + /// + /// - Reflexive: `a.equals(a)` must return true. + /// - Symmetric: `a.equals(b)` implies `b.equals(a)`. + /// - Transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// # Default Behavior + /// By default, this method compares the type IDs, names, and signatures of the two UDFs. If these match, + /// the method assumes the UDFs are not equal unless their pointers are the same. This conservative approach + /// ensures that different instances of the same function type are not mistakenly considered equal. + /// + /// # Custom Implementation + /// If a UDF has internal state or additional properties that should be considered for equality, this method + /// should be overridden. For example, a UDF with parameters might compare those parameters in addition to + /// the default checks. + /// + /// # Example + /// ```rust + /// use std::any::Any; + /// use std::hash::{DefaultHasher, Hash, Hasher}; + /// use arrow::datatypes::DataType; + /// use datafusion_common::{not_impl_err, Result}; + /// use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; + /// + /// #[derive(Debug, PartialEq)] + /// struct MyUdf { + /// param: i32, + /// signature: Signature, + /// } /// - /// - reflexive: `a.equals(a)`; - /// - symmetric: `a.equals(b)` implies `b.equals(a)`; - /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// impl ScalarUDFImpl for MyUdf { + /// fn as_any(&self) -> &dyn Any { + /// self + /// } + /// fn name(&self) -> &str { + /// "my_udf" + /// } + /// fn signature(&self) -> &Signature { + /// &self.signature + /// } + /// fn return_type(&self, _arg_types: &[DataType]) -> Result { + /// Ok(DataType::Int32) + /// } + /// fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + /// not_impl_err!("not used") + /// } + /// fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + /// if let Some(other) = other.as_any().downcast_ref::() { + /// self == other + /// } else { + /// false + /// } + /// } + /// fn hash_value(&self) -> u64 { + /// let mut hasher = DefaultHasher::new(); + /// self.param.hash(&mut hasher); + /// self.name().hash(&mut hasher); + /// hasher.finish() + /// } + /// } + /// ``` /// - /// By default, compares [`Self::name`] and [`Self::signature`]. + /// # Notes + /// - This method must be consistent with [`Self::hash_value`]. If `equals` returns true for two UDFs, + /// their hash values must also be the same. + /// - Ensure that the implementation does not panic or cause undefined behavior for any input. fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - self.name() == other.name() && self.signature() == other.signature() + // 1. If the pointers are identical, it’s definitely the same UDF. + if ptr::eq(self.as_any(), other.as_any()) { + return true; + } + + // 2. Otherwise, check that they’re the same concrete Rust type. + let self_any = self.as_any(); + let other_any = other.as_any(); + if self_any.type_id() != other_any.type_id() { + // Different types can never be equal. + return false; + } + + // 3. Now we know they're the same struct type. In theory, since Rust moves + // values by `memcpy`-ing their bytes, we could `memcmp` them byte-for-byte: + // + // However, Rust doesn't guarantee that padding bytes are set the same way, + // so two equal structs might have different padding and compare as not equal. + // + // If your UDF type has no padding, or you make sure all padding is zeroed + // (for example, with #[repr(C)] and a safe initializer), you can use memcmp + // Otherwise, it's safer to just return false. + + // 4. Fallback: we can’t prove they’re identical, so we say “not equal.” + false } /// Returns a hash value for this scalar UDF. diff --git a/datafusion/expr/tests/udf_equals.rs b/datafusion/expr/tests/udf_equals.rs new file mode 100644 index 0000000000000..7fc2ab6dc00f8 --- /dev/null +++ b/datafusion/expr/tests/udf_equals.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, +}; +use std::{ + any::Any, + hash::{Hash, Hasher}, +}; +#[derive(Debug, PartialEq)] +struct ParamUdf { + param: i32, + signature: Signature, +} + +impl ParamUdf { + fn new(param: i32) -> Self { + Self { + param, + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ParamUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "param_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("not used") + } + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self == other + } else { + false + } + } + fn hash_value(&self) -> u64 { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + self.param.hash(&mut hasher); + self.signature.hash(&mut hasher); + hasher.finish() + } +} + +#[derive(Debug)] +#[allow(dead_code)] +struct SignatureUdf { + signature: Signature, +} + +impl SignatureUdf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SignatureUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "signature_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("not used") + } + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.type_id() == other.type_id() + } else { + false + } + } +} + +#[derive(Debug)] +#[allow(dead_code)] +struct DefaultParamUdf { + param: i32, + signature: Signature, +} + +impl DefaultParamUdf { + fn new(param: i32) -> Self { + Self { + param, + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for DefaultParamUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "default_param_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("not used") + } +} + +#[test] +fn different_instances_not_equal() { + let udf1 = ScalarUDF::from(ParamUdf::new(1)); + let udf2 = ScalarUDF::from(ParamUdf::new(2)); + assert_ne!(udf1, udf2); +} + +#[test] +fn different_types_not_equal() { + let udf1 = ScalarUDF::from(ParamUdf::new(1)); + let udf2 = ScalarUDF::from(SignatureUdf::new()); + assert_ne!(udf1, udf2); +} + +#[test] +fn same_state_equal() { + let udf1 = ScalarUDF::from(ParamUdf::new(1)); + let udf2 = ScalarUDF::from(ParamUdf::new(1)); + assert_eq!(udf1, udf2); +} + +#[test] +fn same_types_equal() { + let udf1 = ScalarUDF::from(SignatureUdf::new()); + let udf2 = ScalarUDF::from(SignatureUdf::new()); + assert_eq!(udf1, udf2); +} + +#[test] +fn default_udfs_with_same_param_not_equal() { + let udf1 = ScalarUDF::from(DefaultParamUdf::new(1)); + let udf2 = ScalarUDF::from(DefaultParamUdf::new(1)); + assert_ne!(udf1, udf2); +} + +#[test] +fn default_udfs_with_different_param_not_equal() { + let udf1 = ScalarUDF::from(DefaultParamUdf::new(1)); + let udf2 = ScalarUDF::from(DefaultParamUdf::new(2)); + assert_ne!(udf1, udf2); +} diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index a28686e01fc39..ed02f830fc4aa 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -62,6 +62,71 @@ DataFusionError::SchemaError( [#16652]: https://github.com/apache/datafusion/issues/16652 +#### `ScalarUDFImpl::equals` Default Implementation + +The default implementation of the `equals` method in the `ScalarUDFImpl` trait has been updated. Previously, it compared only the type IDs, names, and signatures of UDFs. Now, it assumes UDFs are not equal unless their pointers are the same. + +**Impact:** + +- This change may affect any custom UDF implementations relying on the default `equals` behavior. +- If your UDFs have internal state or additional properties that should be considered for equality, you must override the `equals` method to include those comparisons. + +**Action Required:** + +- Review your UDF implementations and ensure the `equals` method is overridden where necessary. +- Update any tests or logic that depend on the previous default behavior. + +**Example:** + +```rust +# use datafusion::logical_expr::{ScalarUDFImpl, Signature, Volatility}; +# use datafusion_common::{DataFusionError, Result}; +# use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +# use arrow::datatypes::DataType; +# use std::any::Any; +# +# #[derive(Debug)] +# struct MyUdf { +# param: i32, +# } +# +# impl MyUdf { +# fn name(&self) -> &str { "my_udf" } +# } +# +impl ScalarUDFImpl for MyUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_udf" + } + + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.param == other.param && self.name() == other.name() + } else { + false + } + } +} +``` + +[#16677] https://github.com/apache/datafusion/issues/16677 + ### Metadata on Arrow Types is now represented by `FieldMetadata` Metadata from the Arrow `Field` is now stored using the `FieldMetadata`