diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index ffe12466dc173..c1dd802140c04 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -22,8 +22,9 @@ use std::sync::Arc; use arrow::datatypes::DataType::Int64; use arrow::datatypes::{DataType, Int64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{Result, exec_err}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -81,7 +82,39 @@ impl ScalarUDFImpl for FactorialFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(factorial, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))); + } + + match scalar { + ScalarValue::Int64(Some(v)) => { + let result = compute_factorial(v)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function factorial", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Int64 => { + let result: Int64Array = array + .as_primitive::() + .try_unary(compute_factorial)?; + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } + other => { + internal_err!("Unexpected data type {other:?} for function factorial") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -113,53 +146,12 @@ const FACTORIALS: [i64; 21] = [ 2432902008176640000, ]; // if return type changes, this constant needs to be updated accordingly -/// Factorial SQL function -fn factorial(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Int64 => { - let result: Int64Array = - args[0].as_primitive::().try_unary(|a| { - if a < 0 { - Ok(1) - } else if a < FACTORIALS.len() as i64 { - Ok(FACTORIALS[a as usize]) - } else { - exec_err!("Overflow happened on FACTORIAL({a})") - } - })?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!("Unsupported data type {other:?} for function factorial."), - } -} - -#[cfg(test)] -mod test { - use super::*; - use datafusion_common::cast::as_int64_array; - - #[test] - fn test_factorial_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 1, 2, 4, 20, -1])), // input - ]; - - let result = factorial(&args).expect("failed to initialize function factorial"); - let ints = - as_int64_array(&result).expect("failed to initialize function factorial"); - - let expected = Int64Array::from(vec![1, 1, 2, 24, 2432902008176640000, 1]); - - assert_eq!(ints, &expected); - } - - #[test] - fn test_overflow() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![21])), // input - ]; - - let result = factorial(&args); - assert!(result.is_err()); +fn compute_factorial(n: i64) -> Result { + if n < 0 { + Ok(1) + } else if n < FACTORIALS.len() as i64 { + Ok(FACTORIALS[n as usize]) + } else { + exec_err!("Overflow happened on FACTORIAL({n})") } }