From bf0c6128ffba548a0f869214adea5d4d8a15a18a Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sat, 30 Mar 2024 20:23:07 +0800 Subject: [PATCH 1/5] Refactor math functions in datafusion code --- datafusion/expr/src/built_in_function.rs | 34 ++----------------- datafusion/expr/src/expr_fn.rs | 8 ----- datafusion/functions/src/math/mod.rs | 11 +++++- datafusion/physical-expr/src/functions.rs | 4 --- datafusion/proto/proto/datafusion.proto | 8 ++--- datafusion/proto/src/generated/pbjson.rs | 12 ------- datafusion/proto/src/generated/prost.rs | 16 +++------ .../proto/src/logical_plan/from_proto.rs | 18 ++-------- datafusion/proto/src/logical_plan/to_proto.rs | 4 --- 9 files changed, 23 insertions(+), 92 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f07e840275529..d8ea093ff0af1 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -37,16 +37,8 @@ use strum_macros::EnumIter; #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions - /// atan - Atan, /// atan2 Atan2, - /// acosh - Acosh, - /// asinh - Asinh, - /// atanh - Atanh, /// cbrt Cbrt, /// ceil @@ -165,11 +157,7 @@ impl BuiltinScalarFunction { pub fn volatility(&self) -> Volatility { match self { // Immutable scalar builtins - BuiltinScalarFunction::Atan => Volatility::Immutable, BuiltinScalarFunction::Atan2 => Volatility::Immutable, - BuiltinScalarFunction::Acosh => Volatility::Immutable, - BuiltinScalarFunction::Asinh => Volatility::Immutable, - BuiltinScalarFunction::Atanh => Volatility::Immutable, BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, @@ -273,11 +261,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Iszero => Ok(Boolean), - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Cos | BuiltinScalarFunction::Cosh | BuiltinScalarFunction::Degrees @@ -389,11 +373,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { Signature::uniform(2, vec![Int64], self.volatility()) } - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Cbrt + BuiltinScalarFunction::Cbrt | BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Cos | BuiltinScalarFunction::Cosh @@ -426,11 +406,7 @@ impl BuiltinScalarFunction { pub fn monotonicity(&self) -> Option { if matches!( &self, - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial @@ -455,10 +431,6 @@ impl BuiltinScalarFunction { /// Returns all names that can be used to call this function pub fn aliases(&self) -> &'static [&'static str] { match self { - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], BuiltinScalarFunction::Atan2 => &["atan2"], BuiltinScalarFunction::Cbrt => &["cbrt"], BuiltinScalarFunction::Ceil => &["ceil"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index e216e4e86dc13..39afaabe15d80 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -541,10 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine"); scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); -scalar_expr!(Atan, atan, num, "inverse tangent"); -scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine"); -scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine"); -scalar_expr!(Atanh, atanh, num, "inverse hyperbolic tangent"); scalar_expr!(Factorial, factorial, num, "factorial"); scalar_expr!( Floor, @@ -983,10 +979,6 @@ mod test { test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); - test_unary_scalar_expr!(Atan, atan); - test_unary_scalar_expr!(Asinh, asinh); - test_unary_scalar_expr!(Acosh, acosh); - test_unary_scalar_expr!(Atanh, atanh); test_unary_scalar_expr!(Factorial, factorial); test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 3a4c1b1e8710d..cccc790409645 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -33,6 +33,11 @@ make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); make_math_unary_udf!(TanFunc, TAN, tan, tan, None); +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh); +make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh); +make_math_unary_udf!(AtanFunc, ATAN, atan, atan); + // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( ( @@ -55,5 +60,9 @@ export_functions!( "returns the arc sine or inverse sine of a number" ), (tan, num, "returns the tangent of a number"), - (tanh, num, "returns the hyperbolic tangent of a number") + (tanh, num, "returns the hyperbolic tangent of a number"), + (atanh, num, "returnd inverse hyperbolic tangent"), + (asinh, num, "returnd inverse hyperbolic sine"), + (acosh, num, "returnd inverse hyperbolic cosine"), + (atan, num, "returnd inverse tangent") ); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 515511b15fbbf..5cd429d4c19e3 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -199,10 +199,6 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions - BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), - BuiltinScalarFunction::Acosh => Arc::new(math_expressions::acosh), - BuiltinScalarFunction::Asinh => Arc::new(math_expressions::asinh), - BuiltinScalarFunction::Atanh => Arc::new(math_expressions::atanh), BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), BuiltinScalarFunction::Cosh => Arc::new(math_expressions::cosh), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 81451e40aa50c..342094e9add28 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -544,7 +544,7 @@ enum ScalarFunction { unknown = 0; // 1 was Acos // 2 was Asin - Atan = 3; + // 3 was Atan // 4 was Ascii Ceil = 5; Cos = 6; @@ -615,9 +615,9 @@ enum ScalarFunction { // 71 was CurrentTime // 72 was Uuid Cbrt = 73; - Acosh = 74; - Asinh = 75; - Atanh = 76; + // 74 Acosh + // 75 was Asinh + // 76 was Atanh Sinh = 77; Cosh = 78; // Tanh = 79; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 2949ab807e048..57539dbdefd7d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22914,7 +22914,6 @@ impl serde::Serialize for ScalarFunction { { let variant = match self { Self::Unknown => "unknown", - Self::Atan => "Atan", Self::Ceil => "Ceil", Self::Cos => "Cos", Self::Exp => "Exp", @@ -22934,9 +22933,6 @@ impl serde::Serialize for ScalarFunction { Self::Power => "Power", Self::Atan2 => "Atan2", Self::Cbrt => "Cbrt", - Self::Acosh => "Acosh", - Self::Asinh => "Asinh", - Self::Atanh => "Atanh", Self::Sinh => "Sinh", Self::Cosh => "Cosh", Self::Pi => "Pi", @@ -22963,7 +22959,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { const FIELDS: &[&str] = &[ "unknown", - "Atan", "Ceil", "Cos", "Exp", @@ -22983,9 +22978,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Power", "Atan2", "Cbrt", - "Acosh", - "Asinh", - "Atanh", "Sinh", "Cosh", "Pi", @@ -23041,7 +23033,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { match value { "unknown" => Ok(ScalarFunction::Unknown), - "Atan" => Ok(ScalarFunction::Atan), "Ceil" => Ok(ScalarFunction::Ceil), "Cos" => Ok(ScalarFunction::Cos), "Exp" => Ok(ScalarFunction::Exp), @@ -23061,9 +23052,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Power" => Ok(ScalarFunction::Power), "Atan2" => Ok(ScalarFunction::Atan2), "Cbrt" => Ok(ScalarFunction::Cbrt), - "Acosh" => Ok(ScalarFunction::Acosh), - "Asinh" => Ok(ScalarFunction::Asinh), - "Atanh" => Ok(ScalarFunction::Atanh), "Sinh" => Ok(ScalarFunction::Sinh), "Cosh" => Ok(ScalarFunction::Cosh), "Pi" => Ok(ScalarFunction::Pi), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6f7e8a9789a6d..1523fdc93cd5a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2843,7 +2843,7 @@ pub enum ScalarFunction { Unknown = 0, /// 1 was Acos /// 2 was Asin - Atan = 3, + /// 3 was Atan /// 4 was Ascii Ceil = 5, Cos = 6, @@ -2914,9 +2914,9 @@ pub enum ScalarFunction { /// 71 was CurrentTime /// 72 was Uuid Cbrt = 73, - Acosh = 74, - Asinh = 75, - Atanh = 76, + /// 74 Acosh + /// 75 was Asinh + /// 76 was Atanh Sinh = 77, Cosh = 78, /// Tanh = 79; @@ -2987,7 +2987,6 @@ impl ScalarFunction { pub fn as_str_name(&self) -> &'static str { match self { ScalarFunction::Unknown => "unknown", - ScalarFunction::Atan => "Atan", ScalarFunction::Ceil => "Ceil", ScalarFunction::Cos => "Cos", ScalarFunction::Exp => "Exp", @@ -3007,9 +3006,6 @@ impl ScalarFunction { ScalarFunction::Power => "Power", ScalarFunction::Atan2 => "Atan2", ScalarFunction::Cbrt => "Cbrt", - ScalarFunction::Acosh => "Acosh", - ScalarFunction::Asinh => "Asinh", - ScalarFunction::Atanh => "Atanh", ScalarFunction::Sinh => "Sinh", ScalarFunction::Cosh => "Cosh", ScalarFunction::Pi => "Pi", @@ -3030,7 +3026,6 @@ impl ScalarFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "unknown" => Some(Self::Unknown), - "Atan" => Some(Self::Atan), "Ceil" => Some(Self::Ceil), "Cos" => Some(Self::Cos), "Exp" => Some(Self::Exp), @@ -3050,9 +3045,6 @@ impl ScalarFunction { "Power" => Some(Self::Power), "Atan2" => Some(Self::Atan2), "Cbrt" => Some(Self::Cbrt), - "Acosh" => Some(Self::Acosh), - "Asinh" => Some(Self::Asinh), - "Atanh" => Some(Self::Atanh), "Sinh" => Some(Self::Sinh), "Cosh" => Some(Self::Cosh), "Pi" => Some(Self::Pi), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d372cb428c73a..3458605464e1f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,8 +37,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, - cos, cosh, cot, degrees, ends_with, exp, + atan2, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, + ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, log, logical_plan::{PlanType, StringifiedPlan}, @@ -429,12 +429,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sin => Self::Sin, ScalarFunction::Cos => Self::Cos, ScalarFunction::Cot => Self::Cot, - ScalarFunction::Atan => Self::Atan, ScalarFunction::Sinh => Self::Sinh, ScalarFunction::Cosh => Self::Cosh, - ScalarFunction::Asinh => Self::Asinh, - ScalarFunction::Acosh => Self::Acosh, - ScalarFunction::Atanh => Self::Atanh, ScalarFunction::Exp => Self::Exp, ScalarFunction::Log => Self::Log, ScalarFunction::Degrees => Self::Degrees, @@ -1322,22 +1318,12 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Asinh => { - Ok(asinh(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Acosh => { - Ok(acosh(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Atanh => { - Ok(atanh(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Degrees => { Ok(degrees(parse_expr(&args[0], registry, codec)?)) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1e4e85c51f70f..fbe1b20cacbd8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1422,10 +1422,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, - BuiltinScalarFunction::Atan => Self::Atan, - BuiltinScalarFunction::Asinh => Self::Asinh, - BuiltinScalarFunction::Acosh => Self::Acosh, - BuiltinScalarFunction::Atanh => Self::Atanh, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, BuiltinScalarFunction::Gcd => Self::Gcd, From e058945ba5f20edc86c4dc8a5f858daa65ad4496 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sat, 30 Mar 2024 20:35:57 +0800 Subject: [PATCH 2/5] fic ci --- datafusion/sqllogictest/test_files/order.slt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index f63179a369c58..f0147137dce94 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -527,9 +527,10 @@ Sort: atan_c11 ASC NULLS LAST ----TableScan: aggregate_test_100 projection=[c11] physical_plan SortPreservingMergeExec: [atan_c11@0 ASC NULLS LAST] ---ProjectionExec: expr=[atan(c11@0) as atan_c11] -----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +--SortExec: expr=[atan_c11@0 ASC NULLS LAST] +----ProjectionExec: expr=[atan(c11@0) as atan_c11] +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true query TT EXPLAIN SELECT CEIL(c11) as ceil_c11 From a77e8825791bce513e8897909d3e9460d3db4927 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sun, 31 Mar 2024 10:07:29 +0800 Subject: [PATCH 3/5] fix: avoid regression --- datafusion/functions/src/macros.rs | 1 + datafusion/functions/src/math/mod.rs | 8 ++++---- datafusion/sqllogictest/test_files/order.slt | 7 +++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index b23baeeacf235..98845593b8fa5 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -156,6 +156,7 @@ macro_rules! downcast_arg { /// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument +/// $MONOTONIC_FUNC: the monotonicity of the function macro_rules! make_math_unary_udf { ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index cccc790409645..9d1eb4f1fce7a 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -33,10 +33,10 @@ make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); make_math_unary_udf!(TanFunc, TAN, tan, tan, None); -make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh); -make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh); -make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh); -make_math_unary_udf!(AtanFunc, ATAN, atan, atan); +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); +make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); +make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index f0147137dce94..f63179a369c58 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -527,10 +527,9 @@ Sort: atan_c11 ASC NULLS LAST ----TableScan: aggregate_test_100 projection=[c11] physical_plan SortPreservingMergeExec: [atan_c11@0 ASC NULLS LAST] ---SortExec: expr=[atan_c11@0 ASC NULLS LAST] -----ProjectionExec: expr=[atan(c11@0) as atan_c11] -------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true +--ProjectionExec: expr=[atan(c11@0) as atan_c11] +----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true query TT EXPLAIN SELECT CEIL(c11) as ceil_c11 From 1a0da9d46c6b29e6632f753a6290e619f1a8c3a2 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sun, 31 Mar 2024 10:43:48 +0800 Subject: [PATCH 4/5] refactor: move atan2 function --- datafusion/expr/src/built_in_function.rs | 14 +-- datafusion/expr/src/expr_fn.rs | 2 - datafusion/functions/src/macros.rs | 28 ++++++ datafusion/functions/src/math/atan2.rs | 98 +++++++++++++++++++ datafusion/functions/src/math/mod.rs | 11 ++- datafusion/functions/src/utils.rs | 3 + datafusion/physical-expr/src/functions.rs | 3 - .../physical-expr/src/math_expressions.rs | 25 ----- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 7 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - 13 files changed, 140 insertions(+), 61 deletions(-) create mode 100644 datafusion/functions/src/math/atan2.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d8ea093ff0af1..d0009c33a997f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -37,8 +37,6 @@ use strum_macros::EnumIter; #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions - /// atan2 - Atan2, /// cbrt Cbrt, /// ceil @@ -157,7 +155,6 @@ impl BuiltinScalarFunction { pub fn volatility(&self) -> Volatility { match self { // Immutable scalar builtins - BuiltinScalarFunction::Atan2 => Volatility::Immutable, BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, @@ -244,11 +241,6 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, - BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - BuiltinScalarFunction::Log => match &input_expr_types[0] { Float32 => Ok(Float32), _ => Ok(Float64), @@ -350,10 +342,7 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - self.volatility(), - ), + BuiltinScalarFunction::Log => Signature::one_of( vec![ Exact(vec![Float32]), @@ -431,7 +420,6 @@ impl BuiltinScalarFunction { /// Returns all names that can be used to call this function pub fn aliases(&self) -> &'static [&'static str] { match self { - BuiltinScalarFunction::Atan2 => &["atan2"], BuiltinScalarFunction::Cbrt => &["cbrt"], BuiltinScalarFunction::Ceil => &["ceil"], BuiltinScalarFunction::Cos => &["cos"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 39afaabe15d80..774f6a01f9d8d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -567,7 +567,6 @@ scalar_expr!(Exp, exp, num, "exponential"); scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor"); scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); -scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); @@ -990,7 +989,6 @@ mod test { test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Signum, signum); test_unary_scalar_expr!(Exp, exp); - test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 98845593b8fa5..4907d74fe941a 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -250,3 +250,31 @@ macro_rules! make_math_unary_udf { } }; } + +#[macro_export] +macro_rules! make_function_inputs2 { + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE1>() + }}; +} diff --git a/datafusion/functions/src/math/atan2.rs b/datafusion/functions/src/math/atan2.rs new file mode 100644 index 0000000000000..ad2a67a5954f5 --- /dev/null +++ b/datafusion/functions/src/math/atan2.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `atan2()`. + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use crate::make_function_inputs2; +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub(super) struct Atan2 { + signature: Signature, +} + +impl Atan2 { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for Atan2 { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "atan2" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use self::DataType::*; + match &arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(atan2, vec![])(args) + } +} + +/// Atan2 SQL function +pub fn atan2(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::atan2 } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::atan2 } + )) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function atan2"), + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 9d1eb4f1fce7a..2ee1fffa16251 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,11 +18,13 @@ //! "math" DataFusion functions mod abs; +mod atan2; mod nans; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(atan2::Atan2, ATAN2, atan2); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); @@ -61,8 +63,9 @@ export_functions!( ), (tan, num, "returns the tangent of a number"), (tanh, num, "returns the hyperbolic tangent of a number"), - (atanh, num, "returnd inverse hyperbolic tangent"), - (asinh, num, "returnd inverse hyperbolic sine"), - (acosh, num, "returnd inverse hyperbolic cosine"), - (atan, num, "returnd inverse tangent") + (atanh, num, "returns inverse hyperbolic tangent"), + (asinh, num, "returns inverse hyperbolic sine"), + (acosh, num, "returns inverse hyperbolic cosine"), + (atan, num, "returns inverse tangent"), + (atan2, y x, "returns inverse tangent of a division given in the argument") ); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index f45deafdb37a0..9b7144b483bd6 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -68,6 +68,9 @@ get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); // `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); +/// Creates a scalar function implementation for the given function. +/// * `inner` - the function to be executed +/// * `hints` - hints to be used when expanding scalars to arrays pub(super) fn make_scalar_function( inner: F, hints: Vec, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5cd429d4c19e3..0cfe910225326 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -237,9 +237,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } - BuiltinScalarFunction::Atan2 => { - Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args)) - } BuiltinScalarFunction::Log => { Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index db8855cb5400b..428308a392964 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -492,31 +492,6 @@ pub fn power(args: &[ArrayRef]) -> Result { } } -/// Atan2 SQL function -pub fn atan2(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::atan2 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::atan2 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function atan2"), - } -} - /// Log SQL function pub fn log(args: &[ArrayRef]) -> Result { // Support overloaded log(base, x) and log(x) which defaults to log(10, x) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 342094e9add28..b4c0bc0347d35 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -608,7 +608,7 @@ enum ScalarFunction { Power = 64; // 65 was StructFun // 66 was FromUnixtime - Atan2 = 67; + // 67 Atan2 // 68 was DateBin // 69 was ArrowTypeof // 70 was CurrentDate diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 57539dbdefd7d..48b5781bf8d1c 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22931,7 +22931,6 @@ impl serde::Serialize for ScalarFunction { Self::Translate => "Translate", Self::Coalesce => "Coalesce", Self::Power => "Power", - Self::Atan2 => "Atan2", Self::Cbrt => "Cbrt", Self::Sinh => "Sinh", Self::Cosh => "Cosh", @@ -22976,7 +22975,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Translate", "Coalesce", "Power", - "Atan2", "Cbrt", "Sinh", "Cosh", @@ -23050,7 +23048,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Translate" => Ok(ScalarFunction::Translate), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), - "Atan2" => Ok(ScalarFunction::Atan2), "Cbrt" => Ok(ScalarFunction::Cbrt), "Sinh" => Ok(ScalarFunction::Sinh), "Cosh" => Ok(ScalarFunction::Cosh), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 1523fdc93cd5a..0e82e8be65c42 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2907,7 +2907,7 @@ pub enum ScalarFunction { Power = 64, /// 65 was StructFun /// 66 was FromUnixtime - Atan2 = 67, + /// 67 Atan2 /// 68 was DateBin /// 69 was ArrowTypeof /// 70 was CurrentDate @@ -3004,7 +3004,6 @@ impl ScalarFunction { ScalarFunction::Translate => "Translate", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", - ScalarFunction::Atan2 => "Atan2", ScalarFunction::Cbrt => "Cbrt", ScalarFunction::Sinh => "Sinh", ScalarFunction::Cosh => "Cosh", @@ -3043,7 +3042,6 @@ impl ScalarFunction { "Translate" => Some(Self::Translate), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), - "Atan2" => Some(Self::Atan2), "Cbrt" => Some(Self::Cbrt), "Sinh" => Some(Self::Sinh), "Cosh" => Some(Self::Cosh), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3458605464e1f..43b924b73a4ae 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,7 +37,7 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - atan2, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, + cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, log, @@ -452,7 +452,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, - ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, ScalarFunction::SubstrIndex => Self::SubstrIndex, @@ -1382,10 +1381,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Atan2 => Ok(atan2( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( parse_expr(&args[0], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index fbe1b20cacbd8..f443676fa92eb 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1443,7 +1443,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, - BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, From e5932810e2f161f5f644e329b199227be9097217 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sun, 31 Mar 2024 10:53:51 +0800 Subject: [PATCH 5/5] chore: move atan2 test --- datafusion/functions/src/math/atan2.rs | 42 +++++++++++++++++++ .../physical-expr/src/math_expressions.rs | 36 ---------------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/datafusion/functions/src/math/atan2.rs b/datafusion/functions/src/math/atan2.rs index ad2a67a5954f5..b090c6c454fd8 100644 --- a/datafusion/functions/src/math/atan2.rs +++ b/datafusion/functions/src/math/atan2.rs @@ -96,3 +96,45 @@ pub fn atan2(args: &[ArrayRef]) -> Result { other => exec_err!("Unsupported data type {other:?} for function atan2"), } } + +#[cfg(test)] +mod test { + use super::*; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + + #[test] + fn test_atan2_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("failed to initialize function atan2"); + let floats = + as_float64_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); + } + + #[test] + fn test_atan2_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("failed to initialize function atan2"); + let floats = + as_float32_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); + } +} diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 428308a392964..5339c12f6e939 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -700,42 +700,6 @@ mod tests { assert_eq!(floats.value(3), 625); } - #[test] - fn test_atan2_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float64_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); - } - - #[test] - fn test_atan2_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float32_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); - } - #[test] fn test_log_f64() { let args: Vec = vec![