Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
288313e
Add UDF equality checks and new test module for UDFs
kosiew Jul 4, 2025
37a37f1
Enhance equality check for Scalar UDFs to improve accuracy in compari…
kosiew Jul 4, 2025
b49303e
Add equality tests for Scalar UDF instances and types
kosiew Jul 4, 2025
0184ef2
Refactor equality check in ScalarUDFImpl to simplify comparison logic
kosiew Jul 4, 2025
3b1bd00
Implement equality checks for ParamUdf and AnotherParamUdf to enhance…
kosiew Jul 4, 2025
d0b1318
Refactor AnotherParamUdf to SignatureUdf for consistency in UDF imple…
kosiew Jul 4, 2025
6128779
Add DefaultParamUdf implementation and equality tests for comparison …
kosiew Jul 4, 2025
579f3c4
Enhance documentation for equals method in ScalarUDFImpl to clarify e…
kosiew Jul 4, 2025
83b7d35
Add hash_value method to ParamUdf for improved equality checks
kosiew Jul 4, 2025
a7a8c02
Document breaking change in ScalarUDFImpl::equals method default impl…
kosiew Jul 4, 2025
c303071
Add missing import for ScalarUDFImpl in udf.rs
kosiew Jul 4, 2025
186f70d
Add hash_value method to ScalarUDFImpl for improved hashing functiona…
kosiew Jul 4, 2025
71ba61e
Merge branch 'main' into udf-16677
kosiew Jul 4, 2025
88afc7a
Add license header to udf_equals.rs
kosiew Jul 4, 2025
7e41e4e
fix prettier errors
kosiew Jul 4, 2025
ec5dcb3
Improve ScalarUDF equality check for pointer equality optimization
kosiew Jul 4, 2025
8dc7e11
Refactor: derive PartialEq ParamUdf equality and hashing implementati…
kosiew Jul 7, 2025
676545c
Enhance ScalarUDFImpl equality method documentation with alternative …
kosiew Jul 7, 2025
dc6264b
Minor comment adjustment of use statement
kosiew Jul 7, 2025
d5b813b
Add blankline
kosiew Jul 7, 2025
06d15b8
Merge branch 'udf-16677' of github.com:kosiew/datafusion into udf-16677
kosiew Jul 7, 2025
2ee9c8f
Refactor: derive PartialEq for MyUdf struct and simplify equality che…
kosiew Jul 7, 2025
571f333
Merge branch 'main' into udf-16677
kosiew Jul 7, 2025
f9c8408
Move udf_equals module to datafusion/expr/tests
kosiew Jul 7, 2025
73678db
Fix doc tests
kosiew Jul 7, 2025
ec7930c
Fix prettier error
kosiew Jul 7, 2025
6be196d
Merge branch 'main' into udf-16677
kosiew Jul 10, 2025
dd4aec2
resolve merge conflict
kosiew Jul 10, 2025
67db2df
fix prettier error
kosiew Jul 10, 2025
db39478
Merge branch 'main' into udf-16677
kosiew Jul 11, 2025
34a612e
Merge branch 'main' into udf-16677
kosiew Jul 12, 2025
0b92276
Merge branch 'main' into udf-16677
kosiew Jul 12, 2025
960e845
Merge branch 'main' into udf-16677
kosiew Jul 12, 2025
9a16939
Merge branch 'main' into udf-16677
kosiew Jul 14, 2025
e8f22aa
refactor: improve equality check for ScalarUDFImpl to ensure type saf…
kosiew Jul 15, 2025
1581b76
refactor: streamline use
kosiew Jul 15, 2025
9993abb
Merge branch 'main' into udf-16677
kosiew Jul 15, 2025
3e466d7
fix clippy error
kosiew Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions datafusion/core/tests/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl ScalarUDFImpl for DummyUDF {

#[test]
fn test_update_matching_exprs() -> Result<()> {
let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new()));
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 3)),
Expand All @@ -114,7 +115,7 @@ fn test_update_matching_exprs() -> Result<()> {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
udf.clone(),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -179,7 +180,7 @@ fn test_update_matching_exprs() -> Result<()> {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
udf,
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -231,8 +232,19 @@ fn test_update_matching_exprs() -> Result<()> {
Ok(())
}

#[test]
fn test_scalar_udf_pointer_equality() {
let udf_a = ScalarUDF::new_from_impl(DummyUDF::new());
let udf_b = ScalarUDF::new_from_impl(DummyUDF::new());
assert_ne!(udf_a, udf_b);

let udf_a_clone = udf_a.clone();
assert_eq!(udf_a, udf_a_clone);
}

#[test]
fn test_update_projected_exprs() -> Result<()> {
let udf = Arc::new(ScalarUDF::new_from_impl(DummyUDF::new()));
let exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 3)),
Expand All @@ -247,7 +259,7 @@ fn test_update_projected_exprs() -> Result<()> {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
udf.clone(),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Expand Down Expand Up @@ -312,7 +324,7 @@ fn test_update_projected_exprs() -> Result<()> {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
udf,
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b_new", 1)),
Expand Down
116 changes: 102 additions & 14 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ use crate::{ColumnarValue, Documentation, Expr, Signature};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
use datafusion_expr_common::interval_arithmetic::Interval;
use std::any::Any;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use std::{
any::Any,
cmp::Ordering,
fmt::Debug,
hash::{DefaultHasher, Hash, Hasher},
ptr,
sync::Arc,
};
/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input. This
Expand Down Expand Up @@ -60,7 +62,11 @@ pub struct ScalarUDF {

impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.inner.equals(other.inner.as_ref())
if Arc::ptr_eq(&self.inner, &other.inner) {
true
} else {
self.inner.equals(other.inner.as_ref())
}
}
}

Expand Down Expand Up @@ -696,16 +702,98 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {

/// Return true if this scalar UDF is equal to the other.
///
/// Allows customizing the equality of scalar UDFs.
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
/// This method allows customizing the equality of scalar UDFs. It must adhere to the rules of equivalence:
///
/// - Reflexive: `a.equals(a)` must return true.
/// - Symmetric: `a.equals(b)` implies `b.equals(a)`.
/// - Transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
///
/// # Default Behavior
/// By default, this method compares the type IDs, names, and signatures of the two UDFs. If these match,
/// the method assumes the UDFs are not equal unless their pointers are the same. This conservative approach
/// ensures that different instances of the same function type are not mistakenly considered equal.
///
/// # Custom Implementation
/// If a UDF has internal state or additional properties that should be considered for equality, this method
/// should be overridden. For example, a UDF with parameters might compare those parameters in addition to
/// the default checks.
///
/// # Example
/// ```rust
/// use std::any::Any;
/// use std::hash::{DefaultHasher, Hash, Hasher};
/// use arrow::datatypes::DataType;
/// use datafusion_common::{not_impl_err, Result};
/// use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
///
/// #[derive(Debug, PartialEq)]
/// struct MyUdf {
/// param: i32,
/// signature: Signature,
/// }
///
/// - reflexive: `a.equals(a)`;
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
/// impl ScalarUDFImpl for MyUdf {
/// fn as_any(&self) -> &dyn Any {
/// self
/// }
/// fn name(&self) -> &str {
/// "my_udf"
/// }
/// fn signature(&self) -> &Signature {
/// &self.signature
/// }
/// fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
/// Ok(DataType::Int32)
/// }
/// fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
/// not_impl_err!("not used")
/// }
/// fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
/// if let Some(other) = other.as_any().downcast_ref::<Self>() {
/// self == other
/// } else {
/// false
/// }
/// }
/// fn hash_value(&self) -> u64 {
/// let mut hasher = DefaultHasher::new();
/// self.param.hash(&mut hasher);
/// self.name().hash(&mut hasher);
/// hasher.finish()
/// }
/// }
/// ```
///
/// By default, compares [`Self::name`] and [`Self::signature`].
/// # Notes
/// - This method must be consistent with [`Self::hash_value`]. If `equals` returns true for two UDFs,
/// their hash values must also be the same.
/// - Ensure that the implementation does not panic or cause undefined behavior for any input.
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
// 1. If the pointers are identical, it’s definitely the same UDF.
if ptr::eq(self.as_any(), other.as_any()) {
return true;
}

// 2. Otherwise, check that they’re the same concrete Rust type.
let self_any = self.as_any();
let other_any = other.as_any();
if self_any.type_id() != other_any.type_id() {
// Different types can never be equal.
return false;
}

// 3. Now we know they're the same struct type. In theory, since Rust moves
// values by `memcpy`-ing their bytes, we could `memcmp` them byte-for-byte:
//
// However, Rust doesn't guarantee that padding bytes are set the same way,
// so two equal structs might have different padding and compare as not equal.
//
// If your UDF type has no padding, or you make sure all padding is zeroed
// (for example, with #[repr(C)] and a safe initializer), you can use memcmp
// Otherwise, it's safer to just return false.

// 4. Fallback: we can’t prove they’re identical, so we say “not equal.”
false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do bit-by-bit comparison of the self & other, if they are the same type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amended comment to clarify why we are not doing
bit-by-bit comparison

        // Alternative approach: we could potentially do bit-by-bit comparison if both objects
        // are the same concrete type, but this requires:
        // 1. Both objects to have identical TypeId
        // 2. Careful handling of potential padding bytes in structs
        // 3. The concrete type to be safely comparable via memcmp
        // For now, we use the conservative approach of returning false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// 1. Both objects to have identical TypeId

we know typeid via as_any(), right?

// 2. Careful handling of potential padding bytes in structs

i don't know if rust does something (zeros) with padding bytes

// 3. The concrete type to be safely comparable via memcmp

rust struct is generally moveable around. i think the move semantics are generally about memcpy-ing bits to a new location, so memcmp-ing bits should be fine.

}

/// Returns a hash value for this scalar UDF.
Expand Down
186 changes: 186 additions & 0 deletions datafusion/expr/tests/udf_equals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use std::{
any::Any,
hash::{Hash, Hasher},
};
#[derive(Debug, PartialEq)]
struct ParamUdf {
param: i32,
signature: Signature,
}

impl ParamUdf {
fn new(param: i32) -> Self {
Self {
param,
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for ParamUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"param_udf"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("not used")
}
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<ParamUdf>() {
self == other
} else {
false
}
}
fn hash_value(&self) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.param.hash(&mut hasher);
self.signature.hash(&mut hasher);
hasher.finish()
}
}

#[derive(Debug)]
#[allow(dead_code)]
struct SignatureUdf {
signature: Signature,
}

impl SignatureUdf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SignatureUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"signature_udf"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("not used")
}
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<SignatureUdf>() {
self.type_id() == other.type_id()
} else {
false
}
}
}

