Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions datafusion/functions/benches/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use arrow::util::bench_util::{
};
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
use datafusion_common::DataFusionError;
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::string;
Expand Down Expand Up @@ -80,6 +81,44 @@ fn invoke_repeat_with_args(
}

fn criterion_benchmark(c: &mut Criterion) {
let repeat_fn = string::repeat();
let config_options = Arc::new(ConfigOptions::default());

// Scalar benchmarks (outside loop)
c.bench_function("repeat/scalar_utf8", |b| {
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
],
arg_fields: vec![
Field::new("a", DataType::Utf8, false).into(),
Field::new("b", DataType::Int64, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
};
b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap()))
});

c.bench_function("repeat/scalar_utf8view", |b| {
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
],
arg_fields: vec![
Field::new("a", DataType::Utf8View, false).into(),
Field::new("b", DataType::Int64, false).into(),
],
number_rows: 1,
return_field: Field::new("f", DataType::Utf8, true).into(),
config_options: Arc::clone(&config_options),
};
b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap()))
});

for size in [1024, 4096] {
// REPEAT 3 TIMES
let repeat_times = 3;
Expand Down
164 changes: 128 additions & 36 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
use std::any::Any;
use std::sync::Arc;

use crate::utils::{make_scalar_function, utf8_to_str_type};
use crate::utils::utf8_to_str_type;
use arrow::array::{
ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
OffsetSizeTrait, StringArrayType, StringViewArray,
};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
use datafusion_common::cast::as_int64_array;
use datafusion_common::types::{NativeType, logical_int64, logical_string};
use datafusion_common::{DataFusionError, Result, exec_err};
use datafusion_common::utils::take_function_args;
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
Expand Down Expand Up @@ -99,39 +100,112 @@ impl ScalarUDFImpl for RepeatFunc {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(repeat, vec![])(&args.args)
let return_type = args.return_field.data_type().clone();
let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;

// Early return if either argument is a scalar null
if let ColumnarValue::Scalar(s) = &string_arg
&& s.is_null()
{
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
}
if let ColumnarValue::Scalar(c) = &count_arg
&& c.is_null()
{
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
}

match (&string_arg, &count_arg) {
(
ColumnarValue::Scalar(string_scalar),
ColumnarValue::Scalar(count_scalar),
) => {
let count = match count_scalar {
ScalarValue::Int64(Some(n)) => *n,
_ => {
return internal_err!(
"Unexpected data type {:?} for repeat count",
count_scalar.data_type()
);
}
};

let result = match string_scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
ScalarValue::Utf8(Some(compute_repeat(
s,
count,
i32::MAX as usize,
)?))
}
ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
compute_repeat(s, count, i64::MAX as usize)?,
)),
_ => {
return internal_err!(
"Unexpected data type {:?} for function repeat",
string_scalar.data_type()
);
}
};

Ok(ColumnarValue::Scalar(result))
}
_ => {
let string_array = string_arg.to_array(args.number_rows)?;
let count_array = count_arg.to_array(args.number_rows)?;
Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
}
}
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

/// Computes repeat for a single string value with max size check
#[inline]
fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
if count <= 0 {
return Ok(String::new());
}
let result_len = s.len().saturating_mul(count as usize);
if result_len > max_size {
return exec_err!(
"string size overflow on repeat, max size is {}, but got {}",
max_size,
result_len
);
}
Ok(s.repeat(count as usize))
}

/// Repeats string the specified number of times.
/// repeat('Pg', 4) = 'PgPgPgPg'
fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
let number_array = as_int64_array(&args[1])?;
match args[0].data_type() {
fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
let number_array = as_int64_array(count_array)?;
match string_array.data_type() {
Utf8View => {
let string_view_array = args[0].as_string_view();
let string_view_array = string_array.as_string_view();
repeat_impl::<i32, &StringViewArray>(
&string_view_array,
number_array,
i32::MAX as usize,
)
}
Utf8 => {
let string_array = args[0].as_string::<i32>();
let string_arr = string_array.as_string::<i32>();
repeat_impl::<i32, &GenericStringArray<i32>>(
&string_array,
&string_arr,
number_array,
i32::MAX as usize,
)
}
LargeUtf8 => {
let string_array = args[0].as_string::<i64>();
let string_arr = string_array.as_string::<i64>();
repeat_impl::<i64, &GenericStringArray<i64>>(
&string_array,
&string_arr,
number_array,
i64::MAX as usize,
)
Expand All @@ -150,7 +224,7 @@ fn repeat_impl<'a, T, S>(
) -> Result<ArrayRef>
where
T: OffsetSizeTrait,
S: StringArrayType<'a>,
S: StringArrayType<'a> + 'a,
{
let mut total_capacity = 0;
let mut max_item_capacity = 0;
Expand Down Expand Up @@ -181,37 +255,55 @@ where
// Reusable buffer to avoid allocations in string.repeat()
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);

string_array
.iter()
.zip(number_array.iter())
.for_each(|(string, number)| {
// Helper function to repeat a string into a buffer using doubling strategy
// count must be > 0
#[inline]
fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
buffer.clear();
if !string.is_empty() {
let src = string.as_bytes();
// Initial copy
buffer.extend_from_slice(src);
// Doubling strategy: copy what we have so far until we reach the target
while buffer.len() < src.len() * count {
let copy_len = buffer.len().min(src.len() * count - buffer.len());
// SAFETY: we're copying valid UTF-8 bytes that we already verified
buffer.extend_from_within(..copy_len);
}
}
}

// Fast path: no nulls in either array
if string_array.null_count() == 0 && number_array.null_count() == 0 {
for i in 0..string_array.len() {
// SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value
let string = unsafe { string_array.value_unchecked(i) };
let count = number_array.value(i);
if count > 0 {
repeat_to_buffer(&mut buffer, string, count as usize);
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
} else {
builder.append_value("");
}
}
} else {
// Slow path: handle nulls
for (string, number) in string_array.iter().zip(number_array.iter()) {
match (string, number) {
(Some(string), Some(number)) if number >= 0 => {
buffer.clear();
let count = number as usize;
if count > 0 && !string.is_empty() {
let src = string.as_bytes();
// Initial copy
buffer.extend_from_slice(src);
// Doubling strategy: copy what we have so far until we reach the target
while buffer.len() < src.len() * count {
let copy_len =
buffer.len().min(src.len() * count - buffer.len());
// SAFETY: we're copying valid UTF-8 bytes that we already verified
Comment thread
Jefffrey marked this conversation as resolved.
buffer.extend_from_within(..copy_len);
}
}
// SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str
(Some(string), Some(count)) if count > 0 => {
repeat_to_buffer(&mut buffer, string, count as usize);
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
builder
.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
}
(Some(_), Some(_)) => builder.append_value(""),
_ => builder.append_null(),
}
});
let array = builder.finish();
}
}

Ok(Arc::new(array) as ArrayRef)
Ok(Arc::new(builder.finish()) as ArrayRef)
}

#[cfg(test)]
Expand Down