diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 0fffd84b7047a..2398f8154311d 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -131,7 +131,7 @@ pub enum TypeSignature { Numeric(usize), /// Fixed number of arguments of all the same string types. /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. - /// Null is considerd as `Utf8` by default + /// Null is considered as `Utf8` by default /// Dictionary with string value type is also handled. String(usize), /// Zero argument diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index ce555c36274e9..da4ab2bed49af 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -32,44 +32,26 @@ use std::any::{type_name, Any}; use crate::utils::{downcast_arg, make_scalar_function}; use arrow::compute::cast; +use arrow_array::builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder}; +use arrow_array::cast::AsArray; +use arrow_array::{GenericStringArray, StringViewArray}; use arrow_schema::DataType::{ - Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, -}; -use datafusion_common::cast::{ - as_generic_string_array, as_large_list_array, as_list_array, as_string_array, + Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; +use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::exec_err; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_functions::strings::StringArrayType; use std::sync::{Arc, OnceLock}; -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - macro_rules! call_array_function { ($DATATYPE:expr, false) => { match $DATATYPE { DataType::Utf8 => array_function!(StringArray), + DataType::Utf8View => array_function!(StringViewArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), DataType::Float32 => array_function!(Float32Array), @@ -89,6 +71,7 @@ macro_rules! call_array_function { match $DATATYPE { DataType::List(_) => array_function!(ListArray), DataType::Utf8 => array_function!(StringArray), + DataType::Utf8View => array_function!(StringViewArray), DataType::LargeUtf8 => array_function!(LargeStringArray), DataType::Boolean => array_function!(BooleanArray), DataType::Float32 => array_function!(Float32Array), @@ -106,6 +89,27 @@ macro_rules! call_array_function { }}; } +macro_rules! to_string { + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + for x in arr { + match x { + Some(x) => { + $ARG.push_str(&x.to_string()); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } + } + } + } + Ok($ARG) + }}; +} + // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( ArrayToString, @@ -222,10 +226,7 @@ impl StringToArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), - TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), aliases: vec![String::from("string_to_list")], @@ -248,12 +249,12 @@ impl ScalarUDFImpl for StringToArray { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { + Utf8 | Utf8View | LargeUtf8 => { List(Arc::new(Field::new("item", arg_types[0].clone(), true))) } _ => { return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." + "The string_to_array function can only accept Utf8, Utf8View or LargeUtf8." ); } }) @@ -261,10 +262,10 @@ impl ScalarUDFImpl for StringToArray { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - Utf8 => make_scalar_function(string_to_array_inner::)(args), + Utf8 | Utf8View => make_scalar_function(string_to_array_inner::)(args), LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), other => { - exec_err!("unsupported type for string_to_array function as {other}") + exec_err!("unsupported type for string_to_array function as {other:?}") } } } @@ -329,13 +330,22 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { let arr = &args[0]; - let delimiters = as_string_array(&args[1])?; - let delimiters: Vec> = delimiters.iter().collect(); + let delimiters: Vec> = match args[1].data_type() { + Utf8 => args[1].as_string::().iter().collect(), + Utf8View => args[1].as_string_view().iter().collect(), + LargeUtf8 => args[1].as_string::().iter().collect(), + other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") + }; let mut null_string = String::from(""); let mut with_null_string = false; if args.len() == 3 { - null_string = as_string_array(&args[2])?.value(0).to_string(); + null_string = match args[2].data_type() { + Utf8 => args[2].as_string::().value(0).to_string(), + Utf8View => args[2].as_string_view().value(0).to_string(), + LargeUtf8 => args[2].as_string::().value(0).to_string(), + other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") + }; with_null_string = true; } @@ -495,20 +505,173 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { /// String_to_array SQL function /// Splits string at occurrences of delimiter and returns an array of parts /// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { +fn string_to_array_inner(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("string_to_array expects two or three arguments"); } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - )); + match args[0].data_type() { + Utf8 => { + let string_array = args[0].as_string::(); + let builder = StringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); + string_to_array_inner_2::<&GenericStringArray, StringBuilder>(args, string_array, builder) + } + Utf8View => { + let string_array = args[0].as_string_view(); + let builder = StringViewBuilder::with_capacity(string_array.len()); + string_to_array_inner_2::<&StringViewArray, StringViewBuilder>(args, string_array, builder) + } + LargeUtf8 => { + let string_array = args[0].as_string::(); + let builder = LargeStringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); + string_to_array_inner_2::<&GenericStringArray, LargeStringBuilder>(args, string_array, builder) + } + other => exec_err!("unsupported type for first argument to string_to_array function as {other:?}") + } +} + +fn string_to_array_inner_2<'a, StringArrType, StringBuilderType>( + args: &'a [ArrayRef], + string_array: StringArrType, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + match args[1].data_type() { + Utf8 => { + let delimiter_array = args[1].as_string::(); + if args.len() == 2 { + string_to_array_impl::< + StringArrType, + &GenericStringArray, + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::, + StringBuilderType>(args, string_array, delimiter_array, string_builder) + } + } + Utf8View => { + let delimiter_array = args[1].as_string_view(); + + if args.len() == 2 { + string_to_array_impl::< + StringArrType, + &StringViewArray, + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::(args, string_array, delimiter_array, string_builder) + } + } + LargeUtf8 => { + let delimiter_array = args[1].as_string::(); + if args.len() == 2 { + string_to_array_impl::< + StringArrType, + &GenericStringArray, + &StringViewArray, + StringBuilderType, + >(string_array, delimiter_array, None, string_builder) + } else { + string_to_array_inner_3::, + StringBuilderType>(args, string_array, delimiter_array, string_builder) + } + } + other => exec_err!("unsupported type for second argument to string_to_array function as {other:?}") + } +} - match args.len() { - 2 => { +fn string_to_array_inner_3<'a, StringArrType, DelimiterArrType, StringBuilderType>( + args: &'a [ArrayRef], + string_array: StringArrType, + delimiter_array: DelimiterArrType, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + match args[2].data_type() { + Utf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &GenericStringArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + Utf8View => { + let null_type_array = Some(args[2].as_string_view()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &StringViewArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + LargeUtf8 => { + let null_type_array = Some(args[2].as_string::()); + string_to_array_impl::< + StringArrType, + DelimiterArrType, + &GenericStringArray, + StringBuilderType, + >( + string_array, + delimiter_array, + null_type_array, + string_builder, + ) + } + other => { + exec_err!("unsupported type for string_to_array function as {other:?}") + } + } +} + +fn string_to_array_impl< + 'a, + StringArrType, + DelimiterArrType, + NullValueArrType, + StringBuilderType, +>( + string_array: StringArrType, + delimiter_array: DelimiterArrType, + null_value_array: Option, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + NullValueArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + let mut list_builder = ListBuilder::new(string_builder); + + match null_value_array { + None => { string_array.iter().zip(delimiter_array.iter()).for_each( |(string, delimiter)| { match (string, delimiter) { @@ -524,63 +687,90 @@ pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c); + list_builder.values().append_value(c.as_str()); }); list_builder.append(true); } _ => list_builder.append(false), // null value } }, - ); + ) } - - 3 => { - let null_value_array = as_generic_string_array::(&args[2])?; - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { + Some(null_value_array) => string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { list_builder.values().append_null(); } else { - list_builder.values().append_value(string); + list_builder.values().append_value(s); } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(c); - } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value + }); + list_builder.append(true); } - }); - } - _ => { - return exec_err!( - "Expect string_to_array function to take two or three parameters" - ) - } - } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c.as_str()); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }), + }; let list_array = list_builder.finish(); Ok(Arc::new(list_array) as ArrayRef) } + +trait StringArrayBuilderType: ArrayBuilder { + fn append_value(&mut self, val: &str); + + fn append_null(&mut self); +} + +impl StringArrayBuilderType for StringBuilder { + fn append_value(&mut self, val: &str) { + StringBuilder::append_value(self, val); + } + + fn append_null(&mut self) { + StringBuilder::append_null(self); + } +} + +impl StringArrayBuilderType for StringViewBuilder { + fn append_value(&mut self, val: &str) { + StringViewBuilder::append_value(self, val) + } + + fn append_null(&mut self) { + StringViewBuilder::append_null(self) + } +} + +impl StringArrayBuilderType for LargeStringBuilder { + fn append_value(&mut self, val: &str) { + LargeStringBuilder::append_value(self, val); + } + + fn append_null(&mut self) { + LargeStringBuilder::append_null(self); + } +} diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1e60699a1f653..da3a53dc07c3a 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3973,6 +3973,30 @@ ORDER BY column1; 3 [bar] bar NULL [baz] baz +# verify make_array does force to Utf8View +query T +SELECT arrow_typeof(make_array(arrow_cast('a', 'Utf8View'), 'b', 'c', 'd')); +---- +List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# expect a,b,c,d. make_array forces all types to be of a common type (see above) +query T +SELECT array_to_string(make_array(arrow_cast('a', 'Utf8View'), 'b', 'c', 'd'), ','); +---- +a,b,c,d + +# array_to_string using largeutf8 for second arg +query TTT +select array_to_string(['h', 'e', 'l', 'l', 'o'], arrow_cast(',', 'LargeUtf8')), array_to_string([1, 2, 3, 4, 5], arrow_cast('-', 'LargeUtf8')), array_to_string([1.0, 2.0, 3.0], arrow_cast('|', 'LargeUtf8')); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + +# array_to_string using utf8view for second arg +query TTT +select array_to_string(['h', 'e', 'l', 'l', 'o'], arrow_cast(',', 'Utf8View')), array_to_string([1, 2, 3, 4, 5], arrow_cast('-', 'Utf8View')), array_to_string([1.0, 2.0, 3.0], arrow_cast('|', 'Utf8View')); +---- +h,e,l,l,o 1-2-3-4-5 1|2|3 + statement ok drop table table1; @@ -6916,6 +6940,79 @@ select string_to_array(e, ',') from values; [adipiscing] NULL +# karge string tests for string_to_array + +# string_to_array scalar function +query ? +SELECT string_to_array(arrow_cast('abcxxxdef', 'LargeUtf8'), 'xxx') +---- +[abc, def] + +# string_to_array scalar function +query ? +SELECT string_to_array(arrow_cast('abcxxxdef', 'LargeUtf8'), arrow_cast('xxx', 'LargeUtf8')) +---- +[abc, def] + +query ? +SELECT string_to_array(arrow_cast('abc', 'LargeUtf8'), NULL) +---- +[a, b, c] + +query ? +select string_to_array(arrow_cast(e, 'LargeUtf8'), ',') from values; +---- +[Lorem] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +query ? +select string_to_array(arrow_cast(e, 'LargeUtf8'), ',', arrow_cast('Lorem', 'LargeUtf8')) from values; +---- +[] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +# string view tests for string_to_array + +# string_to_array scalar function +query ? +SELECT string_to_array(arrow_cast('abcxxxdef', 'Utf8View'), 'xxx') +---- +[abc, def] + +query ? +SELECT string_to_array(arrow_cast('abc', 'Utf8View'), NULL) +---- +[a, b, c] + +query ? +select string_to_array(arrow_cast(e, 'Utf8View'), ',') from values; +---- +[Lorem] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +# test string_to_array aliases + query ? select string_to_list(e, 'm') from values; ---- diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 5a08f3f5447a5..98ba8181397cb 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1023,6 +1023,27 @@ logical_plan 01)Projection: digest(test.column1_utf8view, Utf8View("md5")) AS c 02)--TableScan: test projection=[column1_utf8view] +## Ensure no unexpected casts for string_to_array +query TT +EXPLAIN SELECT + string_to_array(column1_utf8view, ',') AS c +FROM test; +---- +logical_plan +01)Projection: string_to_array(test.column1_utf8view, Utf8View(",")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no unexpected casts for array_to_string +query TT +EXPLAIN SELECT + array_to_string(string_to_array(column1_utf8view, NULL), ',') AS c +FROM test; +---- +logical_plan +01)Projection: array_to_string(string_to_array(test.column1_utf8view, Utf8View(NULL)), Utf8(",")) AS c +02)--TableScan: test projection=[column1_utf8view] + + ## Ensure no casts for binary operators # `~` operator (regex match) query TT