Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ pub use self::builder::BooleanBuilder;
pub use self::builder::DecimalBuilder;
pub use self::builder::FixedSizeBinaryBuilder;
pub use self::builder::FixedSizeListBuilder;
pub use self::builder::GenericStringBuilder;
pub use self::builder::LargeBinaryBuilder;
pub use self::builder::LargeListBuilder;
pub use self::builder::LargeStringBuilder;
Expand Down
1 change: 1 addition & 0 deletions rust/arrow/src/compute/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod concat;
pub mod filter;
pub mod length;
pub mod limit;
pub mod regexp;
pub mod sort;
pub mod substring;
pub mod take;
Expand Down
160 changes: 160 additions & 0 deletions rust/arrow/src/compute/kernels/regexp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines kernel to extract substrings based on a regular
//! expression of a \[Large\]StringArray

use crate::array::{
ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder,
StringOffsetSizeTrait,
};
use crate::error::{ArrowError, Result};
use std::collections::HashMap;

use std::sync::Arc;

use regex::Regex;

/// Extract all groups matched by a regular expression for a given String array.
pub fn regexp_match<OffsetSize: StringOffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
regex_array: &GenericStringArray<OffsetSize>,
flags_array: Option<&GenericStringArray<OffsetSize>>,
) -> Result<ArrayRef> {
let mut patterns: HashMap<String, Regex> = HashMap::new();
let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::new(0);
let mut list_builder = ListBuilder::new(builder);

let complete_pattern = match flags_array {
Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
|(pattern, flags)| {
pattern.map(|pattern| match flags {
Some(value) => format!("(?{}){}", value, pattern),
None => pattern.to_string(),
})
},
)) as Box<dyn Iterator<Item = Option<String>>>,
None => Box::new(
regex_array
.iter()
.map(|pattern| pattern.map(|pattern| pattern.to_string())),
),
};
array
.iter()
.zip(complete_pattern)
.map(|(value, pattern)| {
match (value, pattern) {
// Required for Postgres compatibility:
// SELECT regexp_match('foobarbequebaz', ''); = {""}
(Some(_), Some(pattern)) if pattern == *"" => {
list_builder.values().append_value("")?;
list_builder.append(true)?;
}
(Some(value), Some(pattern)) => {
let existing_pattern = patterns.get(&pattern);
let re = match existing_pattern {
Some(re) => re.clone(),
None => {
let re = Regex::new(pattern.as_str()).map_err(|e| {
ArrowError::ComputeError(format!(
"Regular expression did not compile: {:?}",
e
))
})?;
patterns.insert(pattern, re.clone());
re
}
};
match re.captures(value) {
Some(caps) => {
for m in caps.iter().skip(1) {
if let Some(v) = m {
list_builder.values().append_value(v.as_str())?;
}
}
list_builder.append(true)?
}
None => list_builder.append(false)?,
}
}
_ => list_builder.append(false)?,
}
Ok(())
})
.collect::<Result<Vec<()>>>()?;
Ok(Arc::new(list_builder.finish()))
}

#[cfg(test)]
mod tests {
use super::*;
use crate::array::{ListArray, StringArray};

#[test]
fn match_single_group() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test case for and (regexp_match('foobarbequebaz', '(bar)(bequ1e)') above):

SELECT regexp_match('foobarbequebaz', ''); = {""}

Some of these behaviors from Postgres don't really make sense to me.

Copy link
Contributor Author

@sweb sweb Mar 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seddonm1 First: I am very impressed that you know of this case.

My original implementation returned an empty List, without an item. Do you know whether Postgres actually returns a quoted empty string? I am asking because

SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); => {bar,beque}

so I am not sure what to make of the quotes, since strings are not returned with quotes or is this just a special case when the string is empty?

Regardless, I added special case for the empty string pattern

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know about this case. I just put a few scenarios into a Postgres instance running locally (via docker). Your implementation does make sense.

let values = vec![
Some("abc-005-def"),
Some("X-7-5"),
Some("X545"),
None,
Some("foobarbequebaz"),
Some("foobarbequebaz"),
];
let array = StringArray::from(values);
let mut pattern_values = vec![r".*-(\d*)-.*"; 4];
pattern_values.push(r"(bar)(bequ1e)");
pattern_values.push("");
let pattern = StringArray::from(pattern_values);
let actual = regexp_match(&array, &pattern, None)?;
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("005")?;
expected_builder.append(true)?;
expected_builder.values().append_value("7")?;
expected_builder.append(true)?;
expected_builder.append(false)?;
expected_builder.append(false)?;
expected_builder.append(false)?;
expected_builder.values().append_value("")?;
expected_builder.append(true)?;
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
Ok(())
}

