diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 7503c337517ef..06186f8696af5 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -133,6 +133,11 @@ harness = false name = "gcd" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "lcm" +required-features = ["math_expressions"] + [[bench]] harness = false name = "nanvl" diff --git a/datafusion/functions/benches/lcm.rs b/datafusion/functions/benches/lcm.rs new file mode 100644 index 0000000000000..247c0ec749d15 --- /dev/null +++ b/datafusion/functions/benches/lcm.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::Field; +use arrow::{ + array::{ArrayRef, Int64Array}, + datatypes::DataType, +}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::lcm; +use rand::Rng; +use std::hint::black_box; +use std::sync::Arc; + +fn generate_i64_array(n_rows: usize) -> ArrayRef { + let mut rng = rand::rng(); + let values = (0..n_rows) + .map(|_| rng.random_range(0..1000)) + .collect::>(); + Arc::new(Int64Array::from(values)) as ArrayRef +} + +fn criterion_benchmark(c: &mut Criterion) { + let n_rows = 100000; + let array_a = ColumnarValue::Array(generate_i64_array(n_rows)); + let array_b = ColumnarValue::Array(generate_i64_array(n_rows)); + let udf = lcm(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("lcm both array", |b| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![array_a.clone(), array_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", array_b.data_type(), true).into(), + ], + number_rows: n_rows, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("lcm should work on valid values"), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 35aa9c095a0db..8b92c454d9b4c 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, AsArray, Int64Array, PrimitiveArray}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::try_binary; use arrow::datatypes::{DataType, Int64Type}; use arrow::error::ArrowError; @@ -126,18 +126,23 @@ fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result { } fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option) -> Result { + let prim = arr.as_primitive::(); match scalar { + Some(scalar_value) if scalar_value != 0 && scalar_value != i64::MIN => { + // The gcd result divides both inputs' absolute values. When the + // scalar is neither 0 nor i64::MIN, the gcd's absolute value fits + // in i64, so the cast to i64 below cannot overflow. This allows us + // to use `unary` instead of `try_unary`, which allows LLVM to + // vectorize more effectively. + let sv = scalar_value.unsigned_abs(); + let result: PrimitiveArray = + prim.unary(|val| unsigned_gcd(val.unsigned_abs(), sv) as i64); + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } Some(scalar_value) => { - let result: Result = arr - .as_primitive::() - .iter() - .map(|val| match val { - Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)), - _ => Ok(None), - }) - .collect(); - - result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + let result: PrimitiveArray = + prim.try_unary(|val| compute_gcd(val, scalar_value))?; + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), } @@ -171,7 +176,8 @@ pub fn compute_gcd(x: i64, y: i64) -> Result { let a = x.unsigned_abs(); let b = y.unsigned_abs(); let r = unsigned_gcd(a, b); - // gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64 + // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or + // gcd(i64::MIN, i64::MIN)), which does not fit into i64. r.try_into().map_err(|_| { ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})")) }) diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index c52b3d2c6ed9e..9398e9f8d6e00 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -17,12 +17,14 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::compute::try_binary; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; +use arrow::datatypes::Int64Type; use arrow::error::ArrowError; -use datafusion_common::{Result, arrow_datafusion_err, exec_err}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -91,7 +93,7 @@ impl ScalarUDFImpl for LcmFunc { /// Lcm SQL function fn lcm(args: &[ArrayRef]) -> Result { - let compute_lcm = |x: i64, y: i64| { + let compute_lcm = |x: i64, y: i64| -> Result { if x == 0 || y == 0 { return Ok(0); } @@ -105,55 +107,20 @@ fn lcm(args: &[ArrayRef]) -> Result { .checked_mul(b) .and_then(|v| i64::try_from(v).ok()) .ok_or_else(|| { - arrow_datafusion_err!(ArrowError::ComputeError(format!( + ArrowError::ComputeError(format!( "Signed integer overflow in LCM({x}, {y})" - ))) + )) }) }; match args[0].data_type() { Int64 => { - let arg1 = downcast_named_arg!(&args[0], "x", Int64Array); - let arg2 = downcast_named_arg!(&args[1], "y", Int64Array); - - Ok(arg1 - .iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Ok(Some(compute_lcm(a1, a2)?)), - _ => Ok(None), - }) - .collect::>() - .map(Arc::new)? as ArrayRef) + let arg1 = args[0].as_primitive::(); + let arg2 = args[1].as_primitive::(); + + let result: PrimitiveArray = try_binary(arg1, arg2, compute_lcm)?; + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function lcm"), } } - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{ArrayRef, Int64Array}; - - use datafusion_common::cast::as_int64_array; - - use crate::math::lcm::lcm; - - #[test] - fn test_lcm_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x - Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y - ]; - - let result = lcm(&args).expect("failed to initialize function lcm"); - let ints = as_int64_array(&result).expect("failed to initialize function lcm"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 0); - assert_eq!(ints.value(1), 6); - assert_eq!(ints.value(2), 75); - assert_eq!(ints.value(3), 16); - } -} diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index d571fcd947134..e00edf47c176b 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -857,3 +857,87 @@ query RT SELECT log(2.5, 10.9::double), arrow_typeof(log(2.5, 10.9::double)); ---- 2.606992198152 Float64 + +# lcm with array and scalar + +query I +SELECT lcm(column1, 5) FROM (VALUES (0), (3), (25), (-16)); +---- +0 +15 +25 +80 + +query I +SELECT lcm(6, column1) FROM (VALUES (4), (9), (0)); +---- +12 +18 +0 + +# lcm array and scalar with nulls in the array +query I +SELECT lcm(column1, 5) FROM (VALUES (0), (NULL), (25)); +---- +0 +NULL +25 + +query I +SELECT lcm(6, column1) FROM (VALUES (4), (NULL), (0)); +---- +12 +NULL +0 + +# lcm scalar edge values +query I +SELECT lcm(9223372036854775807, 1); +---- +9223372036854775807 + +query I +SELECT lcm(9223372036854775807, 9223372036854775807); +---- +9223372036854775807 + +# gcd with array and scalar + +query I +SELECT gcd(column1, 12) FROM (VALUES (8), (18), (0), (-36)); +---- +4 +6 +12 +12 + +query I +SELECT gcd(15, column1) FROM (VALUES (10), (25), (0)); +---- +5 +5 +15 + +# gcd array and scalar with nulls in the array +query I +SELECT gcd(column1, 12) FROM (VALUES (8), (NULL), (0), (-36)); +---- +4 +NULL +12 +12 + +query I +SELECT gcd(15, column1) FROM (VALUES (NULL), (25), (0)); +---- +NULL +5 +15 + +# gcd array and scalar=0 +query I +SELECT gcd(column1, 0) FROM (VALUES (7), (-3), (0)); +---- +7 +3 +0