diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 304739b42f5f..0fdefbafa157 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -24,6 +24,7 @@ use arrow::util::bench_util::{ }; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -80,6 +81,44 @@ fn invoke_repeat_with_args( } fn criterion_benchmark(c: &mut Criterion) { + let repeat_fn = string::repeat(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("repeat/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("repeat/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8View, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 2ca5e190c6e0..65f320c4f9f1 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -18,16 +18,17 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -99,7 +100,63 @@ impl ScalarUDFImpl for RepeatFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(repeat, vec![])(&args.args) + let return_type = args.return_field.data_type().clone(); + let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; + + // Early return if either argument is a scalar null + if let ColumnarValue::Scalar(s) = &string_arg + && s.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + if let ColumnarValue::Scalar(c) = &count_arg + && c.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + + match (&string_arg, &count_arg) { + ( + ColumnarValue::Scalar(string_scalar), + ColumnarValue::Scalar(count_scalar), + ) => { + let count = match count_scalar { + ScalarValue::Int64(Some(n)) => *n, + _ => { + return internal_err!( + "Unexpected data type {:?} for repeat count", + count_scalar.data_type() + ); + } + }; + + let result = match string_scalar { + ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { + ScalarValue::Utf8(Some(compute_repeat( + s, + count, + i32::MAX as usize, + )?)) + } + ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some( + compute_repeat(s, count, i64::MAX as usize)?, + )), + _ => { + return internal_err!( + "Unexpected data type {:?} for function repeat", + string_scalar.data_type() + ); + } + }; + + Ok(ColumnarValue::Scalar(result)) + } + _ => { + let string_array = string_arg.to_array(args.number_rows)?; + let count_array = count_arg.to_array(args.number_rows)?; + Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -107,13 +164,30 @@ impl ScalarUDFImpl for RepeatFunc { } } +/// Computes repeat for a single string value with max size check +#[inline] +fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { + if count <= 0 { + return Ok(String::new()); + } + let result_len = s.len().saturating_mul(count as usize); + if result_len > max_size { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + result_len + ); + } + Ok(s.repeat(count as usize)) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -fn repeat(args: &[ArrayRef]) -> Result { - let number_array = as_int64_array(&args[1])?; - match args[0].data_type() { +fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { + let number_array = as_int64_array(count_array)?; + match string_array.data_type() { Utf8View => { - let string_view_array = args[0].as_string_view(); + let string_view_array = string_array.as_string_view(); repeat_impl::( &string_view_array, number_array, @@ -121,17 +195,17 @@ fn repeat(args: &[ArrayRef]) -> Result { ) } Utf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i32::MAX as usize, ) } LargeUtf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i64::MAX as usize, ) @@ -150,7 +224,7 @@ fn repeat_impl<'a, T, S>( ) -> Result where T: OffsetSizeTrait, - S: StringArrayType<'a>, + S: StringArrayType<'a> + 'a, { let mut total_capacity = 0; let mut max_item_capacity = 0; @@ -181,37 +255,55 @@ where // Reusable buffer to avoid allocations in string.repeat() let mut buffer = Vec::::with_capacity(max_item_capacity); - string_array - .iter() - .zip(number_array.iter()) - .for_each(|(string, number)| { + // Helper function to repeat a string into a buffer using doubling strategy + // count must be > 0 + #[inline] + fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: usize) { + buffer.clear(); + if !string.is_empty() { + let src = string.as_bytes(); + // Initial copy + buffer.extend_from_slice(src); + // Doubling strategy: copy what we have so far until we reach the target + while buffer.len() < src.len() * count { + let copy_len = buffer.len().min(src.len() * count - buffer.len()); + // SAFETY: we're copying valid UTF-8 bytes that we already verified + buffer.extend_from_within(..copy_len); + } + } + } + + // Fast path: no nulls in either array + if string_array.null_count() == 0 && number_array.null_count() == 0 { + for i in 0..string_array.len() { + // SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value + let string = unsafe { string_array.value_unchecked(i) }; + let count = number_array.value(i); + if count > 0 { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str + builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); + } else { + builder.append_value(""); + } + } + } else { + // Slow path: handle nulls + for (string, number) in string_array.iter().zip(number_array.iter()) { match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - buffer.clear(); - let count = number as usize; - if count > 0 && !string.is_empty() { - let src = string.as_bytes(); - // Initial copy - buffer.extend_from_slice(src); - // Doubling strategy: copy what we have so far until we reach the target - while buffer.len() < src.len() * count { - let copy_len = - buffer.len().min(src.len() * count - buffer.len()); - // SAFETY: we're copying valid UTF-8 bytes that we already verified - buffer.extend_from_within(..copy_len); - } - } - // SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str + (Some(string), Some(count)) if count > 0 => { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str builder .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); } (Some(_), Some(_)) => builder.append_value(""), _ => builder.append_null(), } - }); - let array = builder.finish(); + } + } - Ok(Arc::new(array) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)]