#[derive(Debug)]
#[allow(dead_code)]
struct DefaultParamUdf {
param: i32,
signature: Signature,
}

impl DefaultParamUdf {
fn new(param: i32) -> Self {
Self {
param,
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for DefaultParamUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"default_param_udf"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
not_impl_err!("not used")
}
}

#[test]
fn different_instances_not_equal() {
let udf1 = ScalarUDF::from(ParamUdf::new(1));
let udf2 = ScalarUDF::from(ParamUdf::new(2));
assert_ne!(udf1, udf2);
}

#[test]
fn different_types_not_equal() {
let udf1 = ScalarUDF::from(ParamUdf::new(1));
let udf2 = ScalarUDF::from(SignatureUdf::new());
assert_ne!(udf1, udf2);
}

#[test]
fn same_state_equal() {
let udf1 = ScalarUDF::from(ParamUdf::new(1));
let udf2 = ScalarUDF::from(ParamUdf::new(1));
assert_eq!(udf1, udf2);
}

#[test]
fn same_types_equal() {
let udf1 = ScalarUDF::from(SignatureUdf::new());
let udf2 = ScalarUDF::from(SignatureUdf::new());
assert_eq!(udf1, udf2);
}

#[test]
fn default_udfs_with_same_param_not_equal() {
let udf1 = ScalarUDF::from(DefaultParamUdf::new(1));
let udf2 = ScalarUDF::from(DefaultParamUdf::new(1));
assert_ne!(udf1, udf2);
}

#[test]
fn default_udfs_with_different_param_not_equal() {
let udf1 = ScalarUDF::from(DefaultParamUdf::new(1));
let udf2 = ScalarUDF::from(DefaultParamUdf::new(2));
assert_ne!(udf1, udf2);
}
Loading