From cd597b4cff59659ece1a1f10f71bd5ec0cea1552 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Wed, 15 Oct 2025 13:06:02 -0700 Subject: [PATCH 1/2] feat: support Spark `concat` string function (#18063) * chore: Extend backtrace coverage * fmt * part2 * feedback * clippy * feat: support Spark `concat` * clippy * comments * test * doc (cherry picked from commit 264030cca76d0bdb4d8809f252b422e72624a345) --- .../spark/src/function/string/concat.rs | 306 ++++++++++++++++++ datafusion/spark/src/function/string/mod.rs | 9 +- .../test_files/spark/string/concat.slt | 48 +++ 3 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 datafusion/spark/src/function/string/concat.rs create mode 100644 datafusion/sqllogictest/test_files/spark/string/concat.slt diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs new file mode 100644 index 0000000000000..0e981e7c37224 --- /dev/null +++ b/datafusion/spark/src/function/string/concat.rs @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayBuilder}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::string::concat::ConcatFunc; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `concat` expression +/// +/// +/// Concatenates multiple input strings into a single string. +/// Returns NULL if any input is NULL. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkConcat { + signature: Signature, +} + +impl Default for SparkConcat { + fn default() -> Self { + Self::new() + } +} + +impl SparkConcat { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::UserDefined, TypeSignature::Nullary], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkConcat { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "concat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_concat(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + // Accept any string types, including zero arguments + Ok(arg_types.to_vec()) + } +} + +/// Concatenates strings, returning NULL if any input is NULL +/// This is a Spark-specific wrapper around DataFusion's concat that returns NULL +/// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs. +fn spark_concat(args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + } = args; + + // Handle zero-argument case: return empty string + if arg_values.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(String::new()), + ))); + } + + // Step 1: Check for NULL mask in incoming args + let null_mask = compute_null_mask(&arg_values, number_rows)?; + + // If all scalars and any is NULL, return NULL immediately + if null_mask.is_none() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + // Step 2: Delegate to DataFusion's concat + let concat_func = ConcatFunc::new(); + let func_args = ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + }; + let result = concat_func.invoke_with_args(func_args)?; + + // Step 3: Apply NULL mask to result + apply_null_mask(result, null_mask) +} + +/// Compute NULL mask for the arguments +/// Returns None if all scalars and any is NULL, or a Vector of +/// boolean representing the null mask for incoming arrays +fn compute_null_mask( + args: &[ColumnarValue], + number_rows: usize, +) -> Result>> { + // Check if all arguments are scalars + let all_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + + if all_scalars { + // For scalars, check if any is NULL + for arg in args { + if let ColumnarValue::Scalar(scalar) = arg { + if scalar.is_null() { + // Return None to indicate all values should be NULL + return Ok(None); + } + } + } + // No NULLs in scalars + Ok(Some(vec![])) + } else { + // For arrays, compute NULL mask for each row + let array_len = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(number_rows); + + // Convert all scalars to arrays for uniform processing + let arrays: Result> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect(); + let arrays = arrays?; + + // Compute NULL mask + let mut null_mask = vec![false; array_len]; + for array in &arrays { + for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) { + if array.is_null(i) { + *null_flag = true; + } + } + } + + Ok(Some(null_mask)) + } +} + +/// Apply NULL mask to the result +fn apply_null_mask( + result: ColumnarValue, + null_mask: Option>, +) -> Result { + match (result, null_mask) { + // Scalar with NULL mask means return NULL + (ColumnarValue::Scalar(_), None) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + // Scalar without NULL mask, return as-is + (scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar), + // Array with NULL mask + (ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => { + let array_len = array.len(); + let return_type = array.data_type(); + + let mut builder: Box = match return_type { + DataType::Utf8 => { + let string_array = array + .as_any() + .downcast_ref::() + .unwrap(); + let mut builder = + arrow::array::StringBuilder::with_capacity(array_len, 0); + for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { + if is_null || string_array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(string_array.value(i)); + } + } + Box::new(builder) + } + DataType::LargeUtf8 => { + let string_array = array + .as_any() + .downcast_ref::() + .unwrap(); + let mut builder = + arrow::array::LargeStringBuilder::with_capacity(array_len, 0); + for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { + if is_null || string_array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(string_array.value(i)); + } + } + Box::new(builder) + } + DataType::Utf8View => { + let string_array = array + .as_any() + .downcast_ref::() + .unwrap(); + let mut builder = + arrow::array::StringViewBuilder::with_capacity(array_len); + for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { + if is_null || string_array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(string_array.value(i)); + } + } + Box::new(builder) + } + _ => { + return datafusion_common::exec_err!( + "Unsupported return type for concat: {:?}", + return_type + ); + } + }; + + Ok(ColumnarValue::Array(builder.finish())) + } + // Array without NULL mask, return as-is + (array @ ColumnarValue::Array(_), _) => Ok(array), + // Shouldn't happen + (scalar, _) => Ok(scalar), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::StringArray; + use arrow::datatypes::DataType; + use datafusion_common::Result; + + #[test] + fn test_concat_basic() -> Result<()> { + test_scalar_function!( + SparkConcat::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))), + ], + Ok(Some("SparkSQL")), + &str, + DataType::Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_concat_with_null() -> Result<()> { + test_scalar_function!( + SparkConcat::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + DataType::Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index e83b696bc1ba9..a0c76cfabeaf1 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -17,6 +17,7 @@ pub mod ascii; pub mod char; +pub mod concat; pub mod ilike; pub mod like; pub mod luhn_check; @@ -27,6 +28,7 @@ use std::sync::Arc; make_udf_function!(ascii::SparkAscii, ascii); make_udf_function!(char::CharFunc, char); +make_udf_function!(concat::SparkConcat, concat); make_udf_function!(ilike::SparkILike, ilike); make_udf_function!(like::SparkLike, like); make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); @@ -44,6 +46,11 @@ pub mod expr_fn { "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", arg1 )); + export_functions!(( + concat, + "Concatenates multiple input strings into a single string. Returns NULL if any input is NULL.", + args + )); export_functions!(( ilike, "Returns true if str matches pattern (case insensitive).", @@ -62,5 +69,5 @@ pub mod expr_fn { } pub fn functions() -> Vec> { - vec![ascii(), char(), ilike(), like(), luhn_check()] + vec![ascii(), char(), concat(), ilike(), like(), luhn_check()] } diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt new file mode 100644 index 0000000000000..0b796a54a69e8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT concat('Spark', 'SQL'); +---- +SparkSQL + +query T +SELECT concat('Spark', 'SQL', NULL); +---- +NULL + +query T +SELECT concat('', '1', '', '2'); +---- +12 + +query T +SELECT concat(); +---- +(empty) + +query T +SELECT concat(''); +---- +(empty) + + +query T +SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, 'b', 'c') order by 1 nulls last; +---- +abc +NULL \ No newline at end of file From 4e2597bef0ccc2ceead68af8fe00c958ed95a2b6 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Thu, 16 Oct 2025 12:36:01 -0700 Subject: [PATCH 2/2] chore: use `NullBuffer::union` for Spark `concat` (#18087) ## Which issue does this PR close? - Closes #. Followup on https://github.com/apache/datafusion/pull/18063#pullrequestreview-3341818221 ## Rationale for this change Use cheaper `NullBuffer::union` to apply null mask instead of iterator approach ## What changes are included in this PR? ## Are these changes tested? ## Are there any user-facing changes? (cherry picked from commit 337378ab81f6c7dab7da9000124c554d3b7ee568) --- .../spark/src/function/string/concat.rs | 141 +++++++----------- 1 file changed, 52 insertions(+), 89 deletions(-) diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0e981e7c37224..0dcc58d5bb8ed 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayBuilder}; +use arrow::array::Array; +use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ @@ -31,6 +32,10 @@ use std::sync::Arc; /// /// Concatenates multiple input strings into a single string. /// Returns NULL if any input is NULL. +/// +/// Differences with DataFusion concat: +/// - Support 0 arguments +/// - Return NULL if any input is NULL #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkConcat { signature: Signature, @@ -80,6 +85,16 @@ impl ScalarUDFImpl for SparkConcat { } } +/// Represents the null state for Spark concat +enum NullMaskResolution { + /// Return NULL as the result (e.g., scalar inputs with at least one NULL) + ReturnNull, + /// No null mask needed (e.g., all scalar inputs are non-NULL) + NoMask, + /// Null mask to apply for arrays + Apply(NullBuffer), +} + /// Concatenates strings, returning NULL if any input is NULL /// This is a Spark-specific wrapper around DataFusion's concat that returns NULL /// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs. @@ -103,7 +118,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { let null_mask = compute_null_mask(&arg_values, number_rows)?; // If all scalars and any is NULL, return NULL immediately - if null_mask.is_none() { + if matches!(null_mask, NullMaskResolution::ReturnNull) { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } @@ -122,13 +137,11 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { apply_null_mask(result, null_mask) } -/// Compute NULL mask for the arguments -/// Returns None if all scalars and any is NULL, or a Vector of -/// boolean representing the null mask for incoming arrays +/// Compute NULL mask for the arguments using NullBuffer::union fn compute_null_mask( args: &[ColumnarValue], number_rows: usize, -) -> Result>> { +) -> Result { // Check if all arguments are scalars let all_scalars = args .iter() @@ -139,15 +152,14 @@ fn compute_null_mask( for arg in args { if let ColumnarValue::Scalar(scalar) = arg { if scalar.is_null() { - // Return None to indicate all values should be NULL - return Ok(None); + return Ok(NullMaskResolution::ReturnNull); } } } // No NULLs in scalars - Ok(Some(vec![])) + Ok(NullMaskResolution::NoMask) } else { - // For arrays, compute NULL mask for each row + // For arrays, compute NULL mask for each row using NullBuffer::union let array_len = args .iter() .find_map(|arg| match arg { @@ -166,99 +178,50 @@ fn compute_null_mask( .collect(); let arrays = arrays?; - // Compute NULL mask - let mut null_mask = vec![false; array_len]; - for array in &arrays { - for (i, null_flag) in null_mask.iter_mut().enumerate().take(array_len) { - if array.is_null(i) { - *null_flag = true; - } - } - } + // Use NullBuffer::union to combine all null buffers + let combined_nulls = arrays + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); - Ok(Some(null_mask)) + match combined_nulls { + Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), + None => Ok(NullMaskResolution::NoMask), + } } } -/// Apply NULL mask to the result +/// Apply NULL mask to the result using NullBuffer::union fn apply_null_mask( result: ColumnarValue, - null_mask: Option>, + null_mask: NullMaskResolution, ) -> Result { match (result, null_mask) { - // Scalar with NULL mask means return NULL - (ColumnarValue::Scalar(_), None) => { + // Scalar with ReturnNull mask means return NULL + (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - // Scalar without NULL mask, return as-is - (scalar @ ColumnarValue::Scalar(_), Some(mask)) if mask.is_empty() => Ok(scalar), - // Array with NULL mask - (ColumnarValue::Array(array), Some(null_mask)) if !null_mask.is_empty() => { - let array_len = array.len(); - let return_type = array.data_type(); + // Scalar without mask, return as-is + (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), + // Array with NULL mask - use NullBuffer::union to combine nulls + (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { + // Combine the result's existing nulls with our computed null mask + let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); - let mut builder: Box = match return_type { - DataType::Utf8 => { - let string_array = array - .as_any() - .downcast_ref::() - .unwrap(); - let mut builder = - arrow::array::StringBuilder::with_capacity(array_len, 0); - for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { - if is_null || string_array.is_null(i) { - builder.append_null(); - } else { - builder.append_value(string_array.value(i)); - } - } - Box::new(builder) - } - DataType::LargeUtf8 => { - let string_array = array - .as_any() - .downcast_ref::() - .unwrap(); - let mut builder = - arrow::array::LargeStringBuilder::with_capacity(array_len, 0); - for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { - if is_null || string_array.is_null(i) { - builder.append_null(); - } else { - builder.append_value(string_array.value(i)); - } - } - Box::new(builder) - } - DataType::Utf8View => { - let string_array = array - .as_any() - .downcast_ref::() - .unwrap(); - let mut builder = - arrow::array::StringViewBuilder::with_capacity(array_len); - for (i, &is_null) in null_mask.iter().enumerate().take(array_len) { - if is_null || string_array.is_null(i) { - builder.append_null(); - } else { - builder.append_value(string_array.value(i)); - } - } - Box::new(builder) - } - _ => { - return datafusion_common::exec_err!( - "Unsupported return type for concat: {:?}", - return_type - ); - } - }; + // Create new array with combined nulls + let new_array = array + .into_data() + .into_builder() + .nulls(combined_nulls) + .build()?; - Ok(ColumnarValue::Array(builder.finish())) + Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + new_array, + )))) } // Array without NULL mask, return as-is - (array @ ColumnarValue::Array(_), _) => Ok(array), - // Shouldn't happen + (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), + // Edge cases that shouldn't happen in practice (scalar, _) => Ok(scalar), } }