From 0bd980303f96e01390ea3b8b908c404c7064ec2b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Apr 2023 14:40:06 -0700 Subject: [PATCH 1/2] Use ScalarValue for single input on math expression --- .../physical-expr/src/math_expressions.rs | 217 ++++++++++++++---- 1 file changed, 169 insertions(+), 48 deletions(-) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 2527e858e281c..4a1ae1bd66cbb 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -21,6 +21,7 @@ use arrow::array::ArrayRef; use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; +use datafusion_common::ScalarValue::Float32; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use rand::{thread_rng, Rng}; @@ -100,6 +101,19 @@ macro_rules! downcast_arg { }}; } +macro_rules! make_function_scalar_inputs { + ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); + + arg.iter() + .map(|a| match a { + Some(a) => Some($FUNC(a)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; +} + macro_rules! make_function_inputs2 { ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); @@ -170,43 +184,86 @@ pub fn round(args: &[ArrayRef]) -> Result { ))); } - let mut decimal_places = - &(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef); + let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); if args.len() == 2 { - decimal_places = &args[1]; + decimal_places = ColumnarValue::Array(args[1].clone()); } match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float64Array, - Int64Array, - { - |value: f64, decimal_places: i64| { - (value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round() - / 10.0_f64.powi(decimal_places.try_into().unwrap()) - } + DataType::Float64 => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places = decimal_places.try_into().unwrap(); + + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float64Array, + { + |value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } + } + )) as ArrayRef) } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float32Array, - Int64Array, - { - |value: f32, decimal_places: i64| { - (value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round() - / 10.0_f32.powi(decimal_places.try_into().unwrap()) + ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int64Array, + { + |value: f64, decimal_places: i64| { + (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) + .round() + / 10.0_f64.powi(decimal_places.try_into().unwrap()) + } } + )) as ArrayRef), + _ => Err(DataFusionError::Internal( + "round function requires a scalar or array for decimal_places" + .to_string(), + )), + }, + + DataType::Float32 => match decimal_places { + ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { + let decimal_places = decimal_places.try_into().unwrap(); + + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float32Array, + { + |value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } + } + )) as ArrayRef) } - )) as ArrayRef), + ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int64Array, + { + |value: f32, decimal_places: i64| { + (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) + .round() + / 10.0_f32.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + _ => Err(DataFusionError::Internal( + "round function requires a scalar or array for decimal_places" + .to_string(), + )), + }, other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function round" @@ -272,30 +329,64 @@ pub fn atan2(args: &[ArrayRef]) -> Result { pub fn log(args: &[ArrayRef]) -> Result { // Support overloaded log(base, x) and log(x) which defaults to log(10, x) // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) - let mut base = &(Arc::new(Float32Array::from_value(10.0, args[0].len())) as ArrayRef); + let mut base = ColumnarValue::Scalar(Float32(Some(10.0))); + let mut x = &args[0]; if args.len() == 2 { x = &args[1]; - base = &args[0]; + base = ColumnarValue::Array(args[0].clone()); } match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )) as ArrayRef), + DataType::Float64 => match base { + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { + let base = base as f64; + return Ok( + Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { + |value: f64| f64::log(value, base) + })) as ArrayRef, + ); + } + ColumnarValue::Array(base) => { + return Ok(Arc::new(make_function_inputs2!( + x, + base, + "x", + "base", + Float64Array, + { f64::log } + )) as ArrayRef); + } + _ => { + return Err(DataFusionError::Internal( + "log function requires a scalar or array for base".to_string(), + )) + } + }, + + DataType::Float32 => match base { + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { + return Ok( + Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { + |value: f32| f32::log(value, base) + })) as ArrayRef, + ); + } + ColumnarValue::Array(base) => { + return Ok(Arc::new(make_function_inputs2!( + x, + base, + "x", + "base", + Float32Array, + { f32::log } + )) as ArrayRef); + } + _ => { + return Err(DataFusionError::Internal( + "log function requires a scalar or array for base".to_string(), + )) + } + }, other => Err(DataFusionError::Internal(format!( "Unsupported data type {other:?} for function log" @@ -466,4 +557,34 @@ mod tests { assert_eq!(floats, &expected); } + + #[test] + fn test_round_f32_one_input() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float32_array(&result).expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64_one_input() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = + as_float64_array(&result).expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]); + + assert_eq!(floats, &expected); + } } From e4944ad29d021193cbede7ae23dddd20be762651 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Apr 2023 20:30:56 -0700 Subject: [PATCH 2/2] Fix clippy --- .../physical-expr/src/math_expressions.rs | 69 ++++++++----------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 4a1ae1bd66cbb..0be352529e4a5 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -340,52 +340,43 @@ pub fn log(args: &[ArrayRef]) -> Result { DataType::Float64 => match base { ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { let base = base as f64; - return Ok( + Ok( Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { |value: f64| f64::log(value, base) })) as ArrayRef, - ); - } - ColumnarValue::Array(base) => { - return Ok(Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )) as ArrayRef); - } - _ => { - return Err(DataFusionError::Internal( - "log function requires a scalar or array for base".to_string(), - )) + ) } + ColumnarValue::Array(base) => Ok(Arc::new(make_function_inputs2!( + x, + base, + "x", + "base", + Float64Array, + { f64::log } + )) as ArrayRef), + _ => Err(DataFusionError::Internal( + "log function requires a scalar or array for base".to_string(), + )), }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - return Ok( - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) as ArrayRef, - ); - } - ColumnarValue::Array(base) => { - return Ok(Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )) as ArrayRef); - } - _ => { - return Err(DataFusionError::Internal( - "log function requires a scalar or array for base".to_string(), - )) - } + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Ok(Arc::new( + make_function_scalar_inputs!(x, "x", Float32Array, { + |value: f32| f32::log(value, base) + }), + ) + as ArrayRef), + ColumnarValue::Array(base) => Ok(Arc::new(make_function_inputs2!( + x, + base, + "x", + "base", + Float32Array, + { f32::log } + )) as ArrayRef), + _ => Err(DataFusionError::Internal( + "log function requires a scalar or array for base".to_string(), + )), }, other => Err(DataFusionError::Internal(format!(