diff --git a/datafusion/spark/src/function/array/repeat.rs b/datafusion/spark/src/function/array/repeat.rs index ea9758ee9891d..c4e5e449c7535 100644 --- a/datafusion/spark/src/function/array/repeat.rs +++ b/datafusion/spark/src/function/array/repeat.rs @@ -28,7 +28,7 @@ use crate::function::null_utils::{ NullMaskResolution, apply_null_mask, compute_null_mask, }; -/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL. +/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL count: in Spark if the count is NULL, the result is NULL. /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkArrayRepeat { @@ -88,7 +88,7 @@ impl ScalarUDFImpl for SparkArrayRepeat { } /// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL -/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs. +/// if the count argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs. fn spark_array_repeat(args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args: arg_values, @@ -99,15 +99,14 @@ fn spark_array_repeat(args: ScalarFunctionArgs) -> Result { } = args; let return_type = return_field.data_type().clone(); - // Step 1: Check for NULL mask in incoming args - let null_mask = compute_null_mask(&arg_values, number_rows)?; + // A NULL element should be repeated into the array, not cause a NULL result. + let null_mask = compute_null_mask(&arg_values[1..], number_rows)?; - // If any argument is null then return NULL immediately + // If count is null then return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)); } - // Step 2: Delegate to DataFusion's array_repeat let array_repeat_func = ArrayRepeat::new(); let func_args = ScalarFunctionArgs { args: arg_values, @@ -118,6 +117,5 @@ fn spark_array_repeat(args: ScalarFunctionArgs) -> Result { }; let result = array_repeat_func.invoke_with_args(func_args)?; - // Step 3: Apply NULL mask to result apply_null_mask(result, null_mask, &return_type) } diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt index 19181aae0fc55..923e349140976 100644 --- a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt +++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt @@ -59,7 +59,17 @@ SELECT array_repeat(['123'], 2); query ? SELECT array_repeat(NULL, 2); ---- -NULL +[NULL, NULL] + +query ? +SELECT array_repeat(NULL, 1); +---- +[NULL] + +query ? +SELECT array_repeat(NULL, 0); +---- +[] query ? SELECT array_repeat([NULL], 2); @@ -88,7 +98,7 @@ FROM VALUES [123, 123] [] [] -NULL +[NULL] NULL