From 43741103bc8710f66011502c258d8301f169aaab Mon Sep 17 00:00:00 2001 From: Renato Marroquin Date: Thu, 30 Dec 2021 15:34:33 -0500 Subject: [PATCH 1/3] Add factorial function --- datafusion/Cargo.toml | 1 + datafusion/src/physical_plan/functions.rs | 6 ++++ .../src/physical_plan/math_expressions.rs | 28 +++++++++++++++ datafusion/tests/sql/functions.rs | 34 +++++++++++++++++++ 4 files changed, 69 insertions(+) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index b9192826120e4..9547ab2bf5c0b 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -77,6 +77,7 @@ rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.14", optional = true } +statrs = "0.15" [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index df073b62c5b78..d20875978c127 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -189,6 +189,8 @@ pub enum BuiltinScalarFunction { Digest, /// exp Exp, + /// Factorial + Factorial, /// floor Floor, /// ln, Natural logarithm @@ -328,6 +330,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Signum => Volatility::Immutable, BuiltinScalarFunction::Sin => Volatility::Immutable, @@ -406,6 +409,7 @@ impl FromStr for BuiltinScalarFunction { "ceil" => BuiltinScalarFunction::Ceil, "cos" => BuiltinScalarFunction::Cos, "exp" => BuiltinScalarFunction::Exp, + "factorial" => BuiltinScalarFunction::Factorial, "floor" => BuiltinScalarFunction::Floor, "ln" => BuiltinScalarFunction::Ln, "log" => BuiltinScalarFunction::Log, @@ -546,6 +550,7 @@ pub fn return_type( BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"), BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), + BuiltinScalarFunction::Factorial => Ok(DataType::Float64), BuiltinScalarFunction::NullIf => { // NULLIF has two args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &signature(fun)); @@ -734,6 +739,7 @@ pub fn create_physical_fun( BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), + BuiltinScalarFunction::Factorial => Arc::new(math_expressions::factorial), BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), BuiltinScalarFunction::Log => Arc::new(math_expressions::log10), BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index eabacfc6eb183..fcb306b156955 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -21,6 +21,7 @@ use crate::error::{DataFusionError, Result}; use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; use rand::{thread_rng, Rng}; +use statrs::function::factorial; use std::iter; use std::sync::Arc; @@ -102,6 +103,33 @@ math_unary_function!("ln", ln); math_unary_function!("log2", log2); math_unary_function!("log10", log10); +/// factorial SQL function +pub fn factorial(args: &[ColumnarValue]) -> Result { + match &args[0] { + ColumnarValue::Array(array) => { + let x1 = array.as_any().downcast_ref::(); + match x1 { + Some(array) => { + let res: Float64Array = + arrow::compute::kernels::arity::unary(array, |x| { + factorial::factorial(x as u64) + }); + let arc1 = Arc::new(res); + Ok(ColumnarValue::Array(arc1)) + } + _ => Err(DataFusionError::Internal( + format!("Invalid data type for ",), + )), + } + } + _ => { + return Err(DataFusionError::Internal( + "Expect factorial function to take some params".to_string(), + )) + } + } +} + /// random SQL function pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index 224f8ba1c0087..99aef419e1869 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -17,6 +17,40 @@ use super::*; +#[tokio::test] +async fn factorial() -> Result<()> { + let schema = Arc::new( + Schema::new(vec![ + Field::new("c1", DataType::Float64, true)] + )); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float64Array::from(vec![ + Some(4.0), + Some(0.0), + Some(5.0), + ]))], + )?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT factorial(c1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------+", + "| factorial(test.c1) |", + "+--------------------+", + "| 24 |", + "| 1 |", + "| 120 |", + "+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + /// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) #[tokio::test] async fn sqrt_f32_vs_f64() -> Result<()> { From 0d82eda1e36fb8d71975c8097cd3612246a55283 Mon Sep 17 00:00:00 2001 From: Renato Marroquin Date: Thu, 30 Dec 2021 15:45:05 -0500 Subject: [PATCH 2/3] Add factorial function --- datafusion/src/physical_plan/math_expressions.rs | 10 ++++------ datafusion/tests/sql/functions.rs | 5 +---- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index fcb306b156955..b986475699b96 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -118,15 +118,13 @@ pub fn factorial(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(arc1)) } _ => Err(DataFusionError::Internal( - format!("Invalid data type for ",), + "Invalid data type for factorial function".to_string(), )), } } - _ => { - return Err(DataFusionError::Internal( - "Expect factorial function to take some params".to_string(), - )) - } + _ => Err(DataFusionError::Internal( + "Expect factorial function to take some params".to_string(), + )), } } diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index 99aef419e1869..a7778d1550280 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -19,10 +19,7 @@ use super::*; #[tokio::test] async fn factorial() -> Result<()> { - let schema = Arc::new( - Schema::new(vec![ - Field::new("c1", DataType::Float64, true)] - )); + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); let data = RecordBatch::try_new( schema.clone(), From 3223d434c981444c476555c88a09d13bbe30fde7 Mon Sep 17 00:00:00 2001 From: Renato Marroquin Date: Thu, 30 Dec 2021 23:57:06 -0500 Subject: [PATCH 3/3] Add factorial function --- datafusion/tests/sql/functions.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index a7778d1550280..61600660f46da 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -26,7 +26,9 @@ async fn factorial() -> Result<()> { vec![Arc::new(Float64Array::from(vec![ Some(4.0), Some(0.0), - Some(5.0), + Some(1.5), + Some(-1.0), + Some(10000000000.0), ]))], )?; let table = MemTable::try_new(schema, vec![vec![data]])?; @@ -41,7 +43,9 @@ async fn factorial() -> Result<()> { "+--------------------+", "| 24 |", "| 1 |", - "| 120 |", + "| 1 |", + "| 1 |", + "| inf |", "+--------------------+", ]; assert_batches_eq!(expected, &actual);