Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions datafusion/spark/src/function/array/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::function::null_utils::{
NullMaskResolution, apply_null_mask, compute_null_mask,
};

/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL.
/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL count: in Spark if the count is NULL, the result is NULL.
/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkArrayRepeat {
Expand Down Expand Up @@ -88,7 +88,7 @@ impl ScalarUDFImpl for SparkArrayRepeat {
}

/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL
/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
/// if the count argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs.
fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs {
args: arg_values,
Expand All @@ -99,15 +99,14 @@ fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
} = args;
let return_type = return_field.data_type().clone();

// Step 1: Check for NULL mask in incoming args
let null_mask = compute_null_mask(&arg_values, number_rows)?;
// A NULL element should be repeated into the array, not cause a NULL result.
let null_mask = compute_null_mask(&arg_values[1..], number_rows)?;

// If any argument is null then return NULL immediately
// If count is null then return NULL immediately
if matches!(null_mask, NullMaskResolution::ReturnNull) {
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
}

// Step 2: Delegate to DataFusion's array_repeat
let array_repeat_func = ArrayRepeat::new();
let func_args = ScalarFunctionArgs {
args: arg_values,
Expand All @@ -118,6 +117,5 @@ fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
};
let result = array_repeat_func.invoke_with_args(func_args)?;

// Step 3: Apply NULL mask to result
apply_null_mask(result, null_mask, &return_type)
}
14 changes: 12 additions & 2 deletions datafusion/sqllogictest/test_files/spark/array/array_repeat.slt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,17 @@ SELECT array_repeat(['123'], 2);
query ?
SELECT array_repeat(NULL, 2);
----
NULL
[NULL, NULL]

query ?
SELECT array_repeat(NULL, 1);
----
[NULL]

query ?
SELECT array_repeat(NULL, 0);
----
[]

query ?
SELECT array_repeat([NULL], 2);
Expand Down Expand Up @@ -88,7 +98,7 @@ FROM VALUES
[123, 123]
[]
[]
NULL
[NULL]
NULL


Expand Down
Loading