diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index f3dae22866c2..d280b9c1c002 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -89,10 +89,21 @@ impl ScalarUDFImpl for SparkConcat { ) } fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + use DataType::*; + // Spark semantics: concat returns NULL if ANY input is NULL let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - Ok(Arc::new(Field::new("concat", DataType::Utf8, nullable))) + // Determine return type: Utf8View > LargeUtf8 > Utf8 + let mut dt = &Utf8; + for field in args.arg_fields { + let data_type = field.data_type(); + if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) { + dt = data_type; + } + } + + Ok(Arc::new(Field::new("concat", dt.clone(), nullable))) } } @@ -110,9 +121,18 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // Handle zero-argument case: return empty string if arg_values.is_empty() { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - Some(String::new()), - ))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::new(), + )))), + DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8( + Some(String::new()), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(String::new()), + ))), + }; } // Step 1: Check for NULL mask in incoming args @@ -120,7 +140,14 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // If all scalars and any is NULL, return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))), + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } // Step 2: Delegate to DataFusion's concat @@ -181,6 +208,7 @@ mod tests { ); Ok(()) } + #[test] fn test_spark_concat_return_field_non_nullable() -> Result<()> { let func = SparkConcat::new(); diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt index 258cb829d7d4..97e7b57f7d06 100644 --- a/datafusion/sqllogictest/test_files/spark/string/concat.slt +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -20,6 +20,12 @@ SELECT concat('Spark', 'SQL'); ---- SparkSQL +# Test two Utf8View inputs: value and return type +query TT +SELECT concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View')), arrow_typeof(concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View'))); +---- +SparkSQL Utf8View + query T SELECT concat('Spark', 'SQL', NULL); ---- @@ -46,3 +52,21 @@ SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, ---- abc NULL + +# Test mixed types: Utf8View + Utf8 +query TT +SELECT concat(arrow_cast('hello', 'Utf8View'), ' world'), arrow_typeof(concat(arrow_cast('hello', 'Utf8View'), ' world')); +---- +hello world Utf8View + +# Test Utf8 + LargeUtf8 => return type LargeUtf8 +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'))); +---- +ab LargeUtf8 + +# Test all three types mixed together +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View'))); +---- +abc Utf8View