From 09959a7de163a96bacdf91ff55368829d4df57af Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 29 Apr 2026 15:59:11 -0600 Subject: [PATCH 1/3] feat: support Spark expression `regexp_extract` Implement regexp_extract using the Rust regex crate. The expression is marked Incompatible because the Rust regex engine differs from the Java engine that Spark uses; users must opt in via spark.comet.expression.RegExpExtract.allowIncompatible=true. --- native/spark-expr/src/comet_scalar_funcs.rs | 4 + native/spark-expr/src/string_funcs/mod.rs | 2 + .../src/string_funcs/regexp_extract.rs | 301 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../org/apache/comet/serde/strings.scala | 36 ++- .../expressions/string/regexp_extract.sql | 35 ++ .../string/regexp_extract_enabled.sql | 73 +++++ 7 files changed, 451 insertions(+), 1 deletion(-) create mode 100644 native/spark-expr/src/string_funcs/regexp_extract.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 0957868a60..9a2dd33f97 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -188,6 +188,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) } + "regexp_extract" => { + let func = Arc::new(crate::string_funcs::spark_regexp_extract); + make_comet_scalar_udf!("regexp_extract", func, without data_type) + } "get_json_object" => { let func = Arc::new(crate::string_funcs::spark_get_json_object); make_comet_scalar_udf!("get_json_object", func, without data_type) diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index bb785bdb44..6655866bd3 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -17,10 +17,12 @@ mod contains; mod get_json_object; +mod regexp_extract; mod split; mod substring; pub use contains::SparkContains; pub use get_json_object::spark_get_json_object; +pub use regexp_extract::spark_regexp_extract; pub use split::spark_split; pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs new file mode 100644 index 0000000000..7364ef72a9 --- /dev/null +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, GenericStringArray, GenericStringBuilder}; +use arrow::datatypes::DataType; +use datafusion::common::{ + cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, + ScalarValue, +}; +use datafusion::logical_expr::ColumnarValue; +use regex::Regex; +use std::sync::Arc; + +/// Spark-compatible `regexp_extract(subject, pattern, idx)`. +/// +/// Returns the substring of `subject` matched by group `idx` of the first match of `pattern`. +/// `idx = 0` returns the entire match. Returns an empty string when there is no match or the +/// matched group is unset (optional group). Returns null when any input is null. Errors when +/// `idx` is out of range for the pattern's group count. +/// +/// Note: this uses the Rust `regex` crate, whose syntax differs from Java's regex engine in +/// some ways. The expression is therefore reported as Incompatible. +pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "regexp_extract expects 2 or 3 arguments (subject, pattern, [idx]), got {}", + args.len() + ); + } + + let idx: i32 = if args.len() == 3 { + match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) => *i, + ColumnarValue::Scalar(ScalarValue::Int32(None)) => { + return Ok(null_result(subject_len(&args[0]))); + } + _ => { + return exec_err!("regexp_extract idx must be an Int32 scalar"); + } + } + } else { + 1 + }; + + if idx < 0 { + return exec_err!("regexp_extract idx must be non-negative, got {}", idx); + } + + let pattern = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p.clone(), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + return Ok(null_result(subject_len(&args[0]))); + } + _ => { + return exec_err!("regexp_extract pattern must be a scalar string"); + } + }; + + let regex = Regex::new(&pattern).map_err(|e| { + DataFusionError::Execution(format!("Invalid regex pattern '{pattern}': {e}")) + })?; + + let group_count = regex.captures_len() as i32 - 1; + if idx > group_count { + return Err(DataFusionError::Execution(format!( + "Regex group count is {group_count}, but the specified group index is {idx}" + ))); + } + let group_idx = idx as usize; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let strings = as_generic_string_array::(array.as_ref())?; + Ok(ColumnarValue::Array(extract_array::( + strings, ®ex, group_idx, + ))) + } + DataType::LargeUtf8 => { + let strings = as_generic_string_array::(array.as_ref())?; + Ok(ColumnarValue::Array(extract_array::( + strings, ®ex, group_idx, + ))) + } + other => exec_err!( + "regexp_extract expects Utf8 or LargeUtf8 subject, got {:?}", + other + ), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(s)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(s)) => match s { + None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + extract_one(s, ®ex, group_idx), + )))), + }, + _ => exec_err!("regexp_extract subject must be a string"), + } +} + +fn extract_array( + array: &GenericStringArray, + regex: &Regex, + group_idx: usize, +) -> ArrayRef { + let mut builder = GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(extract_one(array.value(i), regex, group_idx)); + } + } + Arc::new(builder.finish()) +} + +fn extract_one(input: &str, regex: &Regex, group_idx: usize) -> String { + match regex.captures(input) { + Some(caps) => caps + .get(group_idx) + .map(|m| m.as_str().to_string()) + .unwrap_or_default(), + None => String::new(), + } +} + +fn subject_len(value: &ColumnarValue) -> Option { + match value { + ColumnarValue::Array(a) => Some(a.len()), + ColumnarValue::Scalar(_) => None, + } +} + +fn null_result(len: Option) -> ColumnarValue { + match len { + Some(n) => { + let mut builder = GenericStringBuilder::::with_capacity(n, 0); + for _ in 0..n { + builder.append_null(); + } + ColumnarValue::Array(Arc::new(builder.finish())) + } + None => ColumnarValue::Scalar(ScalarValue::Utf8(None)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + fn run(args: Vec) -> DataFusionResult>> { + let result = spark_regexp_extract(&args)?; + match result { + ColumnarValue::Array(arr) => { + let s = arr + .as_any() + .downcast_ref::>() + .expect("expected Utf8 array"); + Ok((0..s.len()) + .map(|i| { + if s.is_null(i) { + None + } else { + Some(s.value(i).to_string()) + } + }) + .collect()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(v)) => Ok(vec![v]), + other => panic!("unexpected result: {other:?}"), + } + } + + fn array(values: Vec>) -> ColumnarValue { + ColumnarValue::Array(Arc::new(StringArray::from(values))) + } + + fn pattern(p: &str) -> ColumnarValue { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(p.to_string()))) + } + + fn idx(i: i32) -> ColumnarValue { + ColumnarValue::Scalar(ScalarValue::Int32(Some(i))) + } + + #[test] + fn basic_group_extraction() { + let result = run(vec![ + array(vec![Some("100-200"), Some("foo-bar"), Some("nodelim")]), + pattern(r"(\d+)-(\d+)"), + idx(1), + ]) + .unwrap(); + assert_eq!( + result, + vec![ + Some("100".to_string()), + Some(String::new()), + Some(String::new()), + ] + ); + } + + #[test] + fn idx_zero_returns_whole_match() { + let result = run(vec![ + array(vec![Some("abc123def456")]), + pattern(r"\d+"), + idx(0), + ]) + .unwrap(); + assert_eq!(result, vec![Some("123".to_string())]); + } + + #[test] + fn default_idx_is_one() { + let result = run(vec![array(vec![Some("100-200")]), pattern(r"(\d+)-(\d+)")]).unwrap(); + assert_eq!(result, vec![Some("100".to_string())]); + } + + #[test] + fn null_subject_returns_null() { + let result = run(vec![ + array(vec![Some("a1b"), None, Some("c2d")]), + pattern(r"(\d)"), + idx(1), + ]) + .unwrap(); + assert_eq!( + result, + vec![Some("1".to_string()), None, Some("2".to_string())] + ); + } + + #[test] + fn null_pattern_returns_null() { + let result = run(vec![ + array(vec![Some("abc")]), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + idx(1), + ]) + .unwrap(); + assert_eq!(result, vec![None]); + } + + #[test] + fn unmatched_optional_group_returns_empty_string() { + let result = run(vec![ + array(vec![Some("foo")]), + pattern(r"(foo)(bar)?"), + idx(2), + ]) + .unwrap(); + assert_eq!(result, vec![Some(String::new())]); + } + + #[test] + fn group_index_out_of_range_errors() { + let err = spark_regexp_extract(&[ + array(vec![Some("abc")]), + pattern(r"(a)(b)"), + idx(3), + ]) + .err() + .unwrap(); + assert!(err.to_string().contains("group count")); + } + + #[test] + fn negative_index_errors() { + let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)"), idx(-1)]) + .err() + .unwrap(); + assert!(err.to_string().contains("non-negative")); + } + + #[test] + fn invalid_regex_errors() { + let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(unclosed"), idx(0)]) + .err() + .unwrap(); + assert!(err.to_string().contains("Invalid regex")); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 448c2c2cb3..515ca35525 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -170,6 +170,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 968fe8cd69..cde6fedef3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -350,6 +350,40 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { } } +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + + override def getIncompatibleReasons(): Seq[String] = Seq( + "Uses Rust regexp engine, which has different behavior to Java regexp engine") + + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + if (!expr.regexp.isInstanceOf[Literal]) { + return Unsupported(Some("Only scalar regexp patterns are supported")) + } + if (!expr.idx.isInstanceOf[Literal]) { + return Unsupported(Some("idx must be an integer literal")) + } + Incompatible( + Some("Uses Rust regexp engine, which has different behavior to Java regexp engine")) + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) + val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) + val idxExpr = exprToProtoInternal(expr.idx, inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "regexp_extract", + expr.dataType, + failOnError = true, + subjectExpr, + patternExpr, + idxExpr) + optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.idx) + } +} + object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { override def getIncompatibleReasons(): Seq[String] = Seq( "Regexp pattern may not be compatible with Spark") diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql new file mode 100644 index 0000000000..6c125b27d0 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql @@ -0,0 +1,35 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract default behaviour: Comet marks the expression Incompatible +-- (Rust regex engine differs from Java) and should fall back to Spark unless the user +-- opts in via spark.comet.expression.RegExpExtract.allowIncompatible=true. + +statement +CREATE TABLE test_regexp_extract(s string) USING parquet + +statement +INSERT INTO test_regexp_extract VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890') + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 1) FROM test_regexp_extract + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract + +query expect_fallback(Rust regexp engine) +SELECT regexp_extract(s, '(\\d+)-(\\d+)') FROM test_regexp_extract diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql new file mode 100644 index 0000000000..70a371e132 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql @@ -0,0 +1,73 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract() with the per-expression allowIncompatible flag enabled (happy path). +-- Config: spark.comet.expression.RegExpExtract.allowIncompatible=true + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_regexp_extract_enabled(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_enabled VALUES + ('100-200'), + ('foo-bar'), + ('nodelim'), + ('12-34-56'), + (''), + (NULL), + ('phone 123-456-7890') + +-- group 1 of the first match +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 1) FROM test_regexp_extract_enabled + +-- group 2 of the first match +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract_enabled + +-- idx = 0 returns the entire match +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 0) FROM test_regexp_extract_enabled + +-- default idx (no third arg) is 1 +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)') FROM test_regexp_extract_enabled + +-- single-group match; no match should produce empty string, NULL input -> NULL +query +SELECT regexp_extract(s, '(\\d+)', 1) FROM test_regexp_extract_enabled + +-- optional unmatched group should return empty string +query +SELECT regexp_extract(s, '(\\w+)( \\d+)?', 2) FROM test_regexp_extract_enabled + +-- anchors and character classes +query +SELECT regexp_extract(s, '^(\\w+)', 1) FROM test_regexp_extract_enabled + +query +SELECT regexp_extract(s, '(\\d+)$', 1) FROM test_regexp_extract_enabled + +-- literal arguments +query +SELECT + regexp_extract('alice@example.com', '^([\\w.+-]+)@([\\w.-]+)$', 1), + regexp_extract('alice@example.com', '^([\\w.+-]+)@([\\w.-]+)$', 2), + regexp_extract('not-an-email', '^([\\w.+-]+)@([\\w.-]+)$', 1), + regexp_extract(NULL, '(\\d+)', 1) From ba762330e62d86025d485350c97067ab449744b1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 29 Apr 2026 16:13:52 -0600 Subject: [PATCH 2/3] test: cover error and unicode cases for `regexp_extract` Audit follow-ups: - Align Rust error messages with Spark's `INVALID_PARAMETER_VALUE` templates so `expect_error` substrings can match both engines. - Override `getUnsupportedReasons` in `CometRegExpExtract` so the non-literal pattern and non-literal idx reasons are picked up by the Compatibility Guide generator. - Add Comet SQL test cases for: NULL pattern and NULL idx, idx=0 with no capture groups, multibyte / Unicode subjects, idx out of range, pattern with no groups + idx>=1, negative idx, invalid regex syntax, and a Java-only lookahead that Rust regex rejects (marked `ignore`). - Add fallback test cases for non-literal pattern and non-literal idx. - Mark the expression supported in `spark_expressions_support.md` with per-version audit notes. --- .../spark_expressions_support.md | 5 +- .../src/string_funcs/regexp_extract.rs | 34 +++++----- .../org/apache/comet/serde/strings.scala | 20 ++++-- .../expressions/string/regexp_extract.sql | 13 ++++ .../string/regexp_extract_enabled.sql | 62 +++++++++++++++++++ 5 files changed, 110 insertions(+), 24 deletions(-) diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 1e4b4e34bc..36b0680582 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -439,7 +439,10 @@ - [ ] position - [ ] printf - [ ] regexp_count -- [ ] regexp_extract +- [x] regexp_extract + - Spark 3.4.3 audited 2026-04-29 (Incompatible: Rust regex engine differs from Java; `idx` out-of-range check happens at compile time in Comet vs per-row in Spark) + - Spark 3.5.8 audited 2026-04-29 (same as 3.4.3) + - Spark 4.0.1 audited 2026-04-29 (collation support added in Spark; Comet does not honour `UTF8_LCASE` and runs case-sensitively) - [ ] regexp_extract_all - [ ] regexp_instr - [ ] regexp_replace diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 7364ef72a9..7c5546bd78 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -56,10 +56,6 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult p.clone(), @@ -73,13 +69,16 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult group_count { + if idx < 0 || idx > group_count { return Err(DataFusionError::Execution(format!( - "Regex group count is {group_count}, but the specified group index is {idx}" + "The value of parameter `idx` in `regexp_extract` is invalid: \ + Expects group index between 0 and {group_count}, but got {idx}." ))); } let group_idx = idx as usize; @@ -273,14 +272,13 @@ mod tests { #[test] fn group_index_out_of_range_errors() { - let err = spark_regexp_extract(&[ - array(vec![Some("abc")]), - pattern(r"(a)(b)"), - idx(3), - ]) - .err() - .unwrap(); - assert!(err.to_string().contains("group count")); + let err = + spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) + .err() + .unwrap(); + let msg = err.to_string(); + assert!(msg.contains("group index"), "{msg}"); + assert!(msg.contains("but got 3"), "{msg}"); } #[test] @@ -288,7 +286,9 @@ mod tests { let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)"), idx(-1)]) .err() .unwrap(); - assert!(err.to_string().contains("non-negative")); + let msg = err.to_string(); + assert!(msg.contains("group index"), "{msg}"); + assert!(msg.contains("but got -1"), "{msg}"); } #[test] @@ -296,6 +296,6 @@ mod tests { let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(unclosed"), idx(0)]) .err() .unwrap(); - assert!(err.to_string().contains("Invalid regex")); + assert!(err.to_string().contains("`regexp`")); } } diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index cde6fedef3..0407e59840 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -352,18 +352,26 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Uses Rust regexp engine, which has different behavior to Java regexp engine") + private val incompatReason: String = + "Uses Rust regexp engine, which has different behavior to Java regexp engine" + private val nonLiteralPatternReason: String = + "Only scalar regexp patterns are supported" + private val nonLiteralIdxReason: String = + "idx must be an integer literal" + + override def getIncompatibleReasons(): Seq[String] = Seq(incompatReason) + + override def getUnsupportedReasons(): Seq[String] = + Seq(nonLiteralPatternReason, nonLiteralIdxReason) override def getSupportLevel(expr: RegExpExtract): SupportLevel = { if (!expr.regexp.isInstanceOf[Literal]) { - return Unsupported(Some("Only scalar regexp patterns are supported")) + return Unsupported(Some(nonLiteralPatternReason)) } if (!expr.idx.isInstanceOf[Literal]) { - return Unsupported(Some("idx must be an integer literal")) + return Unsupported(Some(nonLiteralIdxReason)) } - Incompatible( - Some("Uses Rust regexp engine, which has different behavior to Java regexp engine")) + Incompatible(Some(incompatReason)) } override def convert( diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql index 6c125b27d0..ef4ac8aa78 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql @@ -33,3 +33,16 @@ SELECT regexp_extract(s, '(\\d+)-(\\d+)', 2) FROM test_regexp_extract query expect_fallback(Rust regexp engine) SELECT regexp_extract(s, '(\\d+)-(\\d+)') FROM test_regexp_extract + +-- Non-literal pattern: Comet falls back regardless of the allowIncompatible flag. +statement +CREATE TABLE test_regexp_extract_nonliteral(s string, p string, i int) USING parquet + +statement +INSERT INTO test_regexp_extract_nonliteral VALUES ('abc', '(a)(b)', 1), ('xyz', '(x)', 1) + +query expect_fallback(Only scalar regexp patterns) +SELECT regexp_extract(s, p, 1) FROM test_regexp_extract_nonliteral + +query expect_fallback(idx must be an integer literal) +SELECT regexp_extract(s, '(\\w+)', i) FROM test_regexp_extract_nonliteral diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql index 70a371e132..6e30d312c0 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_enabled.sql @@ -71,3 +71,65 @@ SELECT regexp_extract('alice@example.com', '^([\\w.+-]+)@([\\w.-]+)$', 2), regexp_extract('not-an-email', '^([\\w.+-]+)@([\\w.-]+)$', 1), regexp_extract(NULL, '(\\d+)', 1) + +-- NULL pattern propagates as NULL (Spark and Comet both return NULL) +query +SELECT regexp_extract(s, CAST(NULL AS STRING), 1) FROM test_regexp_extract_enabled + +-- NULL idx propagates as NULL +query +SELECT regexp_extract(s, '(\\d+)-(\\d+)', CAST(NULL AS INT)) FROM test_regexp_extract_enabled + +-- idx = 0 with no capture groups returns the whole match +query +SELECT regexp_extract(s, '\\d+', 0) FROM test_regexp_extract_enabled + +-- multibyte / Unicode subject +statement +CREATE TABLE test_regexp_extract_unicode(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_unicode VALUES + ('café=42'), + ('café=99'), + ('世界=1'), + ('日本=東京'), + ('🔥=hot'), + ('मानक=हिन्दी') + +-- ASCII anchors and capture groups against multibyte data +query +SELECT regexp_extract(s, '^(.+)=(.+)$', 1) FROM test_regexp_extract_unicode + +query +SELECT regexp_extract(s, '^(.+)=(.+)$', 2) FROM test_regexp_extract_unicode + +-- digit class against multibyte data +query +SELECT regexp_extract(s, '=(\\d+)$', 1) FROM test_regexp_extract_unicode + +-- ERROR CASES +-- idx > groupCount (pattern has 2 groups, ask for 3) +query expect_error(group index) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', 3) FROM test_regexp_extract_enabled + +-- pattern with no capture groups but idx >= 1 +query expect_error(group index) +SELECT regexp_extract(s, '\\d+', 1) FROM test_regexp_extract_enabled + +-- negative idx +query expect_error(group index) +SELECT regexp_extract(s, '(\\d+)-(\\d+)', -1) FROM test_regexp_extract_enabled + +-- invalid regex syntax (unclosed group): both engines fail at pattern compile time. +-- Spark surfaces INVALID_PARAMETER_VALUE.PATTERN, Comet surfaces a regex parse error. +-- Both messages mention `regexp_extract`. +query expect_error(regexp_extract) +SELECT regexp_extract(s, '(unclosed', 1) FROM test_regexp_extract_enabled + +-- Java-only regex feature: lookahead. Rust regex rejects this at compile time; +-- Spark accepts it and returns "" for every row. This is one of the documented +-- incompatibilities behind the Incompatible support level, not an invariant we +-- test for cross-engine equivalence. +query ignore(Rust regex does not support lookahead, unlike Java regex) +SELECT regexp_extract(s, '(?=\\d)\\w+', 0) FROM test_regexp_extract_enabled From a143cf1e6ab65bfc2dd620ca47929535da1c3a5b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 29 Apr 2026 16:21:28 -0600 Subject: [PATCH 3/3] refactor: simplify `regexp_extract` Rust UDF Address review feedback: - Make `extract_array` build a `GenericStringBuilder` matching the input offset size so a `LargeUtf8` subject no longer silently outputs `Utf8` (avoids potential i32-offset overflow on >2GB inputs). - Inline group extraction so the per-row `String` allocation is gone; the only remaining `to_string` is on the rare scalar code path. - Replace the manual append-null loop in `null_result` with `StringArray::new_null(n)`. - Borrow the pattern as `&str` instead of cloning it before calling `Regex::new`. - Pass `failOnError = false` to the proto, matching `CometStringSplit`. The Rust UDF does not branch on this flag, so `true` was misleading. --- .../src/string_funcs/regexp_extract.rs | 42 +++++++++---------- .../org/apache/comet/serde/strings.scala | 2 +- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/native/spark-expr/src/string_funcs/regexp_extract.rs b/native/spark-expr/src/string_funcs/regexp_extract.rs index 7c5546bd78..7ac3473adb 100644 --- a/native/spark-expr/src/string_funcs/regexp_extract.rs +++ b/native/spark-expr/src/string_funcs/regexp_extract.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, GenericStringArray, GenericStringBuilder}; +use arrow::array::{ + Array, ArrayRef, GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray, +}; use arrow::datatypes::DataType; use datafusion::common::{ cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult, @@ -56,9 +58,9 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult p.clone(), + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(p))) => p, ColumnarValue::Scalar(ScalarValue::Utf8(None)) | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { return Ok(null_result(subject_len(&args[0]))); @@ -68,7 +70,7 @@ pub fn spark_regexp_extract(args: &[ColumnarValue]) -> DataFusionResult DataFusionResult match s { None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - extract_one(s, ®ex, group_idx), - )))), + Some(s) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(extract_one( + s, ®ex, group_idx, + ))))), }, _ => exec_err!("regexp_extract subject must be a string"), } } -fn extract_array( +fn extract_array( array: &GenericStringArray, regex: &Regex, group_idx: usize, ) -> ArrayRef { - let mut builder = GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); + let mut builder = + GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); for i in 0..array.len() { if array.is_null(i) { builder.append_null(); } else { - builder.append_value(extract_one(array.value(i), regex, group_idx)); + let extracted = match regex.captures(array.value(i)) { + Some(caps) => caps.get(group_idx).map(|m| m.as_str()).unwrap_or(""), + None => "", + }; + builder.append_value(extracted); } } Arc::new(builder.finish()) @@ -148,13 +155,7 @@ fn subject_len(value: &ColumnarValue) -> Option { fn null_result(len: Option) -> ColumnarValue { match len { - Some(n) => { - let mut builder = GenericStringBuilder::::with_capacity(n, 0); - for _ in 0..n { - builder.append_null(); - } - ColumnarValue::Array(Arc::new(builder.finish())) - } + Some(n) => ColumnarValue::Array(Arc::new(StringArray::new_null(n))), None => ColumnarValue::Scalar(ScalarValue::Utf8(None)), } } @@ -272,10 +273,9 @@ mod tests { #[test] fn group_index_out_of_range_errors() { - let err = - spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) - .err() - .unwrap(); + let err = spark_regexp_extract(&[array(vec![Some("abc")]), pattern(r"(a)(b)"), idx(3)]) + .err() + .unwrap(); let msg = err.to_string(); assert!(msg.contains("group index"), "{msg}"); assert!(msg.contains("but got 3"), "{msg}"); diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 0407e59840..322977378c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -384,7 +384,7 @@ object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { val optExpr = scalarFunctionExprToProtoWithReturnType( "regexp_extract", expr.dataType, - failOnError = true, + failOnError = false, subjectExpr, patternExpr, idxExpr)