From c38a5f84d4ce14d67b494c0d0543a916d949116a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Fri, 5 Mar 2021 17:29:22 +0100 Subject: [PATCH 01/17] feat: regexp_extract and regexp_match --- rust/arrow/src/array/mod.rs | 1 + rust/arrow/src/compute/kernels/mod.rs | 1 + rust/arrow/src/compute/kernels/regexp.rs | 195 ++++++++++++++++++ rust/arrow/src/compute/mod.rs | 1 + rust/datafusion/src/logical_plan/expr.rs | 2 + rust/datafusion/src/logical_plan/mod.rs | 2 +- .../datafusion/src/physical_plan/functions.rs | 85 +++++++- .../src/physical_plan/string_expressions.rs | 19 ++ rust/datafusion/tests/sql.rs | 40 ++++ 9 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 rust/arrow/src/compute/kernels/regexp.rs diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index c0073c03b81..65cf30832e2 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -216,6 +216,7 @@ pub use self::builder::BooleanBuilder; pub use self::builder::DecimalBuilder; pub use self::builder::FixedSizeBinaryBuilder; pub use self::builder::FixedSizeListBuilder; +pub use self::builder::GenericStringBuilder; pub use self::builder::LargeBinaryBuilder; pub use self::builder::LargeListBuilder; pub use self::builder::LargeStringBuilder; diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index a8d24979e04..862f55fe2f2 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -28,6 +28,7 @@ pub mod concat; pub mod filter; pub mod length; pub mod limit; +pub mod regexp; pub mod sort; pub mod substring; pub mod take; diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs new file mode 100644 index 00000000000..93a0ec2d77c --- /dev/null +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines kernel to extract substrings based on a regular +//! expression of a \[Large\]StringArray + +use crate::array::{ + Array, ArrayRef, GenericStringArray, GenericStringBuilder, + LargeStringArray, ListBuilder, StringArray, StringOffsetSizeTrait, +}; +use crate::datatypes::DataType; +use crate::error::{ArrowError, Result}; + +use std::sync::Arc; + +use regex::Regex; + +fn generic_regexp_extract( + array: &GenericStringArray, + re: &Regex, + idx: usize, +) -> Result { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(0); + + for maybe_value in array.iter() { + match maybe_value { + Some(value) => match re.captures(value) { + Some(caps) => { + let m = caps.get(idx).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Regexp has no group with index {}", + idx + )) + })?; + builder.append_value(m.as_str())? + } + None => builder.append_null()?, + }, + None => builder.append_null()?, + } + } + Ok(Arc::new(builder.finish())) +} + +fn generic_regexp_match( + array: &GenericStringArray, + re: &Regex, +) -> Result { + let builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut list_builder = ListBuilder::new(builder); + + for maybe_value in array.iter() { + match maybe_value { + Some(value) => match re.captures(value) { + Some(caps) => { + for m in caps.iter().skip(1) { + if let Some(v) = m { + list_builder.values().append_value(v.as_str())?; + } + } + list_builder.append(true)? + } + None => { + list_builder.values().append_value("")?; + list_builder.append(true)? + } + }, + None => list_builder.append(false)?, + } + } + Ok(Arc::new(list_builder.finish())) +} + +/// Extracts a specific group matched by a regular expression for a given String array. +/// Group index 0 returns the whole match, index 1 returns the first group and so on. Please +/// refer to regex crate for details on pattern specifics. +pub fn regexp_extract(array: &Array, pattern: &str, idx: usize) -> Result { + let re = Regex::new(pattern).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {:?}", e)) + })?; + match array.data_type() { + DataType::LargeUtf8 => generic_regexp_extract( + array + .as_any() + .downcast_ref::() + .expect("A large string is expected"), + &re, + idx, + ), + DataType::Utf8 => generic_regexp_extract( + array + .as_any() + .downcast_ref::() + .expect("A string is expected"), + &re, + idx, + ), + _ => Err(ArrowError::ComputeError(format!( + "regexp_extract does not support type {:?}", + array.data_type() + ))), + } +} + +/// Extract all groups matched by a regular expression for a given String array. +pub fn regexp_match(array: &Array, pattern: &str) -> Result { + let re = Regex::new(pattern).map_err(|e| { + ArrowError::ComputeError(format!("Regular expression did not compile: {:?}", e)) + })?; + match array.data_type() { + DataType::LargeUtf8 => generic_regexp_match( + array + .as_any() + .downcast_ref::() + .expect("A large string is expected"), + &re, + ), + DataType::Utf8 => generic_regexp_match( + array + .as_any() + .downcast_ref::() + .expect("A string is expected"), + &re, + ), + _ => Err(ArrowError::ComputeError(format!( + "regexp_match does not support type {:?}", + array.data_type() + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::ListArray; + + #[test] + fn extract_single_group() -> Result<()> { + let values = vec!["abc-005-def", "X-7-5", "X545"]; + let array = StringArray::from(values); + let pattern = r".*-(\d*)-.*"; + let actual = regexp_extract(&array, pattern, 1)?; + let expected = StringArray::from(vec![Some("005"), Some("7"), None]); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn no_matches() -> Result<()> { + let values = vec!["abc", "X::50::00", "X545"]; + let array = StringArray::from(values); + let pattern = r".*-(\d*)-.*"; + let actual = regexp_extract(&array, pattern, 1)?; + let expected = StringArray::from(vec![None, None, None]); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn match_single_group() -> Result<()> { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = r".*-(\d*)-.*"; + let actual = regexp_match(&array, pattern)?; + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("005")?; + expected_builder.append(true)?; + expected_builder.values().append_value("7")?; + expected_builder.append(true)?; + expected_builder.values().append_value("")?; + expected_builder.append(true)?; + expected_builder.append(false)?; + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + Ok(()) + } +} diff --git a/rust/arrow/src/compute/mod.rs b/rust/arrow/src/compute/mod.rs index 9de07388e9c..be1aa277ca4 100644 --- a/rust/arrow/src/compute/mod.rs +++ b/rust/arrow/src/compute/mod.rs @@ -29,6 +29,7 @@ pub use self::kernels::comparison::*; pub use self::kernels::concat::*; pub use self::kernels::filter::*; pub use self::kernels::limit::*; +pub use self::kernels::regexp::*; pub use self::kernels::sort::*; pub use self::kernels::take::*; pub use self::kernels::temporal::*; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 5b0876a79e0..c84252a8d06 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1090,6 +1090,8 @@ unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); +unary_scalar_expr!(RegexpExtract, regexp_extract); +unary_scalar_expr!(RegexpMatch, regexp_match); unary_scalar_expr!(Repeat, repeat); unary_scalar_expr!(Reverse, reverse); unary_scalar_expr!(Right, right); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index ab787ef82f4..11cc909fbf2 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,7 +37,7 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, repeat, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, + octet_length, or, regexp_extract, regexp_match, repeat, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, to_hex, trim, trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index ae8d128fc30..a2e192ac953 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -186,6 +186,9 @@ pub enum BuiltinScalarFunction { Trim, /// upper Upper, + /// regex_extract, + RegexpExtract, + RegexpMatch, } impl fmt::Display for BuiltinScalarFunction { @@ -253,7 +256,8 @@ impl FromStr for BuiltinScalarFunction { "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, - + "regexp_extract" => BuiltinScalarFunction::RegexpExtract, + "regexp_match" => BuiltinScalarFunction::RegexpMatch, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -538,6 +542,34 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::RegexpExtract => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::LargeUtf8, + DataType::Utf8 => DataType::Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_extract function can only accept strings.".to_string(), + )); + } + }), + BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] { + DataType::LargeUtf8 => DataType::List(Box::new(Field::new( + "item", + DataType::LargeUtf8, + true, + ))), + DataType::Utf8 => DataType::List(Box::new(Field::new( + "item", + DataType::Utf8, + true, + ))), + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_extract function can only accept strings.".to_string(), + )); + } + }), BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos @@ -726,6 +758,22 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, + BuiltinScalarFunction::RegexpExtract => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::regexp_extract)(args), + DataType::LargeUtf8 => make_scalar_function(string_expressions::regexp_extract)(args), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_extract", + other, + ))), + }, + BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { + DataType::Utf8 => make_scalar_function(string_expressions::regexp_match)(args), + DataType::LargeUtf8 => make_scalar_function(string_expressions::regexp_match)(args), + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + other, + ))), + }, BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::repeat::)(args) @@ -950,6 +998,8 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } + BuiltinScalarFunction::RegexpExtract => Signature::Any(3), + BuiltinScalarFunction::RegexpMatch => Signature::Any(2), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -2782,4 +2832,37 @@ mod tests { "PrimitiveArray\n[\n 1,\n 1,\n]", ) } + + #[test] + fn test_regexp_extract() -> Result<()> { + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + // concat(value, value) + let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let idx = lit(ScalarValue::Int64(Some(1))); + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpExtract, + &[col_value, pattern, idx], + &schema, + )?; + + // type is correct + assert_eq!(expr.data_type(&schema)?, DataType::Utf8); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(result.value(0).to_string(), expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index bc0e7633379..04f30f909ee 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -35,6 +35,7 @@ use arrow::{ Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, PrimitiveArray, StringArray, StringOffsetSizeTrait, }, + compute, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; use unicode_segmentation::UnicodeSegmentation; @@ -602,6 +603,24 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { } } +/// extract a specific group from a string column, using a regular expression +pub fn regexp_extract(args: &[ArrayRef]) -> Result { + let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); + let pattern = pattern_expr.value(0); + let idx_expr = args[2].as_any().downcast_ref::().unwrap(); + let idx = idx_expr.value(0) as usize; + compute::regexp_extract(args[0].as_ref(), pattern, idx) + .map_err(DataFusionError::ArrowError) +} + +/// extract a specific group from a string column, using a regular expression +pub fn regexp_match(args: &[ArrayRef]) -> Result { + let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); + let pattern = pattern_expr.value(0); + compute::regexp_match(args[0].as_ref(), pattern) + .map_err(DataFusionError::ArrowError) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index c8e198cb13c..a153a70fe61 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2421,3 +2421,43 @@ async fn inner_join_qualified_names() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn query_regexp_extract() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec!["aaa-0", "bb-1", "aa"]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + let sql = r"SELECT regexp_extract(c1, '.*-(\d)', 1) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["1"], vec!["NULL"]]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn query_regexp_match() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(StringArray::from(vec!["aaa-0", "bb-1", "aa"]))], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + let sql = r"SELECT regexp_match(c1, '.*-(\d)') FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["[0]"], vec!["[1]"], vec!["[]"]]; + assert_eq!(expected, actual); + Ok(()) +} From 0065a10ab73dc660f98d5d5773ee2c902ee413b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Fri, 5 Mar 2021 17:33:11 +0100 Subject: [PATCH 02/17] cleanups --- rust/arrow/src/compute/kernels/regexp.rs | 16 ++-------------- rust/datafusion/src/physical_plan/functions.rs | 1 + 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index 93a0ec2d77c..9fbde38e41b 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -19,8 +19,8 @@ //! expression of a \[Large\]StringArray use crate::array::{ - Array, ArrayRef, GenericStringArray, GenericStringBuilder, - LargeStringArray, ListBuilder, StringArray, StringOffsetSizeTrait, + Array, ArrayRef, GenericStringArray, GenericStringBuilder, LargeStringArray, + ListBuilder, StringArray, StringOffsetSizeTrait, }; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; @@ -160,18 +160,6 @@ mod tests { Ok(()) } - #[test] - fn no_matches() -> Result<()> { - let values = vec!["abc", "X::50::00", "X545"]; - let array = StringArray::from(values); - let pattern = r".*-(\d*)-.*"; - let actual = regexp_extract(&array, pattern, 1)?; - let expected = StringArray::from(vec![None, None, None]); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); - Ok(()) - } - #[test] fn match_single_group() -> Result<()> { let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index a2e192ac953..9e2f2843c71 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -188,6 +188,7 @@ pub enum BuiltinScalarFunction { Upper, /// regex_extract, RegexpExtract, + /// regexp_match RegexpMatch, } From e9c4c5df4995c6d391e73582730cc162b80b8bb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Mon, 8 Mar 2021 18:31:33 +0100 Subject: [PATCH 03/17] fix: clean up after rebase --- rust/datafusion/tests/sql.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index a153a70fe61..a083b2862eb 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2434,7 +2434,7 @@ async fn query_regexp_extract() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = r"SELECT regexp_extract(c1, '.*-(\d)', 1) FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["1"], vec!["NULL"]]; @@ -2454,7 +2454,7 @@ async fn query_regexp_match() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Box::new(table)); + ctx.register_table("test", Arc::new(table)); let sql = r"SELECT regexp_match(c1, '.*-(\d)') FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["[0]"], vec!["[1]"], vec!["[]"]]; From 2213696abf95901e90ad8dde5d27c23ba6aff3b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Mon, 8 Mar 2021 18:55:01 +0100 Subject: [PATCH 04/17] fix: formatting --- rust/datafusion/src/logical_plan/mod.rs | 6 +-- .../datafusion/src/physical_plan/functions.rs | 40 ++++++++++--------- .../src/physical_plan/string_expressions.rs | 3 +- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 11cc909fbf2..bebc8db76ff 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,9 +37,9 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, regexp_extract, regexp_match, repeat, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, - sha512, signum, sin, sqrt, substr, sum, tan, to_hex, trim, trunc, upper, when, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, + octet_length, or, regexp_extract, regexp_match, repeat, reverse, right, round, rpad, + rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, to_hex, + trim, trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 9e2f2843c71..2224148971a 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -549,28 +549,24 @@ pub fn return_type( _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( - "The regexp_extract function can only accept strings.".to_string(), + "The regexp_extract function can only accept strings.".to_string(), )); } - }), - BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::List(Box::new(Field::new( - "item", - DataType::LargeUtf8, - true, - ))), - DataType::Utf8 => DataType::List(Box::new(Field::new( - "item", - DataType::Utf8, - true, - ))), + }), + BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] { + DataType::LargeUtf8 => { + DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true))) + } + DataType::Utf8 => { + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + } _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( "The regexp_extract function can only accept strings.".to_string(), )); } - }), + }), BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos @@ -760,16 +756,24 @@ pub fn create_physical_expr( }, }, BuiltinScalarFunction::RegexpExtract => |args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function(string_expressions::regexp_extract)(args), - DataType::LargeUtf8 => make_scalar_function(string_expressions::regexp_extract)(args), + DataType::Utf8 => { + make_scalar_function(string_expressions::regexp_extract)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::regexp_extract)(args) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function regexp_extract", other, ))), }, BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function(string_expressions::regexp_match)(args), - DataType::LargeUtf8 => make_scalar_function(string_expressions::regexp_match)(args), + DataType::Utf8 => { + make_scalar_function(string_expressions::regexp_match)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::regexp_match)(args) + } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function regexp_match", other, diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 04f30f909ee..64c630b9d56 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -617,8 +617,7 @@ pub fn regexp_extract(args: &[ArrayRef]) -> Result { pub fn regexp_match(args: &[ArrayRef]) -> Result { let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); let pattern = pattern_expr.value(0); - compute::regexp_match(args[0].as_ref(), pattern) - .map_err(DataFusionError::ArrowError) + compute::regexp_match(args[0].as_ref(), pattern).map_err(DataFusionError::ArrowError) } /// Repeats string the specified number of times. From a66b414618f915b541f70d249b0d354e74a7b8c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Mon, 8 Mar 2021 19:07:41 +0100 Subject: [PATCH 05/17] fix: correct signature --- rust/datafusion/src/physical_plan/functions.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 2224148971a..3f0c9f4a5a8 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -1003,8 +1003,14 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } - BuiltinScalarFunction::RegexpExtract => Signature::Any(3), - BuiltinScalarFunction::RegexpMatch => Signature::Any(2), + BuiltinScalarFunction::RegexpExtract => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int64]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int64]), + ]), + BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![ + Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + ]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). From 5f98c0e7f28a84f09e138f79ae4e9d9fccf14095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Tue, 9 Mar 2021 22:52:57 +0100 Subject: [PATCH 06/17] refactor: make usage of literal explicit --- .../datafusion/src/physical_plan/functions.rs | 70 ++++++++++++++++--- .../src/physical_plan/string_expressions.rs | 8 +-- 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 3f0c9f4a5a8..6f333c7185d 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -767,17 +767,30 @@ pub fn create_physical_expr( other, ))), }, - BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(string_expressions::regexp_match)(args) + BuiltinScalarFunction::RegexpMatch => |args| match (&args[0], &args[1]) { + (c, ColumnarValue::Scalar(ScalarValue::Utf8(Some(pattern)))) + if c.data_type() == DataType::Utf8 + || c.data_type() == DataType::LargeUtf8 => + { + let arr = match c { + ColumnarValue::Array(col) => col.clone(), + ColumnarValue::Scalar(_) => c.clone().into_array(1), + }; + + Ok(ColumnarValue::Array(string_expressions::regexp_match( + &arr, + pattern.as_str(), + )?)) } - DataType::LargeUtf8 => { - make_scalar_function(string_expressions::regexp_match)(args) + (c, ColumnarValue::Scalar(ScalarValue::Utf8(Some(_)))) => { + Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function regexp_match", + c.data_type(), + ))) } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_match", - other, - ))), + (_, _) => Err(DataFusionError::Internal( + "regexp_match expects a literal string as second argument".to_string(), + )), }, BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { DataType::Utf8 => { @@ -1168,7 +1181,7 @@ mod tests { use arrow::{ array::{ Array, ArrayRef, BinaryArray, FixedSizeListArray, Float64Array, Int32Array, - StringArray, UInt32Array, UInt64Array, + ListArray, StringArray, UInt32Array, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -2876,4 +2889,41 @@ mod tests { Ok(()) } + + #[test] + fn test_regexp_match() -> Result<()> { + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + // concat(value, value) + let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpMatch, + &[col_value, pattern], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + ); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + let first_row = result.value(0); + let first_row = first_row.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(first_row.value(0), expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index 64c630b9d56..bcde59c66fa 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -614,10 +614,10 @@ pub fn regexp_extract(args: &[ArrayRef]) -> Result { } /// extract a specific group from a string column, using a regular expression -pub fn regexp_match(args: &[ArrayRef]) -> Result { - let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); - let pattern = pattern_expr.value(0); - compute::regexp_match(args[0].as_ref(), pattern).map_err(DataFusionError::ArrowError) +pub fn regexp_match(col: &ArrayRef, pattern: &str) -> Result { + // let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); + // let pattern = pattern_expr.value(0); + compute::regexp_match(col.as_ref(), pattern).map_err(DataFusionError::ArrowError) } /// Repeats string the specified number of times. From dfe4f9f8931edc03cd8f8fd28d1121faaaf8b9a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Wed, 10 Mar 2021 17:23:45 +0100 Subject: [PATCH 07/17] refactor: support regex pattern as own column --- rust/arrow/src/compute/kernels/regexp.rs | 135 ++++++------------ rust/datafusion/src/logical_plan/expr.rs | 1 - rust/datafusion/src/logical_plan/mod.rs | 6 +- .../datafusion/src/physical_plan/functions.rs | 92 ++---------- .../src/physical_plan/string_expressions.rs | 15 +- rust/datafusion/tests/sql.rs | 20 --- 6 files changed, 57 insertions(+), 212 deletions(-) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index 9fbde38e41b..2be01980a77 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -24,117 +24,80 @@ use crate::array::{ }; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; +use std::collections::HashMap; use std::sync::Arc; use regex::Regex; -fn generic_regexp_extract( - array: &GenericStringArray, - re: &Regex, - idx: usize, -) -> Result { - let mut builder: GenericStringBuilder = GenericStringBuilder::new(0); - - for maybe_value in array.iter() { - match maybe_value { - Some(value) => match re.captures(value) { - Some(caps) => { - let m = caps.get(idx).ok_or_else(|| { - ArrowError::ComputeError(format!( - "Regexp has no group with index {}", - idx - )) - })?; - builder.append_value(m.as_str())? - } - None => builder.append_null()?, - }, - None => builder.append_null()?, - } - } - Ok(Arc::new(builder.finish())) -} - fn generic_regexp_match( array: &GenericStringArray, - re: &Regex, + regex_array: &StringArray, ) -> Result { + let mut patterns: HashMap<&str, Regex> = HashMap::new(); let builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut list_builder = ListBuilder::new(builder); - for maybe_value in array.iter() { - match maybe_value { - Some(value) => match re.captures(value) { - Some(caps) => { - for m in caps.iter().skip(1) { - if let Some(v) = m { - list_builder.values().append_value(v.as_str())?; + for (maybe_value, maybe_pattern) in array.iter().zip(regex_array) { + match (maybe_value, maybe_pattern) { + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(pattern); + let re = match existing_pattern { + Some(re) => re.clone(), + None => { + let re = Regex::new(pattern).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {:?}", + e + )) + })?; + patterns.insert(pattern, re.clone()); + re + } + }; + match re.captures(value) { + Some(caps) => { + for m in caps.iter().skip(1) { + if let Some(v) = m { + list_builder.values().append_value(v.as_str())?; + } } + list_builder.append(true)? + } + None => { + list_builder.values().append_value("")?; + list_builder.append(true)? } - list_builder.append(true)? - } - None => { - list_builder.values().append_value("")?; - list_builder.append(true)? } - }, - None => list_builder.append(false)?, + } + _ => list_builder.append(false)?, } } Ok(Arc::new(list_builder.finish())) } -/// Extracts a specific group matched by a regular expression for a given String array. -/// Group index 0 returns the whole match, index 1 returns the first group and so on. Please -/// refer to regex crate for details on pattern specifics. -pub fn regexp_extract(array: &Array, pattern: &str, idx: usize) -> Result { - let re = Regex::new(pattern).map_err(|e| { - ArrowError::ComputeError(format!("Regular expression did not compile: {:?}", e)) - })?; +/// Extract all groups matched by a regular expression for a given String array. +pub fn regexp_match(array: &Array, pattern: &Array) -> Result { match array.data_type() { - DataType::LargeUtf8 => generic_regexp_extract( + DataType::LargeUtf8 => generic_regexp_match( array .as_any() .downcast_ref::() .expect("A large string is expected"), - &re, - idx, - ), - DataType::Utf8 => generic_regexp_extract( - array + pattern .as_any() .downcast_ref::() .expect("A string is expected"), - &re, - idx, - ), - _ => Err(ArrowError::ComputeError(format!( - "regexp_extract does not support type {:?}", - array.data_type() - ))), - } -} - -/// Extract all groups matched by a regular expression for a given String array. -pub fn regexp_match(array: &Array, pattern: &str) -> Result { - let re = Regex::new(pattern).map_err(|e| { - ArrowError::ComputeError(format!("Regular expression did not compile: {:?}", e)) - })?; - match array.data_type() { - DataType::LargeUtf8 => generic_regexp_match( - array - .as_any() - .downcast_ref::() - .expect("A large string is expected"), - &re, ), DataType::Utf8 => generic_regexp_match( array .as_any() .downcast_ref::() .expect("A string is expected"), - &re, + pattern + .as_any() + .downcast_ref::() + .expect("A string is expected"), ), _ => Err(ArrowError::ComputeError(format!( "regexp_match does not support type {:?}", @@ -148,24 +111,12 @@ mod tests { use super::*; use crate::array::ListArray; - #[test] - fn extract_single_group() -> Result<()> { - let values = vec!["abc-005-def", "X-7-5", "X545"]; - let array = StringArray::from(values); - let pattern = r".*-(\d*)-.*"; - let actual = regexp_extract(&array, pattern, 1)?; - let expected = StringArray::from(vec![Some("005"), Some("7"), None]); - let result = actual.as_any().downcast_ref::().unwrap(); - assert_eq!(&expected, result); - Ok(()) - } - #[test] fn match_single_group() -> Result<()> { let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; let array = StringArray::from(values); - let pattern = r".*-(\d*)-.*"; - let actual = regexp_match(&array, pattern)?; + let pattern = StringArray::from(vec![r".*-(\d*)-.*"; 4]); + let actual = regexp_match(&array, &pattern)?; let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut expected_builder = ListBuilder::new(elem_builder); expected_builder.values().append_value("005")?; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index c84252a8d06..6a0876a123c 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -1090,7 +1090,6 @@ unary_scalar_expr!(Lpad, lpad); unary_scalar_expr!(Ltrim, ltrim); unary_scalar_expr!(MD5, md5); unary_scalar_expr!(OctetLength, octet_length); -unary_scalar_expr!(RegexpExtract, regexp_extract); unary_scalar_expr!(RegexpMatch, regexp_match); unary_scalar_expr!(Repeat, repeat); unary_scalar_expr!(Reverse, reverse); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index bebc8db76ff..18bac4b3dcd 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,9 +37,9 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, regexp_extract, regexp_match, repeat, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, to_hex, - trim, trunc, upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, + octet_length, or, regexp_match, repeat, reverse, right, round, rpad, rtrim, sha224, + sha256, sha384, sha512, signum, sin, sqrt, substr, sum, tan, to_hex, trim, trunc, + upper, when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 6f333c7185d..3a0f3a7131f 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -186,8 +186,6 @@ pub enum BuiltinScalarFunction { Trim, /// upper Upper, - /// regex_extract, - RegexpExtract, /// regexp_match RegexpMatch, } @@ -257,7 +255,6 @@ impl FromStr for BuiltinScalarFunction { "to_timestamp" => BuiltinScalarFunction::ToTimestamp, "trim" => BuiltinScalarFunction::Trim, "upper" => BuiltinScalarFunction::Upper, - "regexp_extract" => BuiltinScalarFunction::RegexpExtract, "regexp_match" => BuiltinScalarFunction::RegexpMatch, _ => { return Err(DataFusionError::Plan(format!( @@ -543,16 +540,6 @@ pub fn return_type( )); } }), - BuiltinScalarFunction::RegexpExtract => Ok(match arg_types[0] { - DataType::LargeUtf8 => DataType::LargeUtf8, - DataType::Utf8 => DataType::Utf8, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The regexp_extract function can only accept strings.".to_string(), - )); - } - }), BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] { DataType::LargeUtf8 => { DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true))) @@ -755,43 +742,18 @@ pub fn create_physical_expr( _ => unreachable!(), }, }, - BuiltinScalarFunction::RegexpExtract => |args| match args[0].data_type() { + BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::regexp_extract)(args) + make_scalar_function(string_expressions::regexp_match)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::regexp_extract)(args) + make_scalar_function(string_expressions::regexp_match)(args) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_extract", + "Unsupported data type {:?} for function repeat", other, ))), }, - BuiltinScalarFunction::RegexpMatch => |args| match (&args[0], &args[1]) { - (c, ColumnarValue::Scalar(ScalarValue::Utf8(Some(pattern)))) - if c.data_type() == DataType::Utf8 - || c.data_type() == DataType::LargeUtf8 => - { - let arr = match c { - ColumnarValue::Array(col) => col.clone(), - ColumnarValue::Scalar(_) => c.clone().into_array(1), - }; - - Ok(ColumnarValue::Array(string_expressions::regexp_match( - &arr, - pattern.as_str(), - )?)) - } - (c, ColumnarValue::Scalar(ScalarValue::Utf8(Some(_)))) => { - Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function regexp_match", - c.data_type(), - ))) - } - (_, _) => Err(DataFusionError::Internal( - "regexp_match expects a literal string as second argument".to_string(), - )), - }, BuiltinScalarFunction::Repeat => |args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::repeat::)(args) @@ -1016,10 +978,6 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::NullIf => { Signature::Uniform(2, SUPPORTED_NULLIF_TYPES.to_vec()) } - BuiltinScalarFunction::RegexpExtract => Signature::OneOf(vec![ - Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Int64]), - Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Int64]), - ]), BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![ Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), @@ -2857,51 +2815,17 @@ mod tests { ) } - #[test] - fn test_regexp_extract() -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; - - // concat(value, value) - let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); - let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); - let idx = lit(ScalarValue::Int64(Some(1))); - let expr = create_physical_expr( - &BuiltinScalarFunction::RegexpExtract, - &[col_value, pattern, idx], - &schema, - )?; - - // type is correct - assert_eq!(expr.data_type(&schema)?, DataType::Utf8); - - // evaluate works - let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - - // downcast works - let result = result.as_any().downcast_ref::().unwrap(); - - // value is correct - let expected = "555".to_string(); - assert_eq!(result.value(0).to_string(), expected); - - Ok(()) - } - #[test] fn test_regexp_match() -> Result<()> { - // any type works here: we evaluate against a literal of `value` - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); // concat(value, value) - let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); + let col_value: ArrayRef = Arc::new(StringArray::from(vec!["aaa-555"])); let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let columns: Vec = vec![col_value]; let expr = create_physical_expr( &BuiltinScalarFunction::RegexpMatch, - &[col_value, pattern], + &[col("a"), pattern], &schema, )?; diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index bcde59c66fa..fd5f4316ebb 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -604,20 +604,11 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { } /// extract a specific group from a string column, using a regular expression -pub fn regexp_extract(args: &[ArrayRef]) -> Result { - let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); - let pattern = pattern_expr.value(0); - let idx_expr = args[2].as_any().downcast_ref::().unwrap(); - let idx = idx_expr.value(0) as usize; - compute::regexp_extract(args[0].as_ref(), pattern, idx) - .map_err(DataFusionError::ArrowError) -} - -/// extract a specific group from a string column, using a regular expression -pub fn regexp_match(col: &ArrayRef, pattern: &str) -> Result { +pub fn regexp_match(args: &[ArrayRef]) -> Result { // let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); // let pattern = pattern_expr.value(0); - compute::regexp_match(col.as_ref(), pattern).map_err(DataFusionError::ArrowError) + compute::regexp_match(args[0].as_ref(), args[1].as_ref()) + .map_err(DataFusionError::ArrowError) } /// Repeats string the specified number of times. diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index a083b2862eb..0aa0576fb16 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2422,26 +2422,6 @@ async fn inner_join_qualified_names() -> Result<()> { Ok(()) } -#[tokio::test] -async fn query_regexp_extract() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(StringArray::from(vec!["aaa-0", "bb-1", "aa"]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); - let sql = r"SELECT regexp_extract(c1, '.*-(\d)', 1) FROM test"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0"], vec!["1"], vec!["NULL"]]; - assert_eq!(expected, actual); - Ok(()) -} - #[tokio::test] async fn query_regexp_match() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); From 476f16778e5168d9a760139ae3768da8412776b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Fri, 12 Mar 2021 16:52:13 +0100 Subject: [PATCH 08/17] feat: add flags --- rust/arrow/src/compute/kernels/regexp.rs | 123 +++++++++++++----- .../datafusion/src/physical_plan/functions.rs | 2 + .../src/physical_plan/string_expressions.rs | 14 +- 3 files changed, 102 insertions(+), 37 deletions(-) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index 2be01980a77..c5c17e5aeaa 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -33,51 +33,76 @@ use regex::Regex; fn generic_regexp_match( array: &GenericStringArray, regex_array: &StringArray, + flags_array: Option<&StringArray>, ) -> Result { - let mut patterns: HashMap<&str, Regex> = HashMap::new(); + let mut patterns: HashMap = HashMap::new(); let builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut list_builder = ListBuilder::new(builder); - for (maybe_value, maybe_pattern) in array.iter().zip(regex_array) { - match (maybe_value, maybe_pattern) { - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(pattern); - let re = match existing_pattern { - Some(re) => re.clone(), - None => { - let re = Regex::new(pattern).map_err(|e| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {:?}", - e - )) - })?; - patterns.insert(pattern, re.clone()); - re - } - }; - match re.captures(value) { - Some(caps) => { - for m in caps.iter().skip(1) { - if let Some(v) = m { - list_builder.values().append_value(v.as_str())?; + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(value) => format!("(?{}){}", value, pattern), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re.clone(), + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {:?}", + e + )) + })?; + patterns.insert(pattern, re.clone()); + re + } + }; + match re.captures(value) { + Some(caps) => { + for m in caps.iter().skip(1) { + if let Some(v) = m { + list_builder.values().append_value(v.as_str())?; + } } + list_builder.append(true)? + } + None => { + list_builder.values().append_value("")?; + list_builder.append(true)? } - list_builder.append(true)? - } - None => { - list_builder.values().append_value("")?; - list_builder.append(true)? } } + _ => list_builder.append(false)?, } - _ => list_builder.append(false)?, - } - } + Ok(()) + }) + .collect::>>()?; Ok(Arc::new(list_builder.finish())) } /// Extract all groups matched by a regular expression for a given String array. -pub fn regexp_match(array: &Array, pattern: &Array) -> Result { +pub fn regexp_match( + array: &Array, + pattern: &Array, + flags: Option<&Array>, +) -> Result { match array.data_type() { DataType::LargeUtf8 => generic_regexp_match( array @@ -88,6 +113,11 @@ pub fn regexp_match(array: &Array, pattern: &Array) -> Result { .as_any() .downcast_ref::() .expect("A string is expected"), + flags.map(|x| { + x.as_any() + .downcast_ref::() + .expect("A string is expected") + }), ), DataType::Utf8 => generic_regexp_match( array @@ -98,6 +128,11 @@ pub fn regexp_match(array: &Array, pattern: &Array) -> Result { .as_any() .downcast_ref::() .expect("A string is expected"), + flags.map(|x| { + x.as_any() + .downcast_ref::() + .expect("A string is expected") + }), ), _ => Err(ArrowError::ComputeError(format!( "regexp_match does not support type {:?}", @@ -116,7 +151,7 @@ mod tests { let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; let array = StringArray::from(values); let pattern = StringArray::from(vec![r".*-(\d*)-.*"; 4]); - let actual = regexp_match(&array, &pattern)?; + let actual = regexp_match(&array, &pattern, None)?; let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut expected_builder = ListBuilder::new(elem_builder); expected_builder.values().append_value("005")?; @@ -131,4 +166,26 @@ mod tests { assert_eq!(&expected, result); Ok(()) } + + #[test] + fn match_single_group_with_flags() -> Result<()> { + let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let array = StringArray::from(values); + let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]); + let flags = StringArray::from(vec!["i"; 4]); + let actual = regexp_match(&array, &pattern, Some(&flags))?; + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let mut expected_builder = ListBuilder::new(elem_builder); + expected_builder.values().append_value("")?; + expected_builder.append(true)?; + expected_builder.values().append_value("7")?; + expected_builder.append(true)?; + expected_builder.values().append_value("")?; + expected_builder.append(true)?; + expected_builder.append(false)?; + let expected = expected_builder.finish(); + let result = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(&expected, result); + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 3a0f3a7131f..76140fcfcd7 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -981,6 +981,8 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::RegexpMatch => Signature::OneOf(vec![ Signature::Exact(vec![DataType::Utf8, DataType::Utf8]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), + Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), ]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index fd5f4316ebb..89f0e793a54 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -605,10 +605,16 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { /// extract a specific group from a string column, using a regular expression pub fn regexp_match(args: &[ArrayRef]) -> Result { - // let pattern_expr = args[1].as_any().downcast_ref::().unwrap(); - // let pattern = pattern_expr.value(0); - compute::regexp_match(args[0].as_ref(), args[1].as_ref()) - .map_err(DataFusionError::ArrowError) + match args.len() { + 2 => compute::regexp_match(args[0].as_ref(), args[1].as_ref(), None) + .map_err(DataFusionError::ArrowError), + 3 => compute::regexp_match(args[0].as_ref(), args[1].as_ref(), Some(args[2].as_ref())) + .map_err(DataFusionError::ArrowError), + other => Err(DataFusionError::Internal(format!( + "regexp_match was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } } /// Repeats string the specified number of times. From 1199c17979dd107738e8759b857f5283ae886546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Mon, 15 Mar 2021 16:55:15 +0100 Subject: [PATCH 09/17] refactor: move regexp_match to regex expressions --- rust/arrow/src/compute/kernels/regexp.rs | 58 +++---------------- rust/datafusion/src/logical_plan/mod.rs | 8 +-- .../datafusion/src/physical_plan/functions.rs | 50 +++++++++++++++- .../src/physical_plan/regex_expressions.rs | 15 +++++ .../src/physical_plan/string_expressions.rs | 15 ----- rust/datafusion/src/scalar.rs | 6 +- 6 files changed, 79 insertions(+), 73 deletions(-) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index c5c17e5aeaa..b3b04c9ae28 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -19,10 +19,9 @@ //! expression of a \[Large\]StringArray use crate::array::{ - Array, ArrayRef, GenericStringArray, GenericStringBuilder, LargeStringArray, - ListBuilder, StringArray, StringOffsetSizeTrait, + ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder, + StringOffsetSizeTrait, }; -use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use std::collections::HashMap; @@ -30,10 +29,11 @@ use std::sync::Arc; use regex::Regex; -fn generic_regexp_match( +/// Extract all groups matched by a regular expression for a given String array. +pub fn regexp_match( array: &GenericStringArray, - regex_array: &StringArray, - flags_array: Option<&StringArray>, + regex_array: &GenericStringArray, + flags_array: Option<&GenericStringArray>, ) -> Result { let mut patterns: HashMap = HashMap::new(); let builder: GenericStringBuilder = GenericStringBuilder::new(0); @@ -97,54 +97,10 @@ fn generic_regexp_match( Ok(Arc::new(list_builder.finish())) } -/// Extract all groups matched by a regular expression for a given String array. -pub fn regexp_match( - array: &Array, - pattern: &Array, - flags: Option<&Array>, -) -> Result { - match array.data_type() { - DataType::LargeUtf8 => generic_regexp_match( - array - .as_any() - .downcast_ref::() - .expect("A large string is expected"), - pattern - .as_any() - .downcast_ref::() - .expect("A string is expected"), - flags.map(|x| { - x.as_any() - .downcast_ref::() - .expect("A string is expected") - }), - ), - DataType::Utf8 => generic_regexp_match( - array - .as_any() - .downcast_ref::() - .expect("A string is expected"), - pattern - .as_any() - .downcast_ref::() - .expect("A string is expected"), - flags.map(|x| { - x.as_any() - .downcast_ref::() - .expect("A string is expected") - }), - ), - _ => Err(ArrowError::ComputeError(format!( - "regexp_match does not support type {:?}", - array.data_type() - ))), - } -} - #[cfg(test)] mod tests { use super::*; - use crate::array::ListArray; + use crate::array::{ListArray, StringArray}; #[test] fn match_single_group() -> Result<()> { diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 7ea0ff61f20..f9be1ff9830 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -37,10 +37,10 @@ pub use expr::{ ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, - rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, - strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Expr, - ExprRewriter, ExpressionVisitor, Literal, Recursion, + octet_length, or, regexp_match, regexp_replace, repeat, replace, reverse, right, + round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, + starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper, when, + Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index b676d7d33ee..28d86bc0cad 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -871,10 +871,20 @@ pub fn create_physical_expr( }, BuiltinScalarFunction::RegexpMatch => |args| match args[0].data_type() { DataType::Utf8 => { - make_scalar_function(string_expressions::regexp_match)(args) + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i32, + "regexp_match" + ); + make_scalar_function(func)(args) } DataType::LargeUtf8 => { - make_scalar_function(string_expressions::regexp_match)(args) + let func = invoke_if_regex_expressions_feature_flag!( + regexp_match, + i64, + "regexp_match" + ); + make_scalar_function(func)(args) } other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function regexp_match", @@ -3716,4 +3726,40 @@ mod tests { Ok(()) } + + #[test] + fn test_regexp_match_all_literals() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + // concat(value, value) + let col_value = lit(ScalarValue::Utf8(Some("aaa-555".to_string()))); + let pattern = lit(ScalarValue::Utf8(Some(r".*-(\d*)".to_string()))); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + let expr = create_physical_expr( + &BuiltinScalarFunction::RegexpMatch, + &[col_value, pattern], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))) + ); + + // evaluate works + let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + + // downcast works + let result = result.as_any().downcast_ref::().unwrap(); + let first_row = result.value(0); + let first_row = first_row.as_any().downcast_ref::().unwrap(); + + // value is correct + let expected = "555".to_string(); + assert_eq!(first_row.value(0), expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/regex_expressions.rs b/rust/datafusion/src/physical_plan/regex_expressions.rs index 8df9a822f31..9faa358b41a 100644 --- a/rust/datafusion/src/physical_plan/regex_expressions.rs +++ b/rust/datafusion/src/physical_plan/regex_expressions.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use arrow::array::{ArrayRef, GenericStringArray, StringOffsetSizeTrait}; +use arrow::compute; use hashbrown::HashMap; use regex::Regex; @@ -54,6 +55,20 @@ fn regex_replace_posix_groups(replacement: &str) -> String { .into_owned() } +/// extract a specific group from a string column, using a regular expression +pub fn regexp_match(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), None) + .map_err(DataFusionError::ArrowError), + 3 => compute::regexp_match(downcast_string_arg!(args[0], "string", T), downcast_string_arg!(args[1], "pattern", T), Some(downcast_string_arg!(args[1], "flags", T))) + .map_err(DataFusionError::ArrowError), + other => Err(DataFusionError::Internal(format!( + "regexp_match was called with {} arguments. It requires at least 2 and at most 3.", + other + ))), + } +} + /// Replaces substring(s) matching a POSIX regular expression /// regexp_replace('Thomas', '.[mN]a.', 'M') = 'ThM' pub fn regexp_replace(args: &[ArrayRef]) -> Result { diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs index b904705b9e5..882fe30502f 100644 --- a/rust/datafusion/src/physical_plan/string_expressions.rs +++ b/rust/datafusion/src/physical_plan/string_expressions.rs @@ -33,7 +33,6 @@ use arrow::{ Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, PrimitiveArray, StringArray, StringOffsetSizeTrait, }, - compute, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; @@ -444,20 +443,6 @@ pub fn ltrim(args: &[ArrayRef]) -> Result { } } -/// extract a specific group from a string column, using a regular expression -pub fn regexp_match(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => compute::regexp_match(args[0].as_ref(), args[1].as_ref(), None) - .map_err(DataFusionError::ArrowError), - 3 => compute::regexp_match(args[0].as_ref(), args[1].as_ref(), Some(args[2].as_ref())) - .map_err(DataFusionError::ArrowError), - other => Err(DataFusionError::Internal(format!( - "regexp_match was called with {} arguments. It requires at least 2 and at most 3.", - other - ))), - } -} - /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs index ca0e27dd687..9ea127d34ff 100644 --- a/rust/datafusion/src/scalar.rs +++ b/rust/datafusion/src/scalar.rs @@ -108,7 +108,7 @@ macro_rules! build_list { for scalar_value in values { match scalar_value { ScalarValue::$SCALAR_TY(Some(v)) => { - builder.values().append_value(*v).unwrap() + builder.values().append_value(v.clone()).unwrap() } ScalarValue::$SCALAR_TY(None) => { builder.values().append_null().unwrap(); @@ -333,6 +333,10 @@ impl ScalarValue { DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), + DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), + DataType::LargeUtf8 => { + build_list!(LargeStringBuilder, LargeUtf8, values, size) + } _ => panic!("Unexpected DataType for list"), }), ScalarValue::Date32(e) => match e { From cac07f09a6a02267874288066c71b47597c1af93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Wed, 17 Mar 2021 07:22:15 +0100 Subject: [PATCH 10/17] fix: add regex feature flag to regexp_match tests --- rust/datafusion/src/physical_plan/functions.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index 28d86bc0cad..56365fec1dc 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -3692,6 +3692,7 @@ mod tests { } #[test] + #[cfg(feature = "regex_expressions")] fn test_regexp_match() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); @@ -3728,6 +3729,7 @@ mod tests { } #[test] + #[cfg(feature = "regex_expressions")] fn test_regexp_match_all_literals() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); From ec32f9a744bb00085e8636601d65a2af11a7470c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Wed, 17 Mar 2021 08:20:40 +0100 Subject: [PATCH 11/17] fix: add regex feature flag to regexp_match tests --- rust/datafusion/tests/sql.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 6fb818e7ee8..31fdb911c8b 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2485,6 +2485,7 @@ async fn inner_join_qualified_names() -> Result<()> { } #[tokio::test] +#[cfg(feature = "regex_expressions")] async fn query_regexp_match() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); From 9b5c80d4b8ff993fb52ce1fa88d9219a529cad8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Fri, 26 Mar 2021 19:05:51 +0100 Subject: [PATCH 12/17] fix: sql test after merge --- rust/datafusion/tests/sql.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 4c19134430d..de344dbe119 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2608,6 +2608,7 @@ async fn invalid_qualified_table_references() -> Result<()> { Ok(()) } +#[tokio::test] #[cfg(feature = "regex_expressions")] async fn query_regexp_match() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); @@ -2620,7 +2621,7 @@ async fn query_regexp_match() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table)); + ctx.register_table("test", Arc::new(table))?; let sql = r"SELECT regexp_match(c1, '.*-(\d)') FROM test"; let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["[0]"], vec!["[1]"], vec!["[]"]]; From 077a7dc0f6719f84f1074b37f523377c395d3e37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Sun, 28 Mar 2021 14:51:19 +0200 Subject: [PATCH 13/17] chore: let unmatching pattern return null --- rust/arrow/src/compute/kernels/regexp.rs | 30 ++++++++++++++---------- rust/datafusion/tests/sql.rs | 2 +- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index b3b04c9ae28..d8b5d3d682d 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -83,10 +83,7 @@ pub fn regexp_match( } list_builder.append(true)? } - None => { - list_builder.values().append_value("")?; - list_builder.append(true)? - } + None => list_builder.append(false)?, } } _ => list_builder.append(false)?, @@ -104,9 +101,19 @@ mod tests { #[test] fn match_single_group() -> Result<()> { - let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; + let values = vec![ + Some("abc-005-def"), + Some("X-7-5"), + Some("X545"), + None, + Some("foobarbequebaz"), + Some("foobarbequebaz"), + ]; let array = StringArray::from(values); - let pattern = StringArray::from(vec![r".*-(\d*)-.*"; 4]); + let mut pattern_values = vec![r".*-(\d*)-.*"; 4]; + pattern_values.push(r"(bar)(bequ1e)"); + pattern_values.push(""); + let pattern = StringArray::from(pattern_values); let actual = regexp_match(&array, &pattern, None)?; let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut expected_builder = ListBuilder::new(elem_builder); @@ -114,9 +121,10 @@ mod tests { expected_builder.append(true)?; expected_builder.values().append_value("7")?; expected_builder.append(true)?; - expected_builder.values().append_value("")?; - expected_builder.append(true)?; expected_builder.append(false)?; + expected_builder.append(false)?; + expected_builder.append(false)?; + expected_builder.append(true)?; let expected = expected_builder.finish(); let result = actual.as_any().downcast_ref::().unwrap(); assert_eq!(&expected, result); @@ -132,12 +140,10 @@ mod tests { let actual = regexp_match(&array, &pattern, Some(&flags))?; let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.values().append_value("")?; - expected_builder.append(true)?; + expected_builder.append(false)?; expected_builder.values().append_value("7")?; expected_builder.append(true)?; - expected_builder.values().append_value("")?; - expected_builder.append(true)?; + expected_builder.append(false)?; expected_builder.append(false)?; let expected = expected_builder.finish(); let result = actual.as_any().downcast_ref::().unwrap(); diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index de344dbe119..959e7041ed8 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2624,7 +2624,7 @@ async fn query_regexp_match() -> Result<()> { ctx.register_table("test", Arc::new(table))?; let sql = r"SELECT regexp_match(c1, '.*-(\d)') FROM test"; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["[0]"], vec!["[1]"], vec!["[]"]]; + let expected = vec![vec!["[0]"], vec!["[1]"], vec!["NULL"]]; assert_eq!(expected, actual); Ok(()) } From f7e17452c60c7b07a02fcb2557f2a772cb11deaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Sun, 28 Mar 2021 17:18:04 +0200 Subject: [PATCH 14/17] feat: add special case for empty string pattern --- rust/arrow/src/compute/kernels/regexp.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index d8b5d3d682d..6006c31f934 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -59,6 +59,12 @@ pub fn regexp_match( .zip(complete_pattern) .map(|(value, pattern)| { match (value, pattern) { + // Required for Postgres compatibility: + // SELECT regexp_match('foobarbequebaz', ''); = {""} + (Some(_), Some(pattern)) if pattern == "".to_string() => { + list_builder.values().append_value("")?; + list_builder.append(true)?; + } (Some(value), Some(pattern)) => { let existing_pattern = patterns.get(&pattern); let re = match existing_pattern { @@ -124,6 +130,7 @@ mod tests { expected_builder.append(false)?; expected_builder.append(false)?; expected_builder.append(false)?; + expected_builder.values().append_value("")?; expected_builder.append(true)?; let expected = expected_builder.finish(); let result = actual.as_any().downcast_ref::().unwrap(); From da3301719d0bac04c9d249d9d6a7d3c56659f512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Sun, 28 Mar 2021 21:15:36 +0200 Subject: [PATCH 15/17] fix: clippy --- rust/arrow/src/compute/kernels/regexp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/arrow/src/compute/kernels/regexp.rs b/rust/arrow/src/compute/kernels/regexp.rs index 6006c31f934..446d71d9f4a 100644 --- a/rust/arrow/src/compute/kernels/regexp.rs +++ b/rust/arrow/src/compute/kernels/regexp.rs @@ -61,7 +61,7 @@ pub fn regexp_match( match (value, pattern) { // Required for Postgres compatibility: // SELECT regexp_match('foobarbequebaz', ''); = {""} - (Some(_), Some(pattern)) if pattern == "".to_string() => { + (Some(_), Some(pattern)) if pattern == *"" => { list_builder.values().append_value("")?; list_builder.append(true)?; } From 852ba38fc58f1d924a96299ef60f6b502cb063cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Wed, 31 Mar 2021 21:59:01 +0200 Subject: [PATCH 16/17] refactor: simplify tests for regexp_match --- rust/datafusion/tests/sql.rs | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 959e7041ed8..fb18595373e 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2516,6 +2516,14 @@ async fn test_in_list_scalar() -> Result<()> { test_expression!("'2' IN ('a','b',NULL,1)", "NULL"); test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); + test_expression!("regexp_match('foobarbequebaz', '')", "[]"); + test_expression!("regexp_match('foobarbequebaz', '(bar)(beque)')", "[bar, beque]"); + test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL"); + test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]"); + test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]"); + test_expression!("regexp_match('aa', '.*-(\\d)')", "NULL"); + test_expression!("regexp_match(NULL, '.*-(\\d)')", "NULL"); + test_expression!("regexp_match('aaa-0', NULL)", "NULL"); Ok(()) } @@ -2606,25 +2614,4 @@ async fn invalid_qualified_table_references() -> Result<()> { assert!(matches!(ctx.sql(&sql), Err(DataFusionError::Plan(_)))); } Ok(()) -} - -#[tokio::test] -#[cfg(feature = "regex_expressions")] -async fn query_regexp_match() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(StringArray::from(vec!["aaa-0", "bb-1", "aa"]))], - )?; - - let table = MemTable::try_new(schema, vec![vec![data]])?; - - let mut ctx = ExecutionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = r"SELECT regexp_match(c1, '.*-(\d)') FROM test"; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["[0]"], vec!["[1]"], vec!["NULL"]]; - assert_eq!(expected, actual); - Ok(()) -} +} \ No newline at end of file From 346227c60568221dd6dca10ccbe1efbeb0531218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCller?= Date: Thu, 1 Apr 2021 07:24:08 +0200 Subject: [PATCH 17/17] chore: formatting / linting --- rust/datafusion/tests/sql.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index fb18595373e..dbeeffa7aeb 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -2517,7 +2517,10 @@ async fn test_in_list_scalar() -> Result<()> { test_expression!("'1' NOT IN ('a','b',NULL,1)", "false"); test_expression!("'2' NOT IN ('a','b',NULL,1)", "NULL"); test_expression!("regexp_match('foobarbequebaz', '')", "[]"); - test_expression!("regexp_match('foobarbequebaz', '(bar)(beque)')", "[bar, beque]"); + test_expression!( + "regexp_match('foobarbequebaz', '(bar)(beque)')", + "[bar, beque]" + ); test_expression!("regexp_match('foobarbequebaz', '(ba3r)(bequ34e)')", "NULL"); test_expression!("regexp_match('aaa-0', '.*-(\\d)')", "[0]"); test_expression!("regexp_match('bb-1', '.*-(\\d)')", "[1]"); @@ -2614,4 +2617,4 @@ async fn invalid_qualified_table_references() -> Result<()> { assert!(matches!(ctx.sql(&sql), Err(DataFusionError::Plan(_)))); } Ok(()) -} \ No newline at end of file +}