Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/datafusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI
- String functions
- [x] Length
- [x] Concatenate
- [x] Split
- Miscellaneous/Boolean functions
- [x] nullif
- Common date/time functions
Expand Down
8 changes: 8 additions & 0 deletions rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,14 @@ pub fn length(e: Expr) -> Expr {
}
}

/// returns the provided index of a split string
pub fn split(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: functions::BuiltinScalarFunction::SplitPart,
args,
}
}

/// returns the concatenation of string expressions
pub fn concat(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ pub use expr::{
abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col,
combine_filters, concat, cos, count, count_distinct, create_udaf, create_udf, exp,
exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max,
md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum,
tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion,
md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, split, sqrt,
sum, tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
10 changes: 10 additions & 0 deletions rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ pub enum BuiltinScalarFunction {
Length,
/// concat
Concat,
/// split
SplitPart,
/// lower
Lower,
/// upper
Expand Down Expand Up @@ -181,6 +183,7 @@ impl FromStr for BuiltinScalarFunction {
"char_length" => BuiltinScalarFunction::Length,
"character_length" => BuiltinScalarFunction::Length,
"concat" => BuiltinScalarFunction::Concat,
"split_part" => BuiltinScalarFunction::SplitPart,
"lower" => BuiltinScalarFunction::Lower,
"trim" => BuiltinScalarFunction::Trim,
"ltrim" => BuiltinScalarFunction::Ltrim,
Expand Down Expand Up @@ -238,6 +241,7 @@ pub fn return_type(
));
}
}),
BuiltinScalarFunction::SplitPart => Ok(DataType::Utf8),
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
BuiltinScalarFunction::Lower => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
Expand Down Expand Up @@ -428,6 +432,9 @@ pub fn create_physical_expr(
BuiltinScalarFunction::Concat => {
|args| Ok(Arc::new(string_expressions::concatenate(args)?))
}
BuiltinScalarFunction::SplitPart => {
|args| Ok(Arc::new(string_expressions::split_part(args)?))
}
BuiltinScalarFunction::Lower => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(string_expressions::lower::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(string_expressions::lower::<i64>(args)?)),
Expand Down Expand Up @@ -499,6 +506,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
// for now, the list is small, as we do not have many built-in functions.
match fun {
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::SplitPart => {
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int64])
}
BuiltinScalarFunction::Upper
| BuiltinScalarFunction::Lower
| BuiltinScalarFunction::Length
Expand Down
54 changes: 53 additions & 1 deletion rust/datafusion/src/physical_plan/string_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::error::{DataFusionError, Result};
use arrow::array::{
Array, ArrayRef, GenericStringArray, StringArray, StringBuilder,
Array, ArrayRef, GenericStringArray, Int64Array, StringArray, StringBuilder,
StringOffsetSizeTrait,
};

Expand All @@ -34,6 +34,58 @@ macro_rules! downcast_vec {
}};
}

/// split string columns, return value at index
pub fn split_part(args: &[ArrayRef]) -> Result<StringArray> {
let haystack = &args[0]
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"could not cast split_part input to StringArray".to_string(),
)
})?;

let needle = &args[1]
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
DataFusionError::Internal(
"could not cast split_part input to StringArray".to_string(),
)
})?;

let part = &args[2]
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| {
DataFusionError::Internal(
"could not cast split_part input to Int64Array".to_string(),
)
})?;

let mut builder = StringBuilder::new(args.len());

for index in 0..args[0].len() {
if haystack.is_null(index) || needle.is_null(index) || part.is_null(index) {
builder.append_null()?;
continue;
}

let hs = haystack.value(index);
let ndl = needle.value(index);

// rust is 0 indexed, PostgreSQL is 1 indexed
let pnum = part.value(index) - 1;

match hs.split(ndl).nth(pnum as usize) {
Some(i) => builder.append_value(i)?,
None => builder.append_value("")?,
};
}

Ok(builder.finish())
}

/// concatenate string columns together.
pub fn concatenate(args: &[ArrayRef]) -> Result<StringArray> {
// downcast all arguments to strings
Expand Down
25 changes: 25 additions & 0 deletions rust/datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,31 @@ async fn query_large_length() -> Result<()> {
generic_query_length::<LargeStringArray>(DataType::LargeUtf8).await
}

#[tokio::test]
async fn query_split_part() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)]));

let data = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StringArray::from(vec![
"hello-world",
"a",
"-",
"---",
]))],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let mut ctx = ExecutionContext::new();
ctx.register_table("test", Box::new(table));
let sql = "SELECT split_part(c1,'-',1) FROM test";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["hello"], vec!["a"], vec![""], vec![""]];
assert_eq!(expected, actual);
Ok(())
}

#[tokio::test]
async fn query_not() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)]));
Expand Down