#[test]
fn match_single_group_with_flags() -> Result<()> {
let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None];
let array = StringArray::from(values);
let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]);
let flags = StringArray::from(vec!["i"; 4]);
let actual = regexp_match(&array, &pattern, Some(&flags))?;
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new(0);
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.append(false)?;
expected_builder.values().append_value("7")?;
expected_builder.append(true)?;
expected_builder.append(false)?;
expected_builder.append(false)?;
let expected = expected_builder.finish();
let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(&expected, result);
Ok(())
}
}
1 change: 1 addition & 0 deletions rust/arrow/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub use self::kernels::comparison::*;
pub use self::kernels::concat::*;
pub use self::kernels::filter::*;
pub use self::kernels::limit::*;
pub use self::kernels::regexp::*;
pub use self::kernels::sort::*;
pub use self::kernels::take::*;
pub use self::kernels::temporal::*;
Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ unary_scalar_expr!(Lpad, lpad);
unary_scalar_expr!(Ltrim, ltrim);
unary_scalar_expr!(MD5, md5);
unary_scalar_expr!(OctetLength, octet_length);
unary_scalar_expr!(RegexpMatch, regexp_match);
unary_scalar_expr!(RegexpReplace, regexp_replace);
unary_scalar_expr!(Replace, replace);
unary_scalar_expr!(Repeat, repeat);
Expand Down
8 changes: 4 additions & 4 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ pub use expr::{
ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count,
count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
octet_length, or, regexp_replace, repeat, replace, reverse, right, round, rpad,
rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with,
strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr,
ExprRewriter, ExpressionVisitor, Literal, Recursion,
octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right,
round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when,
Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
122 changes: 120 additions & 2 deletions rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ pub enum BuiltinScalarFunction {
Trim,
/// upper
Upper,
/// regexp_match
RegexpMatch,
}

impl fmt::Display for BuiltinScalarFunction {
Expand Down Expand Up @@ -271,7 +273,7 @@ impl FromStr for BuiltinScalarFunction {
"translate" => BuiltinScalarFunction::Translate,
"trim" => BuiltinScalarFunction::Trim,
"upper" => BuiltinScalarFunction::Upper,

"regexp_match" => BuiltinScalarFunction::RegexpMatch,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -607,6 +609,20 @@ pub fn return_type(
));
}
}),
BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] {
DataType::LargeUtf8 => {
DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true)))
}
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The regexp_extract function can only accept strings.".to_string(),
));
}
}),

BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
Expand Down Expand Up @@ -853,6 +869,28 @@ pub fn create_physical_expr(
_ => unreachable!(),
},
},
BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_match,
i32,
"regexp_match"
);
make_scalar_function(func)(args)
}
DataType::LargeUtf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
regexp_match,
i64,
"regexp_match"
);
make_scalar_function(func)(args)
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function regexp_match",
other
))),
},
BuiltinScalarFunction::RegexpReplace => |args| match args[0].data_type() {
DataType::Utf8 => {
let func = invoke_if_regex_expressions_feature_flag!(
Expand Down Expand Up @@ -1229,6 +1267,12 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::NullIf => {
Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec())
}
BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![
Signature::Exact(vec![DataType::Utf8, DataType::Utf8]),
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
]),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
Expand Down Expand Up @@ -1386,7 +1430,7 @@ mod tests {
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array,
Int32Array, StringArray, UInt32Array, UInt64Array,
Int32Array, ListArray, StringArray, UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
Expand Down Expand Up @@ -3646,4 +3690,78 @@ mod tests {
"PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
)
}

#[test]
#[cfg(feature = "regex_expressions")]
fn test_regexp_match() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

// concat(value, value)
let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"]));
let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
let columns: Vec<ArrayRef> = vec![col_value];
let expr = create_physical_expr(
&BuiltinScalarFunction::RegexpMatch,
&[col("a"), pattern],
&schema,
)?;

// type is correct
assert_eq!(
expr.data_type(&schema)?,
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
);

// evaluate works
let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());

// downcast works
let result = result.as_any().downcast_ref::<ListArray>().unwrap();
let first_row = result.value(0);
let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();

// value is correct
let expected = "555".to_string();
assert_eq!(first_row.value(0), expected);

Ok(())
}

#[test]
#[cfg(feature = "regex_expressions")]
fn test_regexp_match_all_literals() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

// concat(value, value)
let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string())));
let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string())));
let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
let expr = create_physical_expr(
&BuiltinScalarFunction::RegexpMatch,
&[col_value, pattern],
&schema,
)?;

// type is correct
assert_eq!(
expr.data_type(&schema)?,
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
);

// evaluate works
let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());

// downcast works
let result = result.as_any().downcast_ref::<ListArray>().unwrap();
let first_row = result.value(0);
let first_row = first_row.as_any().downcast_ref::<StringArray>().unwrap();

// value is correct
let expected = "555".to_string();
assert_eq!(first_row.value(0), expected);

Ok(())
}
}
Loading