diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index b9192826120e4..9547ab2bf5c0b 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -77,6 +77,7 @@ rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.14", optional = true } +statrs = "0.15" [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index df073b62c5b78..d20875978c127 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -189,6 +189,8 @@ pub enum BuiltinScalarFunction { Digest, /// exp Exp, + /// Factorial + Factorial, /// floor Floor, /// ln, Natural logarithm @@ -328,6 +330,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Signum => Volatility::Immutable, BuiltinScalarFunction::Sin => Volatility::Immutable, @@ -406,6 +409,7 @@ impl FromStr for BuiltinScalarFunction { "ceil" => BuiltinScalarFunction::Ceil, "cos" => BuiltinScalarFunction::Cos, "exp" => BuiltinScalarFunction::Exp, + "factorial" => BuiltinScalarFunction::Factorial, "floor" => BuiltinScalarFunction::Floor, "ln" => BuiltinScalarFunction::Ln, "log" => BuiltinScalarFunction::Log, @@ -546,6 +550,7 @@ pub fn return_type( BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"), BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), + BuiltinScalarFunction::Factorial => Ok(DataType::Float64), BuiltinScalarFunction::NullIf => { // NULLIF has two args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &signature(fun)); @@ -734,6 +739,7 @@ pub fn create_physical_fun( BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp), + BuiltinScalarFunction::Factorial => Arc::new(math_expressions::factorial), BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), BuiltinScalarFunction::Log => Arc::new(math_expressions::log10), BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index eabacfc6eb183..b986475699b96 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -21,6 +21,7 @@ use crate::error::{DataFusionError, Result}; use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; use rand::{thread_rng, Rng}; +use statrs::function::factorial; use std::iter; use std::sync::Arc; @@ -102,6 +103,31 @@ math_unary_function!("ln", ln); math_unary_function!("log2", log2); math_unary_function!("log10", log10); +/// factorial SQL function +pub fn factorial(args: &[ColumnarValue]) -> Result { + match &args[0] { + ColumnarValue::Array(array) => { + let x1 = array.as_any().downcast_ref::(); + match x1 { + Some(array) => { + let res: Float64Array = + arrow::compute::kernels::arity::unary(array, |x| { + factorial::factorial(x as u64) + }); + let arc1 = Arc::new(res); + Ok(ColumnarValue::Array(arc1)) + } + _ => Err(DataFusionError::Internal( + "Invalid data type for factorial function".to_string(), + )), + } + } + _ => Err(DataFusionError::Internal( + "Expect factorial function to take some params".to_string(), + )), + } +} + /// random SQL function pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index 224f8ba1c0087..61600660f46da 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -17,6 +17,41 @@ use super::*; +#[tokio::test] +async fn factorial() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Float64, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float64Array::from(vec![ + Some(4.0), + Some(0.0), + Some(1.5), + Some(-1.0), + Some(10000000000.0), + ]))], + )?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = "SELECT factorial(c1) FROM test"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+--------------------+", + "| factorial(test.c1) |", + "+--------------------+", + "| 24 |", + "| 1 |", + "| 1 |", + "| 1 |", + "| inf |", + "+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + /// sqrt(f32) is slightly different than sqrt(CAST(f32 AS double))) #[tokio::test] async fn sqrt_f32_vs_f64() -> Result<()> {