From 2267af35a52dfed52676713a36efe45c1934cdce Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sat, 30 Mar 2024 20:35:57 +0800 Subject: [PATCH 1/2] refactor: macro for the binary math function in datafusion-function --- datafusion/functions/src/macros.rs | 107 ++++++++++++++++++- datafusion/functions/src/math/atan2.rs | 140 ------------------------- datafusion/functions/src/math/mod.rs | 3 +- 3 files changed, 107 insertions(+), 143 deletions(-) delete mode 100644 datafusion/functions/src/math/atan2.rs diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 4907d74fe941a..53eee2c748d45 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -251,7 +251,112 @@ macro_rules! make_math_unary_udf { }; } -#[macro_export] +/// Macro to create a binary math UDF. +/// +/// A binary math function takes an argument of type Float32 or Float64, +/// applies a binary floating function to the argument, and returns a value of the same type. +/// +/// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` +/// $GNAME: a singleton instance of the UDF +/// $NAME: the name of the function +/// $BINARY_FUNC: the binary function to apply to the argument +/// $MONOTONIC_FUNC: the monotonicity of the function +macro_rules! make_math_binary_udf { + ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $MONOTONICITY:expr) => { + make_udf_function!($NAME::$UDF, $GNAME, $NAME); + + mod $NAME { + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::DataType; + use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::TypeSignature::*; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, + }; + use std::any::Any; + use std::sync::Arc; + + #[derive(Debug)] + pub struct $UDF { + signature: Signature, + } + + impl $UDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for $UDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + stringify!($NAME) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + // For other types (possible values float64/null/int), use Float64 + _ => Ok(DataType::Float64), + } + } + + fn monotonicity(&self) -> Result> { + Ok($MONOTONICITY) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::$BINARY_FUNC } + )), + + DataType::Float32 => Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::$BINARY_FUNC } + )), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + Ok(ColumnarValue::Array(arr)) + } + } + } + }; +} + macro_rules! make_function_inputs2 { ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); diff --git a/datafusion/functions/src/math/atan2.rs b/datafusion/functions/src/math/atan2.rs deleted file mode 100644 index b090c6c454fd8..0000000000000 --- a/datafusion/functions/src/math/atan2.rs +++ /dev/null @@ -1,140 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Math function: `atan2()`. - -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use datafusion_common::DataFusionError; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use crate::make_function_inputs2; -use crate::utils::make_scalar_function; - -#[derive(Debug)] -pub(super) struct Atan2 { - signature: Signature, -} - -impl Atan2 { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - Volatility::Immutable, - ), - } - } -} - -impl ScalarUDFImpl for Atan2 { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "atan2" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use self::DataType::*; - match &arg_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_scalar_function(atan2, vec![])(args) - } -} - -/// Atan2 SQL function -pub fn atan2(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::atan2 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::atan2 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function atan2"), - } -} - -#[cfg(test)] -mod test { - use super::*; - use datafusion_common::cast::{as_float32_array, as_float64_array}; - - #[test] - fn test_atan2_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float64_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); - } - - #[test] - fn test_atan2_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float32_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); - } -} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 2ee1fffa16251..ee53fcf96a8b5 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,13 +18,11 @@ //! "math" DataFusion functions mod abs; -mod atan2; mod nans; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); -make_udf_function!(atan2::Atan2, ATAN2, atan2); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); @@ -39,6 +37,7 @@ make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); +make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( From 2bf7700740c7b5779d3caeebfde5796518f830cf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 1 Apr 2024 09:48:28 -0400 Subject: [PATCH 2/2] Update datafusion/functions/src/macros.rs --- datafusion/functions/src/macros.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 53eee2c748d45..c92cb27ef5bb5 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -253,7 +253,7 @@ macro_rules! make_math_unary_udf { /// Macro to create a binary math UDF. /// -/// A binary math function takes an argument of type Float32 or Float64, +/// A binary math function takes two arguments of types Float32 or Float64, /// applies a binary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl`