diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 7a122506e67..d5d1b3ffd45 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -59,6 +59,7 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - String functions - [x] Length - [x] Concatenate + - [x] Split - Miscellaneous/Boolean functions - [x] nullif - Common date/time functions diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 63915381a7b..3cb716c2553 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -848,6 +848,14 @@ pub fn length(e: Expr) -> Expr { } } +/// returns the provided index of a split string +pub fn split(args: Vec) -> Expr { + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::SplitPart, + args, + } +} + /// returns the concatenation of string expressions pub fn concat(args: Vec) -> Expr { Expr::ScalarFunction { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index fbad5e26606..42af096ea1f 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,8 +37,8 @@ pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, combine_filters, concat, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max, - md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, - tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, + md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, split, sqrt, + sum, tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index c9c2cde2a9e..a37438dd973 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -119,6 +119,8 @@ pub enum BuiltinScalarFunction { Length, /// concat Concat, + /// split + SplitPart, /// lower Lower, /// upper @@ -181,6 +183,7 @@ impl FromStr for BuiltinScalarFunction { "char_length" => BuiltinScalarFunction::Length, "character_length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, + "split_part" => BuiltinScalarFunction::SplitPart, "lower" => BuiltinScalarFunction::Lower, "trim" => BuiltinScalarFunction::Trim, "ltrim" => BuiltinScalarFunction::Ltrim, @@ -238,6 +241,7 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::SplitPart => Ok(DataType::Utf8), BuiltinScalarFunction::Concat => Ok(DataType::Utf8), BuiltinScalarFunction::Lower => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, @@ -428,6 +432,9 @@ pub fn create_physical_expr( BuiltinScalarFunction::Concat => { |args| Ok(Arc::new(string_expressions::concatenate(args)?)) } + BuiltinScalarFunction::SplitPart => { + |args| Ok(Arc::new(string_expressions::split_part(args)?)) + } BuiltinScalarFunction::Lower => |args| match args[0].data_type() { DataType::Utf8 => Ok(Arc::new(string_expressions::lower::(args)?)), DataType::LargeUtf8 => Ok(Arc::new(string_expressions::lower::(args)?)), @@ -499,6 +506,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { // for now, the list is small, as we do not have many built-in functions. match fun { BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), + BuiltinScalarFunction::SplitPart => { + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int64]) + } BuiltinScalarFunction::Upper | BuiltinScalarFunction::Lower | BuiltinScalarFunction::Length diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index c633aa874f4..b3641f815a8 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -19,7 +19,7 @@ use crate::error::{DataFusionError, Result}; use arrow::array::{ - Array, ArrayRef, GenericStringArray, StringArray, StringBuilder, + Array, ArrayRef, GenericStringArray, Int64Array, StringArray, StringBuilder, StringOffsetSizeTrait, }; @@ -34,6 +34,58 @@ macro_rules! downcast_vec { }}; } +/// split string columns, return value at index +pub fn split_part(args: &[ArrayRef]) -> Result { + let haystack = &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast split_part input to StringArray".to_string(), + ) + })?; + + let needle = &args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast split_part input to StringArray".to_string(), + ) + })?; + + let part = &args[2] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "could not cast split_part input to Int64Array".to_string(), + ) + })?; + + let mut builder = StringBuilder::new(args.len()); + + for index in 0..args[0].len() { + if haystack.is_null(index) || needle.is_null(index) || part.is_null(index) { + builder.append_null()?; + continue; + } + + let hs = haystack.value(index); + let ndl = needle.value(index); + + // rust is 0 indexed, PostgreSQL is 1 indexed + let pnum = part.value(index) - 1; + + match hs.split(ndl).nth(pnum as usize) { + Some(i) => builder.append_value(i)?, + None => builder.append_value("")?, + }; + } + + Ok(builder.finish()) +} + /// concatenate string columns together. pub fn concatenate(args: &[ArrayRef]) -> Result { // downcast all arguments to strings diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 0c2ff6c863f..63bddb5b6e5 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -1554,6 +1554,31 @@ async fn query_large_length() -> Result<()> { generic_query_length::(DataType::LargeUtf8).await } +#[tokio::test] +async fn query_split_part() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec![ + "hello-world", + "a", + "-", + "---", + ]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + let sql = "SELECT split_part(c1,'-',1) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["hello"], vec!["a"], vec![""], vec![""]]; + assert_eq!(expected, actual); + Ok(()) +} + #[tokio::test] async fn query_not() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)]));