From 631af34248d0678697be9e2083796aa5b6c97ded Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 24 Jan 2026 21:47:51 +0530 Subject: [PATCH 1/3] perf: Optimize repeat function for scalar and array fast --- datafusion/functions/benches/repeat.rs | 39 +++++++ datafusion/functions/src/string/repeat.rs | 123 +++++++++++++++++----- 2 files changed, 136 insertions(+), 26 deletions(-) diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 304739b42f5fc..0fdefbafa1579 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -24,6 +24,7 @@ use arrow::util::bench_util::{ }; use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -80,6 +81,44 @@ fn invoke_repeat_with_args( } fn criterion_benchmark(c: &mut Criterion) { + let repeat_fn = string::repeat(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("repeat/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("repeat/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8View, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 2ca5e190c6e02..16f04ea96ee44 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::utf8_to_str_type; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, @@ -27,7 +27,8 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{NativeType, logical_int64, logical_string}; -use datafusion_common::{DataFusionError, Result, exec_err}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -99,7 +100,61 @@ impl ScalarUDFImpl for RepeatFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(repeat, vec![])(&args.args) + let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; + + match (&string_arg, &count_arg) { + ( + ColumnarValue::Scalar(string_scalar), + ColumnarValue::Scalar(count_scalar), + ) => { + if string_scalar.is_null() || count_scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + let count = match count_scalar { + ScalarValue::Int64(Some(n)) => *n, + _ => { + return internal_err!( + "Unexpected data type {:?} for repeat count", + count_scalar.data_type() + ); + } + }; + + let repeated = match string_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => { + if count <= 0 { + String::new() + } else { + let result_len = s.len().saturating_mul(count as usize); + if result_len > i32::MAX as usize { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + i32::MAX, + result_len + ); + } + s.repeat(count as usize) + } + } + _ => { + return internal_err!( + "Unexpected data type {:?} for function repeat", + string_scalar.data_type() + ); + } + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(repeated)))) + } + _ => { + let string_array = string_arg.to_array(args.number_rows)?; + let count_array = count_arg.to_array(args.number_rows)?; + Ok(ColumnarValue::Array(repeat(&[string_array, count_array])?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -150,8 +205,9 @@ fn repeat_impl<'a, T, S>( ) -> Result where T: OffsetSizeTrait, - S: StringArrayType<'a>, + S: StringArrayType<'a> + 'a, { + use arrow::array::Array; let mut total_capacity = 0; let mut max_item_capacity = 0; string_array.iter().zip(number_array.iter()).try_for_each( @@ -181,37 +237,52 @@ where // Reusable buffer to avoid allocations in string.repeat() let mut buffer = Vec::::with_capacity(max_item_capacity); - string_array - .iter() - .zip(number_array.iter()) - .for_each(|(string, number)| { + // Helper function to repeat a string into a buffer using doubling strategy + #[inline] + fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: i64) { + buffer.clear(); + if count > 0 && !string.is_empty() { + let count = count as usize; + let src = string.as_bytes(); + buffer.extend_from_slice(src); + while buffer.len() < src.len() * count { + let copy_len = buffer.len().min(src.len() * count - buffer.len()); + buffer.extend_from_within(..copy_len); + } + } + } + + // no nulls in either array + if string_array.null_count() == 0 && number_array.null_count() == 0 { + for i in 0..string_array.len() { + // SAFETY: null_count() == 0 guarantees no nulls + let string = unsafe { string_array.value_unchecked(i) }; + let count = number_array.value(i); + if count >= 0 { + repeat_to_buffer(&mut buffer, string, count); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str + builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); + } else { + builder.append_value(""); + } + } + } else { + // handle nulls + for (string, number) in string_array.iter().zip(number_array.iter()) { match (string, number) { (Some(string), Some(number)) if number >= 0 => { - buffer.clear(); - let count = number as usize; - if count > 0 && !string.is_empty() { - let src = string.as_bytes(); - // Initial copy - buffer.extend_from_slice(src); - // Doubling strategy: copy what we have so far until we reach the target - while buffer.len() < src.len() * count { - let copy_len = - buffer.len().min(src.len() * count - buffer.len()); - // SAFETY: we're copying valid UTF-8 bytes that we already verified - buffer.extend_from_within(..copy_len); - } - } - // SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str + repeat_to_buffer(&mut buffer, string, number); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str builder .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); } (Some(_), Some(_)) => builder.append_value(""), _ => builder.append_null(), } - }); - let array = builder.finish(); + } + } - Ok(Arc::new(array) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] From 5e768c4778e04ea17935551141c67e2016a7747c Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sun, 25 Jan 2026 19:16:37 +0530 Subject: [PATCH 2/3] several refactors --- datafusion/functions/src/string/repeat.rs | 113 ++++++++++++++-------- 1 file changed, 70 insertions(+), 43 deletions(-) diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 16f04ea96ee44..dae77a796168a 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; @@ -102,15 +102,37 @@ impl ScalarUDFImpl for RepeatFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; + // Helper to create null result with correct type (follows utf8_to_str_type) + let null_result = |dt: &DataType| -> ColumnarValue { + let scalar = if matches!(dt, LargeUtf8) { + ScalarValue::LargeUtf8(None) + } else { + ScalarValue::Utf8(None) + }; + ColumnarValue::Scalar(scalar) + }; + + // Early return if either argument is a scalar null + if let ColumnarValue::Scalar(s) = &string_arg + && s.is_null() + { + return Ok(null_result(&s.data_type())); + } + if let ColumnarValue::Scalar(c) = &count_arg + && c.is_null() + { + let dt = match &string_arg { + ColumnarValue::Scalar(s) => s.data_type(), + ColumnarValue::Array(a) => a.data_type().clone(), + }; + return Ok(null_result(&dt)); + } + match (&string_arg, &count_arg) { ( ColumnarValue::Scalar(string_scalar), ColumnarValue::Scalar(count_scalar), ) => { - if string_scalar.is_null() || count_scalar.is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); - } - let count = match count_scalar { ScalarValue::Int64(Some(n)) => *n, _ => { @@ -121,23 +143,12 @@ impl ScalarUDFImpl for RepeatFunc { } }; - let repeated = match string_scalar { - ScalarValue::Utf8(Some(s)) - | ScalarValue::LargeUtf8(Some(s)) - | ScalarValue::Utf8View(Some(s)) => { - if count <= 0 { - String::new() - } else { - let result_len = s.len().saturating_mul(count as usize); - if result_len > i32::MAX as usize { - return exec_err!( - "string size overflow on repeat, max size is {}, but got {}", - i32::MAX, - result_len - ); - } - s.repeat(count as usize) - } + let result = match string_scalar { + ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { + ScalarValue::Utf8(Some(compute_repeat(s, count)?)) + } + ScalarValue::LargeUtf8(Some(s)) => { + ScalarValue::LargeUtf8(Some(compute_repeat(s, count)?)) } _ => { return internal_err!( @@ -147,12 +158,12 @@ impl ScalarUDFImpl for RepeatFunc { } }; - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(repeated)))) + Ok(ColumnarValue::Scalar(result)) } _ => { let string_array = string_arg.to_array(args.number_rows)?; let count_array = count_arg.to_array(args.number_rows)?; - Ok(ColumnarValue::Array(repeat(&[string_array, count_array])?)) + Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?)) } } } @@ -162,13 +173,30 @@ impl ScalarUDFImpl for RepeatFunc { } } +/// Computes repeat for a single string value +#[inline] +fn compute_repeat(s: &str, count: i64) -> Result { + if count <= 0 { + return Ok(String::new()); + } + let result_len = s.len().saturating_mul(count as usize); + if result_len > i32::MAX as usize { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + i32::MAX, + result_len + ); + } + Ok(s.repeat(count as usize)) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -fn repeat(args: &[ArrayRef]) -> Result { - let number_array = as_int64_array(&args[1])?; - match args[0].data_type() { +fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { + let number_array = as_int64_array(count_array)?; + match string_array.data_type() { Utf8View => { - let string_view_array = args[0].as_string_view(); + let string_view_array = string_array.as_string_view(); repeat_impl::( &string_view_array, number_array, @@ -176,17 +204,17 @@ fn repeat(args: &[ArrayRef]) -> Result { ) } Utf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i32::MAX as usize, ) } LargeUtf8 => { - let string_array = args[0].as_string::(); + let string_arr = string_array.as_string::(); repeat_impl::>( - &string_array, + &string_arr, number_array, i64::MAX as usize, ) @@ -207,7 +235,6 @@ where T: OffsetSizeTrait, S: StringArrayType<'a> + 'a, { - use arrow::array::Array; let mut total_capacity = 0; let mut max_item_capacity = 0; string_array.iter().zip(number_array.iter()).try_for_each( @@ -238,11 +265,11 @@ where let mut buffer = Vec::::with_capacity(max_item_capacity); // Helper function to repeat a string into a buffer using doubling strategy + // count must be > 0 #[inline] - fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: i64) { + fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: usize) { buffer.clear(); - if count > 0 && !string.is_empty() { - let count = count as usize; + if !string.is_empty() { let src = string.as_bytes(); buffer.extend_from_slice(src); while buffer.len() < src.len() * count { @@ -252,14 +279,14 @@ where } } - // no nulls in either array + // Fast path: no nulls in either array if string_array.null_count() == 0 && number_array.null_count() == 0 { for i in 0..string_array.len() { - // SAFETY: null_count() == 0 guarantees no nulls + // SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value let string = unsafe { string_array.value_unchecked(i) }; let count = number_array.value(i); - if count >= 0 { - repeat_to_buffer(&mut buffer, string, count); + if count > 0 { + repeat_to_buffer(&mut buffer, string, count as usize); // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); } else { @@ -267,11 +294,11 @@ where } } } else { - // handle nulls + // Slow path: handle nulls for (string, number) in string_array.iter().zip(number_array.iter()) { match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - repeat_to_buffer(&mut buffer, string, number); + (Some(string), Some(count)) if count > 0 => { + repeat_to_buffer(&mut buffer, string, count as usize); // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str builder .append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); From e54a07e44ca79424173a7fd7fed71197320f13b8 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Mon, 26 Jan 2026 10:37:30 +0530 Subject: [PATCH 3/3] get return type from scalarfunctionargs --- datafusion/functions/src/string/repeat.rs | 42 ++++++++++------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index dae77a796168a..65f320c4f9f13 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -100,32 +100,19 @@ impl ScalarUDFImpl for RepeatFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let return_type = args.return_field.data_type().clone(); let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; - // Helper to create null result with correct type (follows utf8_to_str_type) - let null_result = |dt: &DataType| -> ColumnarValue { - let scalar = if matches!(dt, LargeUtf8) { - ScalarValue::LargeUtf8(None) - } else { - ScalarValue::Utf8(None) - }; - ColumnarValue::Scalar(scalar) - }; - // Early return if either argument is a scalar null if let ColumnarValue::Scalar(s) = &string_arg && s.is_null() { - return Ok(null_result(&s.data_type())); + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); } if let ColumnarValue::Scalar(c) = &count_arg && c.is_null() { - let dt = match &string_arg { - ColumnarValue::Scalar(s) => s.data_type(), - ColumnarValue::Array(a) => a.data_type().clone(), - }; - return Ok(null_result(&dt)); + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); } match (&string_arg, &count_arg) { @@ -145,11 +132,15 @@ impl ScalarUDFImpl for RepeatFunc { let result = match string_scalar { ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => { - ScalarValue::Utf8(Some(compute_repeat(s, count)?)) - } - ScalarValue::LargeUtf8(Some(s)) => { - ScalarValue::LargeUtf8(Some(compute_repeat(s, count)?)) + ScalarValue::Utf8(Some(compute_repeat( + s, + count, + i32::MAX as usize, + )?)) } + ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some( + compute_repeat(s, count, i64::MAX as usize)?, + )), _ => { return internal_err!( "Unexpected data type {:?} for function repeat", @@ -173,17 +164,17 @@ impl ScalarUDFImpl for RepeatFunc { } } -/// Computes repeat for a single string value +/// Computes repeat for a single string value with max size check #[inline] -fn compute_repeat(s: &str, count: i64) -> Result { +fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { if count <= 0 { return Ok(String::new()); } let result_len = s.len().saturating_mul(count as usize); - if result_len > i32::MAX as usize { + if result_len > max_size { return exec_err!( "string size overflow on repeat, max size is {}, but got {}", - i32::MAX, + max_size, result_len ); } @@ -271,9 +262,12 @@ where buffer.clear(); if !string.is_empty() { let src = string.as_bytes(); + // Initial copy buffer.extend_from_slice(src); + // Doubling strategy: copy what we have so far until we reach the target while buffer.len() < src.len() * count { let copy_len = buffer.len().min(src.len() * count - buffer.len()); + // SAFETY: we're copying valid UTF-8 bytes that we already verified buffer.extend_from_within(..copy_len); } }