diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index a565006a2577d..2d62c5132e887 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -21,10 +21,9 @@ use std::any::Any; use std::sync::Arc; use crate::make_array::make_array_inner; -use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; +use crate::utils::{check_datatypes, make_scalar_function}; use arrow::array::{ - Array, ArrayData, ArrayRef, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, + Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; @@ -330,7 +329,7 @@ impl ScalarUDFImpl for ArrayConcat { &self, args: datafusion_expr::ScalarFunctionArgs, ) -> Result { - make_scalar_function(array_concat_inner)(&args.args) + make_scalar_function(datafusion_functions::utils::concat_arrays)(&args.args) } fn aliases(&self) -> &[String] { @@ -352,106 +351,9 @@ impl ScalarUDFImpl for ArrayConcat { } } -fn array_concat_inner(args: &[ArrayRef]) -> Result { - if args.is_empty() { - return exec_err!("array_concat expects at least one argument"); - } - - let mut all_null = true; - let mut large_list = false; - for arg in args { - match arg.data_type() { - DataType::Null => continue, - DataType::LargeList(_) => large_list = true, - _ => (), - } - if arg.null_count() < arg.len() { - all_null = false; - } - } - - if all_null { - // Return a null array with the same type as the first non-null-type argument - let return_type = args - .iter() - .map(|arg| arg.data_type()) - .find_or_first(|d| !d.is_null()) - .unwrap(); // Safe because args is non-empty - - Ok(arrow::array::make_array(ArrayData::new_null( - return_type, - args[0].len(), - ))) - } else if large_list { - concat_internal::(args) - } else { - concat_internal::(args) - } -} - -fn concat_internal(args: &[ArrayRef]) -> Result { - let args = align_array_dimensions::(args.to_vec())?; - - let list_arrays = args - .iter() - .map(|arg| as_generic_list_array::(arg)) - .collect::>>()?; - // Assume number of rows is the same for all arrays - let row_count = list_arrays[0].len(); - - let mut array_lengths = vec![]; - let mut arrays = vec![]; - let mut valid = NullBufferBuilder::new(row_count); - for i in 0..row_count { - let nulls = list_arrays - .iter() - .map(|arr| arr.is_null(i)) - .collect::>(); - - // If all the arrays are null, the concatenated array is null - let is_null = nulls.iter().all(|&x| x); - if is_null { - array_lengths.push(0); - valid.append_null(); - } else { - // Get all the arrays on i-th row - let values = list_arrays - .iter() - .map(|arr| arr.value(i)) - .collect::>(); - - let elements = values - .iter() - .map(|a| a.as_ref()) - .collect::>(); - - // Concatenated array on i-th row - let concatenated_array = arrow::compute::concat(elements.as_slice())?; - array_lengths.push(concatenated_array.len()); - arrays.push(concatenated_array); - valid.append_non_null(); - } - } - // Assume all arrays have the same data type - let data_type = list_arrays[0].value_type(); - - let elements = arrays - .iter() - .map(|a| a.as_ref()) - .collect::>(); - - let list_arr = GenericListArray::::new( - Arc::new(Field::new_list_field(data_type, true)), - OffsetBuffer::from_lengths(array_lengths), - Arc::new(arrow::compute::concat(elements.as_slice())?), - valid.finish(), - ); - - Ok(Arc::new(list_arr)) -} - // Kernel functions +/// Array_append SQL function fn array_append_inner(args: &[ArrayRef]) -> Result { let [array, values] = take_function_args("array_append", args)?; match array.data_type() { @@ -462,6 +364,7 @@ fn array_append_inner(args: &[ArrayRef]) -> Result { } } +/// Array_prepend SQL function fn array_prepend_inner(args: &[ArrayRef]) -> Result { let [values, array] = take_function_args("array_prepend", args)?; match array.data_type() { @@ -492,8 +395,8 @@ where }; let res = match list_array.value_type() { - DataType::List(_) => concat_internal::(args)?, - DataType::LargeList(_) => concat_internal::(args)?, + DataType::List(_) => datafusion_functions::utils::concat_arrays(args)?, + DataType::LargeList(_) => datafusion_functions::utils::concat_arrays(args)?, data_type => { return generic_append_and_prepend::( list_array, diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 464301b6ffcf0..c90ad401caf55 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -17,14 +17,9 @@ //! array function utils -use std::sync::Arc; +use arrow::datatypes::{DataType, Fields}; -use arrow::datatypes::{DataType, Field, Fields}; - -use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, -}; -use arrow::buffer::OffsetBuffer; +use arrow::array::{Array, ArrayRef, BooleanArray, Scalar, UInt32Array}; use datafusion_common::cast::{ as_fixed_size_list_array, as_large_list_array, as_list_array, }; @@ -82,44 +77,6 @@ where } } -pub(crate) fn align_array_dimensions( - args: Vec, -) -> Result> { - let args_ndim = args - .iter() - .map(|arg| datafusion_common::utils::list_ndims(arg.data_type())) - .collect::>(); - let max_ndim = args_ndim.iter().max().unwrap_or(&0); - - // Align the dimensions of the arrays - let aligned_args: Result> = args - .into_iter() - .zip(args_ndim.iter()) - .map(|(array, ndim)| { - if ndim < max_ndim { - let mut aligned_array = Arc::clone(&array); - for _ in 0..(max_ndim - ndim) { - let data_type = aligned_array.data_type().to_owned(); - let array_lengths = vec![1; aligned_array.len()]; - let offsets = OffsetBuffer::::from_lengths(array_lengths); - - aligned_array = Arc::new(GenericListArray::::try_new( - Arc::new(Field::new_list_field(data_type, true)), - offsets, - aligned_array, - None, - )?) - } - Ok(aligned_array) - } else { - Ok(Arc::clone(&array)) - } - }) - .collect(); - - aligned_args -} - /// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// /// # Arguments @@ -267,60 +224,3 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { _ => internal_err!("Expected a Map type, got {data_type}"), } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::ListArray; - use arrow::datatypes::Int64Type; - use datafusion_common::utils::SingleRowListArrayBuilder; - - /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` - #[test] - fn test_align_array_dimensions() { - let array1d_1: ArrayRef = - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(2), Some(3)]), - Some(vec![Some(4), Some(5)]), - ])); - let array1d_2: ArrayRef = - Arc::new(ListArray::from_iter_primitive::(vec![ - Some(vec![Some(6), Some(7), Some(8)]), - ])); - - let array2d_1: ArrayRef = Arc::new( - SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(), - ); - let array2d_2 = Arc::new( - SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(), - ); - - let res = align_array_dimensions::(vec![ - array1d_1.to_owned(), - array2d_2.to_owned(), - ]) - .unwrap(); - - let expected = as_list_array(&array2d_1).unwrap(); - let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type()); - assert_ne!(as_list_array(&res[0]).unwrap(), expected); - assert_eq!( - datafusion_common::utils::list_ndims(res[0].data_type()), - expected_dim - ); - - let array3d_1: ArrayRef = - Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array()); - let array3d_2: ArrayRef = - Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array()); - let res = align_array_dimensions::(vec![array1d_1, array3d_2]).unwrap(); - - let expected = as_list_array(&array3d_1).unwrap(); - let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); - assert_ne!(as_list_array(&res[0]).unwrap(), expected); - assert_eq!( - datafusion_common::utils::list_ndims(res[0].data_type()), - expected_dim - ); - } -} diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 3b53660463d44..c873444b1104b 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -35,8 +35,8 @@ use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), - description = "Concatenates multiple strings together.", - syntax_example = "concat(str[, ..., str_n])", + description = "Concatenates multiple strings or arrays together.", + syntax_example = "concat(str[, ..., str_n]) or concat(array[, ..., array_n])", sql_example = r#"```sql > select concat('data', 'f', 'us', 'ion'); +-------------------------------------------------------+ @@ -44,11 +44,17 @@ use datafusion_macros::user_doc; +-------------------------------------------------------+ | datafusion | +-------------------------------------------------------+ +> select concat(make_array(1, 2), make_array(3, 4)); ++------------------------------------------+ +| concat(make_array(1, 2), make_array(3, 4)) | ++------------------------------------------+ +| [1, 2, 3, 4] | ++------------------------------------------+ ```"#, - standard_argument(name = "str", prefix = "String"), + standard_argument(name = "str_or_array", prefix = "String or Array"), argument( - name = "str_n", - description = "Subsequent string expressions to concatenate." + name = "str_or_array_n", + description = "Subsequent string or array expressions to concatenate. Cannot mix strings and arrays." ), related_udf(name = "concat_ws") )] @@ -65,13 +71,62 @@ impl Default for ConcatFunc { impl ConcatFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::variadic( - vec![Utf8View, Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), + } + } + + /// Get the string type with highest precedence: Utf8View > LargeUtf8 > Utf8 + /// + /// Utf8View is preferred for performance (zero-copy views), + /// LargeUtf8 supports larger strings (i64 offsets), + /// Utf8 is the fallback standard string type + fn get_string_type_precedence(&self, arg_types: &[DataType]) -> DataType { + use DataType::*; + + for data_type in arg_types { + if data_type == &Utf8View { + return Utf8View; + } + } + + for data_type in arg_types { + if data_type == &LargeUtf8 { + return LargeUtf8; + } } + + Utf8 + } + + /// Concatenate array arguments + fn concat_arrays(&self, args: &[ColumnarValue]) -> Result { + if args.is_empty() { + return plan_err!("concat requires at least one argument"); + } + + // Convert ColumnarValue arguments to ArrayRef + let arrays = ColumnarValue::values_to_arrays(args)?; + + // Check if all arrays are null - concat errors in this case + // This matches PostgreSQL behavior where concat() with all NULL values returns an error + let mut all_null = true; + for arg in &arrays { + if arg.data_type() == &DataType::Null { + continue; + } + if arg.null_count() < arg.len() { + all_null = false; + } + } + + if all_null { + return plan_err!("No valid arrays to concatenate"); + } + + // Delegate to shared array concatenation + let result = crate::utils::concat_arrays(&arrays)?; + Ok(ColumnarValue::Array(result)) } } @@ -88,37 +143,91 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + use DataType::*; + + if arg_types.is_empty() { + return plan_err!("concat requires at least one argument"); + } + + let has_arrays = arg_types + .iter() + .any(|dt| matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _))); + let has_non_arrays = arg_types + .iter() + .any(|dt| !matches!(dt, List(_) | LargeList(_) | FixedSizeList(_, _) | Null)); + + if has_arrays && has_non_arrays { + return plan_err!( + "Cannot mix array and non-array arguments in concat function." + ); + } + + if has_arrays { + return Ok(arg_types.to_vec()); + } + + let target_type = self.get_string_type_precedence(arg_types); + + // Only coerce types that need coercion, keep string types as-is + let coerced_types = arg_types + .iter() + .map(|data_type| match data_type { + Utf8View | Utf8 | LargeUtf8 => data_type.clone(), + _ => target_type.clone(), + }) + .collect(); + Ok(coerced_types) + } + fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; - let mut dt = &Utf8; - arg_types.iter().for_each(|data_type| { - if data_type == &Utf8View { - dt = data_type; - } - if data_type == &LargeUtf8 && dt != &Utf8View { - dt = data_type; - } - }); - Ok(dt.to_owned()) + if arg_types.is_empty() { + return plan_err!("concat requires at least one argument"); + } + + // After coercion, all arguments have the same type category, so check only the first + if let List(field) | LargeList(field) | FixedSizeList(field, _) = &arg_types[0] { + return Ok(List(Arc::new(arrow::datatypes::Field::new( + "item", + field.data_type().clone(), + true, + )))); + } + + // For non-array arguments, return string type based on precedence + let dt = self.get_string_type_precedence(arg_types); + Ok(dt) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + use DataType::*; let ScalarFunctionArgs { args, .. } = args; - let mut return_datatype = DataType::Utf8; - args.iter().for_each(|col| { - if col.data_type() == DataType::Utf8View { - return_datatype = col.data_type(); - } - if col.data_type() == DataType::LargeUtf8 - && return_datatype != DataType::Utf8View - { - return_datatype = col.data_type(); - } - }); + if args.is_empty() { + return plan_err!("concat requires at least one argument"); + } + + // After coercion, all arguments have the same type category, so check only the first + let is_array = match &args[0] { + ColumnarValue::Array(array) => matches!( + array.data_type(), + List(_) | LargeList(_) | FixedSizeList(_, _) + ), + ColumnarValue::Scalar(scalar) => matches!( + scalar.data_type(), + List(_) | LargeList(_) | FixedSizeList(_, _) + ), + }; + if is_array { + return self.concat_arrays(&args); + } + + let data_types: Vec = args.iter().map(|col| col.data_type()).collect(); + let return_datatype = self.get_string_type_precedence(&data_types); let array_len = args .iter() @@ -128,7 +237,7 @@ impl ScalarUDFImpl for ConcatFunc { }) .next(); - // Scalar + // Scalar case if array_len.is_none() { let mut result = String::new(); for arg in args { @@ -139,21 +248,22 @@ impl ScalarUDFImpl for ConcatFunc { match scalar.try_as_str() { Some(Some(v)) => result.push_str(v), Some(None) => {} // null literal - None => plan_err!( - "Concat function does not support scalar type {}", - scalar - )?, + None => { + if scalar.is_null() { + // Skip null values + } else { + result.push_str(&format!("{scalar}")); + } + } } } return match return_datatype { - DataType::Utf8View => { + Utf8View => { Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) } - DataType::Utf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) - } - DataType::LargeUtf8 => { + Utf8 => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))), + LargeUtf8 => { Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) } other => { @@ -162,7 +272,7 @@ impl ScalarUDFImpl for ConcatFunc { }; } - // Array + // Array case let len = array_len.unwrap(); let mut data_size = 0; let mut columns = Vec::with_capacity(args.len()); @@ -179,7 +289,7 @@ impl ScalarUDFImpl for ConcatFunc { } ColumnarValue::Array(array) => { match array.data_type() { - DataType::Utf8 => { + Utf8 => { let string_array = as_string_array(array)?; data_size += string_array.values().len(); @@ -189,19 +299,21 @@ impl ScalarUDFImpl for ConcatFunc { ColumnarValueRef::NonNullableArray(string_array) }; columns.push(column); - }, - DataType::LargeUtf8 => { + } + LargeUtf8 => { let string_array = as_largestring_array(array); data_size += string_array.values().len(); let column = if array.is_nullable() { ColumnarValueRef::NullableLargeStringArray(string_array) } else { - ColumnarValueRef::NonNullableLargeStringArray(string_array) + ColumnarValueRef::NonNullableLargeStringArray( + string_array, + ) }; columns.push(column); - }, - DataType::Utf8View => { + } + Utf8View => { let string_array = as_string_view_array(array)?; data_size += string_array.len(); @@ -211,18 +323,18 @@ impl ScalarUDFImpl for ConcatFunc { ColumnarValueRef::NonNullableStringViewArray(string_array) }; columns.push(column); - }, + } other => { return plan_err!("Input was {other} which is not a supported datatype for concat function") } }; } - _ => unreachable!("concat"), + _ => return plan_err!("Unsupported argument type: {}", arg.data_type()), } } match return_datatype { - DataType::Utf8 => { + Utf8 => { let mut builder = StringArrayBuilder::with_capacity(len, data_size); for i in 0..len { columns @@ -234,7 +346,7 @@ impl ScalarUDFImpl for ConcatFunc { let string_array = builder.finish(None); Ok(ColumnarValue::Array(Arc::new(string_array))) } - DataType::Utf8View => { + Utf8View => { let mut builder = StringViewArrayBuilder::with_capacity(len, data_size); for i in 0..len { columns @@ -246,7 +358,7 @@ impl ScalarUDFImpl for ConcatFunc { let string_array = builder.finish(); Ok(ColumnarValue::Array(Arc::new(string_array))) } - DataType::LargeUtf8 => { + LargeUtf8 => { let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size); for i in 0..len { columns @@ -258,7 +370,7 @@ impl ScalarUDFImpl for ConcatFunc { let string_array = builder.finish(None); Ok(ColumnarValue::Array(Arc::new(string_array))) } - _ => unreachable!(), + _ => plan_err!("Unsupported return datatype: {return_datatype}"), } } @@ -288,6 +400,11 @@ impl ScalarUDFImpl for ConcatFunc { } pub(crate) fn simplify_concat(args: Vec) -> Result { + use DataType::*; + + if args.is_empty() { + return plan_err!("concat requires at least one argument"); + } let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); @@ -302,30 +419,55 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { ConcatFunc::new().return_type(&data_types) }?; - for arg in args.clone() { + for arg in args.iter() { match arg { Expr::Literal(ScalarValue::Utf8(None), _) => {} - Expr::Literal(ScalarValue::LargeUtf8(None), _) => { - } - Expr::Literal(ScalarValue::Utf8View(None), _) => { } + Expr::Literal(ScalarValue::LargeUtf8(None), _) => {} + Expr::Literal(ScalarValue::Utf8View(None), _) => {} // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { - contiguous_scalar += &v; + contiguous_scalar += v; } Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { - contiguous_scalar += &v; + contiguous_scalar += v; } Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { - contiguous_scalar += &v; + contiguous_scalar += v; } - Expr::Literal(x, _) => { - return internal_err!( - "The scalar {x} should be casted to string type during the type coercion." - ) + Expr::Literal(scalar_val, _) => { + // Convert non-string, non-array literals to their string representation + // Skip array literals - they should be handled at runtime + if matches!( + scalar_val.data_type(), + List(_) | LargeList(_) | FixedSizeList(_, _) + ) { + if !contiguous_scalar.is_empty() { + match return_type { + Utf8 => new_args.push(lit(contiguous_scalar)), + LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8( + Some(contiguous_scalar), + ))), + Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some( + contiguous_scalar, + )))), + _ => return Ok(ExprSimplifyResult::Original(args)), + } + contiguous_scalar = "".to_string(); + } + new_args.push(arg.clone()); + } else { + // Convert non-string, non-array literals to their string representation + // This is needed during simplification phase which happens before coercion + // Skip NULL values (concat ignores NULLs) + if !scalar_val.is_null() { + let string_repr = format!("{scalar_val}"); + contiguous_scalar += &string_repr; + } + } } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. @@ -333,28 +475,30 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { arg => { if !contiguous_scalar.is_empty() { match return_type { - DataType::Utf8 => new_args.push(lit(contiguous_scalar)), - DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), - DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), - _ => unreachable!(), + Utf8 => new_args.push(lit(contiguous_scalar)), + LargeUtf8 => new_args + .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + Utf8View => new_args + .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => return Ok(ExprSimplifyResult::Original(args)), } contiguous_scalar = "".to_string(); } - new_args.push(arg); + new_args.push(arg.clone()); } } } if !contiguous_scalar.is_empty() { match return_type { - DataType::Utf8 => new_args.push(lit(contiguous_scalar)), - DataType::LargeUtf8 => { + Utf8 => new_args.push(lit(contiguous_scalar)), + LargeUtf8 => { new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) } - DataType::Utf8View => { + Utf8View => { new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) } - _ => unreachable!(), + _ => return Ok(ExprSimplifyResult::Original(args)), } } @@ -479,7 +623,7 @@ mod tests { ] .into_iter() .map(Arc::new) - .collect::>(); + .collect(); let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], @@ -501,4 +645,120 @@ mod tests { } Ok(()) } + + #[test] + fn test_concat_with_integers() -> Result<()> { + use datafusion_common::config::ConfigOptions; + + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abc".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(123))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), // NULL + ColumnarValue::Scalar(ScalarValue::Int64(Some(456))), + ]; + + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("b", Int64, true), + Field::new("c", Utf8, true), + Field::new("d", Int64, true), + ] + .into_iter() + .map(Arc::new) + .collect(); + + let func_args = ScalarFunctionArgs { + args, + arg_fields, + number_rows: 1, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatFunc::new().invoke_with_args(func_args)?; + + // Expected result should be "abc123456" + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + assert_eq!(s, "abc123456"); + } + _ => panic!("Expected scalar UTF8 result, got {result:?}"), + } + + Ok(()) + } + + #[test] + fn test_array_concatenation_comprehensive() -> Result<()> { + use arrow::array::{Int32Array, ListArray}; + use arrow::datatypes::{Field, Int32Type}; + use datafusion_common::config::ConfigOptions; + + // Test basic array concatenation + let arr1 = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + ])); + let arr2 = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3), Some(4)]), + ])); + + let args = vec![ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; + + let arg_fields = vec![ + Field::new("a", List(Arc::new(Field::new("item", Int32, true))), true), + Field::new("b", List(Arc::new(Field::new("item", Int32, true))), true), + ] + .into_iter() + .map(Arc::new) + .collect(); + + let func_args = ScalarFunctionArgs { + args, + arg_fields, + number_rows: 1, + return_field: Field::new( + "result", + List(Arc::new(Field::new("item", Int32, true))), + true, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatFunc::new().invoke_with_args(func_args)?; + + match result { + ColumnarValue::Array(array) => { + let list_array = array.as_any().downcast_ref::().unwrap(); + let concatenated = list_array.value(0); + let int_array = + concatenated.as_any().downcast_ref::().unwrap(); + + assert_eq!(int_array.len(), 4); + assert_eq!(int_array.value(0), 1); + assert_eq!(int_array.value(1), 2); + assert_eq!(int_array.value(2), 3); + assert_eq!(int_array.value(3), 4); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn test_mixed_type_error() -> Result<()> { + use arrow::datatypes::Field; + + // Test that coerce_types properly rejects mixed array/non-array types + let func = ConcatFunc::new(); + let arg_types = vec![List(Arc::new(Field::new("item", Int32, true))), Utf8]; + + let result = func.coerce_types(&arg_types); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("Cannot mix array and non-array arguments")); + + Ok(()) + } } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 5a2368a38ef9d..7a8182e007e10 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -15,13 +15,20 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::array::{ + Array, ArrayData, ArrayRef, ArrowPrimitiveType, AsArray, GenericListArray, + NullBufferBuilder, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::buffer::OffsetBuffer; use arrow::compute::try_binary; -use arrow::datatypes::{DataType, DecimalType}; +use arrow::datatypes::{DataType, DecimalType, Field}; use arrow::error::ArrowError; -use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::list_ndims; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; +use itertools::Itertools; use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given @@ -378,3 +385,201 @@ pub mod test { } } } + +/// Concatenates arrays +pub fn concat_arrays(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("concat_arrays expects at least one argument"); + } + + let mut all_null = true; + let mut large_list = false; + for arg in args { + match arg.data_type() { + DataType::Null => continue, + DataType::LargeList(_) => large_list = true, + _ => (), + } + if arg.null_count() < arg.len() { + all_null = false; + } + } + + if all_null { + // Return a null array with the same type as the first non-null-type argument + // Safe because args is non-empty + let return_type = args + .iter() + .map(|arg| arg.data_type()) + .find_or_first(|d| !d.is_null()) + .unwrap(); + + return Ok(arrow::array::make_array(ArrayData::new_null( + return_type, + args[0].len(), + ))); + } + + if large_list { + concat_arrays_internal::(args) + } else { + concat_arrays_internal::(args) + } +} + +fn concat_arrays_internal(args: &[ArrayRef]) -> Result { + let args = align_array_dimensions::(args.to_vec())?; + + let list_arrays = args + .iter() + .map(|arg| as_generic_list_array::(arg)) + .collect::>>()?; + + // Assume number of rows is the same for all arrays + let row_count = list_arrays[0].len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = NullBufferBuilder::new(row_count); + for i in 0..row_count { + let nulls = list_arrays + .iter() + .map(|arr| arr.is_null(i)) + .collect::>(); + + // If all the arrays are null, the concatenated array is null + let is_null = nulls.iter().all(|&x| x); + if is_null { + array_lengths.push(0); + valid.append_null(); + } else { + // Get all the arrays on i-th row + let values = list_arrays + .iter() + .map(|arr| arr.value(i)) + .collect::>(); + + let elements = values + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + // Concatenated array on i-th row + let concatenated_array = arrow::compute::concat(elements.as_slice())?; + array_lengths.push(concatenated_array.len()); + arrays.push(concatenated_array); + valid.append_non_null(); + } + } + // Assume all arrays have the same data type + let data_type = list_arrays[0].value_type(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = GenericListArray::::new( + Arc::new(Field::new_list_field(data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(arrow::compute::concat(elements.as_slice())?), + valid.finish(), + ); + + Ok(Arc::new(list_arr)) +} + +/// Aligns array dimensions +fn align_array_dimensions( + args: Vec, +) -> Result> { + let args_ndim = args + .iter() + .map(|arg| list_ndims(arg.data_type())) + .collect::>(); + let max_ndim = args_ndim.iter().max().unwrap_or(&0); + + // Align the dimensions of the arrays + let aligned_args: Result> = args + .into_iter() + .zip(args_ndim.iter()) + .map(|(array, ndim)| { + if ndim < max_ndim { + let mut aligned_array = Arc::clone(&array); + for _ in 0..(max_ndim - ndim) { + let data_type = aligned_array.data_type().to_owned(); + let array_lengths = vec![1; aligned_array.len()]; + let offsets = OffsetBuffer::::from_lengths(array_lengths); + + let field = Arc::new(Field::new("item", data_type, true)); + let aligned_array_inner = + GenericListArray::::new(field, offsets, aligned_array, None); + aligned_array = Arc::new(aligned_array_inner); + } + Ok(aligned_array) + } else { + Ok(array) + } + }) + .collect(); + + aligned_args +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::ListArray; + use arrow::datatypes::Int64Type; + use datafusion_common::cast::as_list_array; + use datafusion_common::utils::{list_ndims, SingleRowListArrayBuilder}; + + /// Test for align_array_dimensions function + #[test] + fn test_align_array_dimensions() { + let array1d_1: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])); + let array1d_2: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(6), Some(7), Some(8)]), + ])); + + let array2d_1: ArrayRef = Arc::new( + SingleRowListArrayBuilder::new(Arc::clone(&array1d_1)).build_list_array(), + ); + let array2d_2 = Arc::new( + SingleRowListArrayBuilder::new(Arc::clone(&array1d_2)).build_list_array(), + ); + + let res = align_array_dimensions::(vec![ + array1d_1.to_owned(), + array2d_2.to_owned(), + ]) + .unwrap(); + + let expected = as_list_array(&array2d_1).unwrap(); + let expected_dim = list_ndims(array2d_1.data_type()); + assert_ne!(as_list_array(&res[0]).unwrap(), expected); + assert_eq!( + list_ndims(res[0].data_type()), + expected_dim + ); + + let array3d_1: ArrayRef = + Arc::new(SingleRowListArrayBuilder::new(array2d_1).build_list_array()); + let array3d_2: ArrayRef = + Arc::new(SingleRowListArrayBuilder::new(array2d_2).build_list_array()); + let res = align_array_dimensions::(vec![array1d_1, array3d_2]).unwrap(); + + let expected = as_list_array(&array3d_1).unwrap(); + let expected_dim = list_ndims(array3d_1.data_type()); + assert_ne!(as_list_array(&res[0]).unwrap(), expected); + assert_eq!( + list_ndims(res[0].data_type()), + expected_dim + ); + } +} diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0dcc58d5bb8ed..04aab7e3fd821 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -71,8 +71,11 @@ impl ScalarUDFImpl for SparkConcat { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Delegate to the underlying ConcatFunc for return type determination + // This allows proper handling of array concatenation + let concat_func = ConcatFunc::new(); + concat_func.return_type(arg_types) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -80,8 +83,10 @@ impl ScalarUDFImpl for SparkConcat { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - // Accept any string types, including zero arguments - Ok(arg_types.to_vec()) + // Delegate to the underlying ConcatFunc for type coercion + // This allows proper handling of array vs string type validation + let concat_func = ConcatFunc::new(); + concat_func.coerce_types(arg_types) } } @@ -119,7 +124,43 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // If all scalars and any is NULL, return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + // First check if we're dealing with array types by delegating to ConcatFunc + let concat_func = ConcatFunc::new(); + let return_type = concat_func.return_type( + &arg_values + .iter() + .map(|arg| arg.data_type()) + .collect::>(), + )?; + + // Return appropriate null value based on return type + return Ok(ColumnarValue::Scalar(match return_type { + DataType::List(_) => { + let null_array = arrow::array::new_null_array(&return_type, 1); + let list_array = null_array + .as_any() + .downcast_ref::() + .unwrap(); + ScalarValue::List(Arc::new(list_array.clone())) + } + DataType::LargeList(_) => { + let null_array = arrow::array::new_null_array(&return_type, 1); + let list_array = null_array + .as_any() + .downcast_ref::() + .unwrap(); + ScalarValue::LargeList(Arc::new(list_array.clone())) + } + DataType::FixedSizeList(_, _) => { + let null_array = arrow::array::new_null_array(&return_type, 1); + let list_array = null_array + .as_any() + .downcast_ref::() + .unwrap(); + ScalarValue::FixedSizeList(Arc::new(list_array.clone())) + } + _ => ScalarValue::Utf8(None), + })); } // Step 2: Delegate to DataFusion's concat @@ -198,8 +239,36 @@ fn apply_null_mask( ) -> Result { match (result, null_mask) { // Scalar with ReturnNull mask means return NULL - (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + (ColumnarValue::Scalar(scalar), NullMaskResolution::ReturnNull) => { + // Return null value with the appropriate type + let null_scalar = match scalar.data_type() { + DataType::List(_) => { + let null_array = arrow::array::new_null_array(&scalar.data_type(), 1); + let list_array = null_array + .as_any() + .downcast_ref::() + .unwrap(); + ScalarValue::List(Arc::new(list_array.clone())) + } + DataType::LargeList(_) => { + let null_array = arrow::array::new_null_array(&scalar.data_type(), 1); + let list_array = null_array + .as_any() + .downcast_ref::() + .unwrap(); + ScalarValue::LargeList(Arc::new(list_array.clone())) + } + DataType::FixedSizeList(_, _) => { + let null_array = arrow::array::new_null_array(&scalar.data_type(), 1); + let list_array = null_array + .as_any() + .downcast_ref::() + .unwrap(); + ScalarValue::FixedSizeList(Arc::new(list_array.clone())) + } + _ => ScalarValue::Utf8(None), + }; + Ok(ColumnarValue::Scalar(null_scalar)) } // Scalar without mask, return as-is (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), diff --git a/datafusion/sqllogictest/test_files/concat_arrays.slt b/datafusion/sqllogictest/test_files/concat_arrays.slt new file mode 100644 index 0000000000000..da8417a8d20f6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/concat_arrays.slt @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for array concatenation via concat() function +# The concat() function uses runtime delegation to array concatenation when called with array arguments + +# Basic array concatenation +query ? +SELECT concat(make_array(1,2,3), make_array(4,5)); +---- +[1, 2, 3, 4, 5] + +# Multiple array concatenation +query ? +SELECT concat(make_array(1,2), make_array(3,4), make_array(5,6)); +---- +[1, 2, 3, 4, 5, 6] + +# Array concatenation with nulls +query ? +SELECT concat(make_array(1, NULL, 3), make_array(4)); +---- +[1, NULL, 3, 4] + +# Empty array edge cases +# Note: These cases fail due to Arrow's limitation with concatenating different array types (Null vs Int64) +# This is consistent with limitations in the underlying array processing library +statement error Arrow error: Invalid argument error: It is not possible to concatenate arrays of different data types +SELECT concat(make_array(), make_array(1,2)); + +statement error Arrow error: Invalid argument error: It is not possible to concatenate arrays of different data types +SELECT concat(make_array(1,2), make_array()); + +# Test string arrays +query ? +SELECT concat(make_array('a', 'b'), make_array('c', 'd')); +---- +[a, b, c, d] + +# Test multi-row array concatenation +statement ok +CREATE TABLE array_table ( + id INT, + arr1 INT[], + arr2 INT[] +) AS VALUES + (1, make_array(1,2), make_array(3,4)), + (2, make_array(10,20), make_array(30,40)); + +query I? +SELECT id, concat(arr1, arr2) FROM array_table ORDER BY id; +---- +1 [1, 2, 3, 4] +2 [10, 20, 30, 40] + +# Mixed type rejection - should produce clear error +statement error Cannot mix array and non-array arguments in concat function +SELECT concat(make_array(1), 'x'); + +# Test single array concatenation - should work +query ? +SELECT concat(CAST(make_array(1,2) AS INT[])); +---- +[1, 2] + +# Test null array handling +query ? +SELECT concat(CAST(NULL AS BIGINT[]), make_array(1,2)); +---- +[1, 2] + +query ? +SELECT concat(make_array(1,2), CAST(NULL AS BIGINT[])); +---- +[1, 2] + +# Test all null arrays - expect error for now since no valid element type can be determined +statement error No valid arrays to concatenate +SELECT concat(CAST(NULL AS INT[]), CAST(NULL AS INT[])); + +# Test large arrays (performance) +query I +SELECT array_length(concat(range(0, 1000), range(1000, 2000))); +---- +2000 + +# Test different numeric types in arrays +query ? +SELECT concat(make_array(1::bigint, 2::bigint), make_array(3::bigint, 4::bigint)); +---- +[1, 2, 3, 4] + +# Test mixed types error with better error message +statement error Cannot mix array and non-array arguments in concat function +SELECT concat(make_array(1), 'hello'); + +# Test boolean arrays +query ? +SELECT concat(make_array(true, false), make_array(false, true)); +---- +[true, false, false, true] + +# Test float arrays +query ? +SELECT concat(make_array(1.5, 2.5), make_array(3.5, 4.5)); +---- +[1.5, 2.5, 3.5, 4.5] + +# Clean up +statement ok +DROP TABLE array_table; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index e15163cf6ec74..6474e8d8c2c45 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -826,12 +826,10 @@ datafusion public string_agg 1 IN expression String NULL false 1 datafusion public string_agg 2 IN delimiter String NULL false 1 datafusion public string_agg 1 OUT NULL String NULL false 1 -# test variable length arguments +# test variable length arguments - concat function with variadic_any signature may not expose parameters query TTTBI rowsort select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat'; ---- -concat String IN true 0 -concat String OUT false 0 # test ceorcion signature query TTITI rowsort diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt index 258cb829d7d4b..26003dab873b5 100644 --- a/datafusion/sqllogictest/test_files/spark/string/concat.slt +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -46,3 +46,39 @@ SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, ---- abc NULL + +# Test array concatenation +query ? +SELECT concat([1, 2], [3, 4]); +---- +[1, 2, 3, 4] + +query ? +SELECT concat([1, 2], [3, 4], [5, 6]); +---- +[1, 2, 3, 4, 5, 6] + +# Test array concatenation with nulls - Spark returns NULL if any argument is NULL +query ? +SELECT concat([1, 2], NULL, [3, 4]); +---- +NULL + +# Test array concatenation with empty arrays - Arrow limitation with Null vs Int64 types +statement error Arrow error: Invalid argument error: It is not possible to concatenate arrays of different data types +SELECT concat([], [1, 2]); + +statement error Arrow error: Invalid argument error: It is not possible to concatenate arrays of different data types +SELECT concat([1, 2], []); + +# Test concatenation of all null arrays - Spark returns NULL +query ? +SELECT concat(CAST(NULL AS INT[]), CAST(NULL AS INT[])); +---- +NULL + +# Test string arrays +query ? +SELECT concat(['a', 'b'], ['c', 'd']); +---- +[a, b, c, d] diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 7c88d1fd9c3eb..5c89603fe16b1 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1332,16 +1332,16 @@ chr(expression) ### `concat` -Concatenates multiple strings together. +Concatenates multiple strings or arrays together. ```sql -concat(str[, ..., str_n]) +concat(str[, ..., str_n]) or concat(array[, ..., array_n]) ``` #### Arguments -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -- **str_n**: Subsequent string expressions to concatenate. +- **str_or_array**: String or Array expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_or_array_n**: Subsequent string or array expressions to concatenate. Cannot mix strings and arrays. #### Example @@ -1352,6 +1352,12 @@ concat(str[, ..., str_n]) +-------------------------------------------------------+ | datafusion | +-------------------------------------------------------+ +> select concat(make_array(1, 2), make_array(3, 4)); ++------------------------------------------+ +| concat(make_array(1, 2), make_array(3, 4)) | ++------------------------------------------+ +| [1, 2, 3, 4] | ++------------------------------------------+ ``` **Related functions**: