-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Enhance ScalarUDFImpl Equality Handling with Pointer-Based Default and Customizable Logic
#16681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
288313e
37a37f1
b49303e
0184ef2
3b1bd00
d0b1318
6128779
579f3c4
83b7d35
a7a8c02
c303071
186f70d
71ba61e
88afc7a
7e41e4e
ec5dcb3
8dc7e11
676545c
dc6264b
d5b813b
06d15b8
2ee9c8f
571f333
f9c8408
73678db
ec7930c
6be196d
dd4aec2
67db2df
db39478
34a612e
0b92276
960e845
9a16939
e8f22aa
1581b76
9993abb
3e466d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,12 +25,14 @@ use crate::{ColumnarValue, Documentation, Expr, Signature}; | |
| use arrow::datatypes::{DataType, Field, FieldRef}; | ||
| use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; | ||
| use datafusion_expr_common::interval_arithmetic::Interval; | ||
| use std::any::Any; | ||
| use std::cmp::Ordering; | ||
| use std::fmt::Debug; | ||
| use std::hash::{DefaultHasher, Hash, Hasher}; | ||
| use std::sync::Arc; | ||
|
|
||
| use std::{ | ||
| any::Any, | ||
| cmp::Ordering, | ||
| fmt::Debug, | ||
| hash::{DefaultHasher, Hash, Hasher}, | ||
| ptr, | ||
| sync::Arc, | ||
| }; | ||
| /// Logical representation of a Scalar User Defined Function. | ||
| /// | ||
| /// A scalar function produces a single row output for each row of input. This | ||
|
|
@@ -60,7 +62,11 @@ pub struct ScalarUDF { | |
|
|
||
| impl PartialEq for ScalarUDF { | ||
| fn eq(&self, other: &Self) -> bool { | ||
| self.inner.equals(other.inner.as_ref()) | ||
| if Arc::ptr_eq(&self.inner, &other.inner) { | ||
| true | ||
| } else { | ||
| self.inner.equals(other.inner.as_ref()) | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -696,16 +702,98 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { | |
|
|
||
| /// Return true if this scalar UDF is equal to the other. | ||
| /// | ||
| /// Allows customizing the equality of scalar UDFs. | ||
| /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: | ||
| /// This method allows customizing the equality of scalar UDFs. It must adhere to the rules of equivalence: | ||
| /// | ||
| /// - Reflexive: `a.equals(a)` must return true. | ||
| /// - Symmetric: `a.equals(b)` implies `b.equals(a)`. | ||
| /// - Transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. | ||
| /// | ||
| /// # Default Behavior | ||
| /// By default, this method compares the type IDs, names, and signatures of the two UDFs. If these match, | ||
| /// the method assumes the UDFs are not equal unless their pointers are the same. This conservative approach | ||
| /// ensures that different instances of the same function type are not mistakenly considered equal. | ||
| /// | ||
| /// # Custom Implementation | ||
| /// If a UDF has internal state or additional properties that should be considered for equality, this method | ||
| /// should be overridden. For example, a UDF with parameters might compare those parameters in addition to | ||
| /// the default checks. | ||
| /// | ||
| /// # Example | ||
| /// ```rust | ||
| /// use std::any::Any; | ||
| /// use std::hash::{DefaultHasher, Hash, Hasher}; | ||
| /// use arrow::datatypes::DataType; | ||
| /// use datafusion_common::{not_impl_err, Result}; | ||
| /// use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; | ||
| /// | ||
| /// #[derive(Debug, PartialEq)] | ||
| /// struct MyUdf { | ||
| /// param: i32, | ||
| /// signature: Signature, | ||
| /// } | ||
| /// | ||
| /// - reflexive: `a.equals(a)`; | ||
| /// - symmetric: `a.equals(b)` implies `b.equals(a)`; | ||
| /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. | ||
| /// impl ScalarUDFImpl for MyUdf { | ||
| /// fn as_any(&self) -> &dyn Any { | ||
| /// self | ||
| /// } | ||
| /// fn name(&self) -> &str { | ||
| /// "my_udf" | ||
| /// } | ||
| /// fn signature(&self) -> &Signature { | ||
| /// &self.signature | ||
| /// } | ||
| /// fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| /// Ok(DataType::Int32) | ||
| /// } | ||
| /// fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| /// not_impl_err!("not used") | ||
| /// } | ||
| /// fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { | ||
| /// if let Some(other) = other.as_any().downcast_ref::<Self>() { | ||
| /// self == other | ||
| /// } else { | ||
| /// false | ||
| /// } | ||
| /// } | ||
| /// fn hash_value(&self) -> u64 { | ||
| /// let mut hasher = DefaultHasher::new(); | ||
| /// self.param.hash(&mut hasher); | ||
| /// self.name().hash(&mut hasher); | ||
| /// hasher.finish() | ||
| /// } | ||
| /// } | ||
| /// ``` | ||
| /// | ||
| /// By default, compares [`Self::name`] and [`Self::signature`]. | ||
| /// # Notes | ||
| /// - This method must be consistent with [`Self::hash_value`]. If `equals` returns true for two UDFs, | ||
| /// their hash values must also be the same. | ||
| /// - Ensure that the implementation does not panic or cause undefined behavior for any input. | ||
| fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { | ||
| self.name() == other.name() && self.signature() == other.signature() | ||
| // 1. If the pointers are identical, it’s definitely the same UDF. | ||
| if ptr::eq(self.as_any(), other.as_any()) { | ||
| return true; | ||
| } | ||
|
|
||
| // 2. Otherwise, check that they’re the same concrete Rust type. | ||
| let self_any = self.as_any(); | ||
| let other_any = other.as_any(); | ||
| if self_any.type_id() != other_any.type_id() { | ||
| // Different types can never be equal. | ||
| return false; | ||
| } | ||
|
|
||
| // 3. Now we know they're the same struct type. In theory, since Rust moves | ||
| // values by `memcpy`-ing their bytes, we could `memcmp` them byte-for-byte: | ||
| // | ||
| // However, Rust doesn't guarantee that padding bytes are set the same way, | ||
| // so two equal structs might have different padding and compare as not equal. | ||
| // | ||
| // If your UDF type has no padding, or you make sure all padding is zeroed | ||
| // (for example, with #[repr(C)] and a safe initializer), you can use memcmp | ||
| // Otherwise, it's safer to just return false. | ||
|
|
||
| // 4. Fallback: we can’t prove they’re identical, so we say “not equal.” | ||
| false | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we do bit-by-bit comparison of the self & other, if they are the same type?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. amended comment to clarify why we are not doing
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
we know typeid via as_any(), right?
i don't know if rust does something (zeros) with padding bytes
rust struct is generally moveable around. i think the move semantics are generally about memcpy-ing bits to a new location, so memcmp-ing bits should be fine. |
||
| } | ||
|
|
||
| /// Returns a hash value for this scalar UDF. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,186 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::{not_impl_err, Result}; | ||
| use datafusion_expr::{ | ||
| ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, | ||
| }; | ||
| use std::{ | ||
| any::Any, | ||
| hash::{Hash, Hasher}, | ||
| }; | ||
| #[derive(Debug, PartialEq)] | ||
| struct ParamUdf { | ||
| param: i32, | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl ParamUdf { | ||
| fn new(param: i32) -> Self { | ||
| Self { | ||
| param, | ||
| signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for ParamUdf { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
| fn name(&self) -> &str { | ||
| "param_udf" | ||
| } | ||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
| fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| Ok(DataType::Int32) | ||
| } | ||
| fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| not_impl_err!("not used") | ||
| } | ||
| fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { | ||
| if let Some(other) = other.as_any().downcast_ref::<ParamUdf>() { | ||
| self == other | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
| fn hash_value(&self) -> u64 { | ||
| let mut hasher = std::collections::hash_map::DefaultHasher::new(); | ||
| self.param.hash(&mut hasher); | ||
| self.signature.hash(&mut hasher); | ||
| hasher.finish() | ||
| } | ||
| } | ||
|
|
||
| #[derive(Debug)] | ||
| #[allow(dead_code)] | ||
| struct SignatureUdf { | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl SignatureUdf { | ||
| fn new() -> Self { | ||
| Self { | ||
| signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for SignatureUdf { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
| fn name(&self) -> &str { | ||
| "signature_udf" | ||
| } | ||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
| fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| Ok(DataType::Int32) | ||
| } | ||
| fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| not_impl_err!("not used") | ||
| } | ||
| fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { | ||
| if let Some(other) = other.as_any().downcast_ref::<SignatureUdf>() { | ||
| self.type_id() == other.type_id() | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[derive(Debug)] | ||
| #[allow(dead_code)] | ||
| struct DefaultParamUdf { | ||
| param: i32, | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl DefaultParamUdf { | ||
| fn new(param: i32) -> Self { | ||
| Self { | ||
| param, | ||
| signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for DefaultParamUdf { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
| fn name(&self) -> &str { | ||
| "default_param_udf" | ||
| } | ||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
| fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| Ok(DataType::Int32) | ||
| } | ||
| fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| not_impl_err!("not used") | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn different_instances_not_equal() { | ||
| let udf1 = ScalarUDF::from(ParamUdf::new(1)); | ||
| let udf2 = ScalarUDF::from(ParamUdf::new(2)); | ||
| assert_ne!(udf1, udf2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn different_types_not_equal() { | ||
| let udf1 = ScalarUDF::from(ParamUdf::new(1)); | ||
| let udf2 = ScalarUDF::from(SignatureUdf::new()); | ||
| assert_ne!(udf1, udf2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn same_state_equal() { | ||
| let udf1 = ScalarUDF::from(ParamUdf::new(1)); | ||
| let udf2 = ScalarUDF::from(ParamUdf::new(1)); | ||
| assert_eq!(udf1, udf2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn same_types_equal() { | ||
| let udf1 = ScalarUDF::from(SignatureUdf::new()); | ||
| let udf2 = ScalarUDF::from(SignatureUdf::new()); | ||
| assert_eq!(udf1, udf2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn default_udfs_with_same_param_not_equal() { | ||
| let udf1 = ScalarUDF::from(DefaultParamUdf::new(1)); | ||
| let udf2 = ScalarUDF::from(DefaultParamUdf::new(1)); | ||
| assert_ne!(udf1, udf2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn default_udfs_with_different_param_not_equal() { | ||
| let udf1 = ScalarUDF::from(DefaultParamUdf::new(1)); | ||
| let udf2 = ScalarUDF::from(DefaultParamUdf::new(2)); | ||
| assert_ne!(udf1, udf2); | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.