diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index c0073c03b81..65cf30832e2 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -216,6 +216,7 @@ pub use self::builder::BooleanBuilder; pub use self::builder::DecimalBuilder; pub use self::builder::FixedSizeBinaryBuilder; pub use self::builder::FixedSizeListBuilder; +pub use self::builder::GenericStringBuilder; pub use self::builder::LargeBinaryBuilder; pub use self::builder::LargeListBuilder; pub use self::builder::LargeStringBuilder; diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index a8d24979e04..862f55fe2f2 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -28,6 +28,7 @@ pub mod concat; pub mod filter; pub mod length; pub mod limit; +pub mod regexp; pub mod sort; pub mod substring; pub mod take; diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs new file mode 100644 index 00000000000..446d71d9f4a --- /dev/null +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernel to extract substrings based on a regular +//! expression of a \[Large\]StringArray + +use crate::array::{ + ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder, + StringOffsetSizeTrait, +}; +use crate::error::{ArrowError, Result}; +use std::collections::HashMap; + +use std::sync::Arc; + +use regex::Regex; + +/// Extract all groups matched by a regular expression for a given String array. +pub fn regexp_match( + array: &GenericStringArray, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, +) -> Result { + let mut patterns: HashMap = HashMap::new(); + let builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut list_builder = ListBuilder::new(builder); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{}){}", value, pattern), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == *"" => { + list_builder.values().append_value("")?; + list_builder.append(true)?; + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re.clone(), + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {:?}", + e + )) + })?; + patterns.insert(pattern, re.clone()); + re + } + }; + match re.captures(value) { + Some(caps) => { + for m in caps.iter().skip(1) { + if let Some(v) = m { + list_builder.values().append_value(v.as_str())?; + } + } + list_builder.append(true)? + } + None => list_builder.append(false)?, + } + } + _ => list_builder.append(false)?, + } + Ok(()) + }) + .collect::>>()?; + Ok(Arc::new(list_builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::{ListArray, StringArray}; + + #[test] + fn match_single_group() -> Result<()> { + let values = vec![ + Some("abc-005-def"), + Some("X-7-5"), + Some("X545"), + None, + Some("foobarbequebaz"), + Some("foobarbequebaz"), + ]; + let array = StringArray::from(values); + let mut pattern_values = vec![r".*-(\d*)-.*"; 4]; + pattern_values.push(r"(bar)(bequ1e)"); + pattern_values.push(""); + let pattern = StringArray::from(pattern_values); + let actual = regexp_match(&array, &pattern, None)?; + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("005")?; + expected_builder.append(true)?; + expected_builder.values().append_value("7")?; + expected_builder.append(true)?; + expected_builder.append(false)?; + expected_builder.append(false)?; + expected_builder.append(false)?; + expected_builder.values().append_value("")?; + expected_builder.append(true)?; + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn match_single_group_with_flags() -> Result<()> { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]); + let flags = StringArray::from(vec!["i"; 4]); + let actual = regexp_match(&array, &pattern, Some(&flags))?; + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.append(false)?; + expected_builder.values().append_value("7")?; + expected_builder.append(true)?; + expected_builder.append(false)?; + expected_builder.append(false)?; + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + Ok(()) + } +} diff --git a/rust/arrow/src/compute/mod.rs b/rust/arrow/src/compute/mod.rs index 9de07388e9c..be1aa277ca4 100644 --- a/rust/arrow/src/compute/mod.rs +++ b/rust/arrow/src/compute/mod.rs @@ -29,6 +29,7 @@ pub use self::kernels::comparison::*; pub use self::kernels::concat::*; pub use self::kernels::filter::*; pub use self::kernels::limit::*; +pub use self::kernels::regexp::*; pub use self::kernels::sort::*; pub use self::kernels::take::*; pub use self::kernels::temporal::*; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 314f5d477b3..991b16058b1 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1090,6 +1090,7 @@ unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(RegexpMatch, regexp_match); unary_scalar_expr!(RegexpReplace, regexp_replace); unary_scalar_expr!(Replace, replace); unary_scalar_expr!(Repeat, repeat); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 0e7e61981b1..f9be1ff9830 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,10 +37,10 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, - strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, + octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, + starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, + Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 9dc54a4113f..56365fec1dc 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -198,6 +198,8 @@ pub enum BuiltinScalarFunction { Trim, /// upper Upper, + /// regexp_match + RegexpMatch, } impl fmt::Display for BuiltinScalarFunction { @@ -271,7 +273,7 @@ impl FromStr for BuiltinScalarFunction { "translate" => BuiltinScalarFunction::Translate, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, - + "regexp_match" => BuiltinScalarFunction::RegexpMatch, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -607,6 +609,20 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] { + DataType::LargeUtf8 => { + DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true))) + } + DataType::Utf8 => { + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + } + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_extract function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos @@ -853,6 +869,28 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { + DataType::Utf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args) + } + DataType::LargeUtf8 => { + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other + ))), + }, BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_regex_expressions_feature_flag!( @@ -1229,6 +1267,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } + BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), + ]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -1386,7 +1430,7 @@ mod tests { use arrow::{ array::{ Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array, - Int32Array, StringArray, UInt32Array, UInt64Array, + Int32Array, ListArray, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -3646,4 +3690,78 @@ mod tests { "PrimitiveArray\n[\n 1,\n 1,\n]", ) } + + #[test] + #[cfg(feature = "regex_expressions")] + fn test_regexp_match() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + // concat(value, value) + let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"])); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let columns: Vec = vec![col_value]; + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpMatch, + &[col("a"), pattern], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + ); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + let first_row = result.value(0); + let first_row = first_row.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(first_row.value(0), expected); + + Ok(()) + } + + #[test] + #[cfg(feature = "regex_expressions")] + fn test_regexp_match_all_literals() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + // concat(value, value) + let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpMatch, + &[col_value, pattern], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + ); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + let first_row = result.value(0); + let first_row = first_row.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(first_row.value(0), expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs index 6482424e105..b526e7259ef 100644 --- a/rust/datafusion/src/physical_plan/regex_expressions.rs +++ b/rust/datafusion/src/physical_plan/regex_expressions.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait}; +use arrow::compute; use hashbrown::HashMap; use regex::Regex; @@ -43,6 +44,20 @@ macro_rules! downcast_string_arg { }}; } +/// extract a specific group from a string column, using a regular expression +pub fn regexp_match(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None) + .map_err(DataFusionError::ArrowError), + 3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T))) + .map_err(DataFusionError::ArrowError), + other => Err(DataFusionError::Internal(format!( + "regexp_match was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index f0c7acfce3e..b2367758493 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -115,7 +115,7 @@ macro_rules! build_list { for scalar_value in values { match scalar_value { ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(*v).unwrap() + builder.values().append_value(v.clone()).unwrap() } ScalarValue::$SCALAR_TY(None) => { builder.values().append_null().unwrap(); @@ -335,6 +335,10 @@ impl ScalarValue { DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), + DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), + DataType::LargeUtf8 => { + build_list!(LargeStringBuilder, LargeUtf8, values, size) + } _ => panic!("Unexpected DataType for list"), }), ScalarValue::Date32(e) => match e { diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 0dc79e91803..6a287e0b1bb 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2560,6 +2560,17 @@ async fn test_in_list_scalar() -> Result<()> { test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); + test_expression!("regexp_match('foobarbequebaz', '')", "[]"); + test_expression!( + "regexp_match('foobarbequebaz', '(bar)(beque)')", + "[bar, beque]" + ); + test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL"); + test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]"); + test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]"); + test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL"); + test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL"); + test_expression!("regexp_match('aaa-0', NULL)", "NULL"); Ok(()) }