diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 64cf7857e2234..6879a17f34bea 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -109,7 +109,7 @@ async fn main() -> Result<()> { // expects two f64 vec![DataType::Float64, DataType::Float64], // returns f64 - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, pow, ); diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index e7aa1172a8540..e229b28490706 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2772,7 +2772,7 @@ mod tests { ctx.register_udf(create_udf( "my_fn", vec![DataType::Float64], - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, my_fn, )); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index b6068e4859df3..d7995d4663be4 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -155,7 +155,7 @@ fn test_evaluate(input_expr: Expr, expected_expr: Expr) { // Make a UDF that adds its two values together, with the specified volatility fn make_udf_add(volatility: Volatility) -> Arc { let input_types = vec![DataType::Int32, DataType::Int32]; - let return_type = Arc::new(DataType::Int32); + let return_type = DataType::Int32; let fun = Arc::new(|args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 0f1c3b8e53c4a..013aec48d5108 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -120,7 +120,7 @@ async fn scalar_udf() -> Result<()> { ctx.register_udf(create_udf( "my_add", vec![DataType::Int32, DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, myfunc, )); @@ -237,7 +237,7 @@ async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { ctx.register_udf(create_udf( "buggy_func", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, buggy_udf, )); @@ -321,7 +321,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { ctx.register_udf(create_udf( "abs", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))), )); @@ -414,7 +414,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { ctx.register_udf(create_udf( "MY_FUNC", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, myfunc, )); @@ -459,7 +459,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let udf = create_udf( "dummy", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, myfunc, ) @@ -1149,7 +1149,7 @@ fn create_udf_context() -> SessionContext { ctx.register_udf(create_udf( "custom_sqrt", vec![DataType::Float64], - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, Arc::new(custom_sqrt), )); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8d01712b95add..5fd3177bc27b9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -390,11 +390,10 @@ pub fn unnest(expr: Expr) -> Expr { pub fn create_udf( name: &str, input_types: Vec, - return_type: Arc, + return_type: DataType, volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type = Arc::unwrap_or_clone(return_type); ScalarUDF::from(SimpleScalarUDF::new( name, input_types, diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 9188480431aa5..12ddb4cb2e329 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -116,7 +116,7 @@ impl Serializeable for Expr { Ok(Arc::new(create_udf( name, vec![], - Arc::new(arrow::datatypes::DataType::Null), + arrow::datatypes::DataType::Null, Volatility::Immutable, Arc::new(|_| unimplemented!()), ))) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1ff39e9e65b74..71c8dbe6ec50c 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2172,7 +2172,7 @@ fn roundtrip_scalar_udf() { let udf = create_udf( "dummy", vec![DataType::Utf8], - Arc::new(DataType::Utf8), + DataType::Utf8, Volatility::Immutable, scalar_fn, ); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 58f6015ee3361..f4b32e662ea9c 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -871,7 +871,7 @@ fn roundtrip_scalar_udf() -> Result<()> { let udf = create_udf( "dummy", vec![DataType::Int64], - Arc::new(DataType::Int64), + DataType::Int64, Volatility::Immutable, scalar_fn.clone(), ); diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f28098d83b970..d1b50105d053d 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -238,7 +238,7 @@ fn context_with_udf() -> SessionContext { let udf = create_udf( "dummy", vec![DataType::Utf8], - Arc::new(DataType::Utf8), + DataType::Utf8, Volatility::Immutable, scalar_fn, ); diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index ef2fa863e6b03..19016d328f4cf 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -359,7 +359,7 @@ fn create_example_udf() -> ScalarUDF { // Expects two f64 values: vec![DataType::Float64, DataType::Float64], // Returns an f64 value: - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, adder, )