diff --git a/ci/conda_env_cpp.yml b/ci/conda_env_cpp.yml index 90cef3ea2d1..4388df4237b 100644 --- a/ci/conda_env_cpp.yml +++ b/ci/conda_env_cpp.yml @@ -35,6 +35,7 @@ ninja pkg-config python rapidjson +re2 snappy thrift-cpp>=0.11.0 zlib diff --git a/ci/conda_env_gandiva.yml b/ci/conda_env_gandiva.yml index 5056456fc66..22c70a32e5e 100644 --- a/ci/conda_env_gandiva.yml +++ b/ci/conda_env_gandiva.yml @@ -17,4 +17,3 @@ clangdev=11 llvmdev=11 -re2 diff --git a/ci/conda_env_gandiva_win.yml b/ci/conda_env_gandiva_win.yml index 49b3b8c1de1..9098b53d1f5 100644 --- a/ci/conda_env_gandiva_win.yml +++ b/ci/conda_env_gandiva_win.yml @@ -18,4 +18,3 @@ # llvmdev=9 or later require Visual Studio 2017 clangdev=8 llvmdev=8 -re2 diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index f9ab1548fbd..e12d8b5744d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -330,6 +330,10 @@ if(ARROW_BUILD_BENCHMARKS set(ARROW_TESTING ON) endif() +if(ARROW_GANDIVA) + set(ARROW_WITH_RE2 ON) +endif() + if(ARROW_CUDA OR ARROW_FLIGHT OR ARROW_PARQUET @@ -746,6 +750,14 @@ if(ARROW_WITH_UTF8PROC) endif() endif() +if(ARROW_WITH_RE2) + list(APPEND ARROW_LINK_LIBS RE2::re2) + list(APPEND ARROW_STATIC_LINK_LIBS RE2::re2) + if(utf8proc_SOURCE STREQUAL "SYSTEM") + list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS RE2::re2) + endif() +endif() + add_custom_target(arrow_dependencies) add_custom_target(arrow_benchmark_dependencies) add_custom_target(arrow_test_dependencies) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index a68c3a92cc7..43fa9e88b2c 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -363,6 +363,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_WITH_UTF8PROC "Build with support for Unicode properties using the utf8proc library" ON) + define_option(ARROW_WITH_RE2 + "Build with support for regular expressions using the re2 library" ON) #---------------------------------------------------------------------- if(MSVC_TOOLCHAIN) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 22531fcfc57..a0a27f17998 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -257,6 +257,9 @@ if(NOT ARROW_COMPUTE) # utf8proc is only potentially used in kernels for now set(ARROW_WITH_UTF8PROC OFF) endif() +if((NOT ARROW_COMPUTE) AND (NOT ARROW_GANDIVA)) + set(ARROW_WITH_RE2 OFF) +endif() # ---------------------------------------------------------------------- # Versions and URLs for toolchain builds, which also can be used to configure @@ -2090,8 +2093,9 @@ macro(build_re2) list(APPEND ARROW_BUNDLED_STATIC_LIBS RE2::re2) endmacro() -if(ARROW_GANDIVA) +if(ARROW_WITH_RE2) resolve_dependency(RE2) + add_definitions(-DARROW_WITH_RE2) # TODO: Don't use global includes but rather target_include_directories get_target_property(RE2_INCLUDE_DIR RE2::re2 INTERFACE_INCLUDE_DIRECTORIES) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 789ac909ccf..8189ad8c5b4 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -68,6 +68,13 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { std::string pattern; }; +struct ARROW_EXPORT RE2Options : public FunctionOptions { + explicit RE2Options(std::string regex) : regex(regex) {} + + /// Regular expression + std::string regex; +}; + /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { explicit SetLookupOptions(Datum value_set, bool skip_nulls) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index eff60b80481..ce5dfaab3d6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -23,9 +23,12 @@ #include #endif +#include #include "arrow/array/builder_binary.h" #include "arrow/array/builder_nested.h" #include "arrow/buffer_builder.h" + +#include "arrow/builder.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/util/utf8.h" @@ -1194,6 +1197,121 @@ void AddSplit(FunctionRegistry* registry) { #endif } +// ---------------------------------------------------------------------- +// re2 regex + +template +struct ExtractRE2 { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using BuilderType = typename TypeTraits::BuilderType; + using State = OptionsWrapper; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RE2Options options = State::Get(ctx); + RE2 regex(options.regex); + + if (!regex.ok()) { + ctx->SetStatus(Status::Invalid("Regular expression error")); + return; + } + std::vector> fields; + int group_count = regex.NumberOfCapturingGroups(); + fields.reserve(group_count); + const std::map name_map = regex.CapturingGroupNames(); + + // We need to pass RE2 a Args* array, which all point to a std::string + std::vector found_values(group_count); + std::vector args; + std::vector args_pointers; + args.reserve(group_count); + args_pointers.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + auto item = name_map.find(i + 1); // re2 starts counting from 1 + if (item == name_map.end()) { + ctx->SetStatus(Status::Invalid("Regular expression contains unnamed groups")); + return; + } + fields.emplace_back(new Field(item->second, batch[0].type())); + args.emplace_back(&found_values[i]); + // since we reserved capacity, we're guaranteed std::vector does not reallocate + // (which would cause the pointer to be invalid) + args_pointers.push_back(&args[i]); + } + auto type = struct_(fields); + + if (batch[0].kind() == Datum::ARRAY) { + std::unique_ptr array_builder_tmp; + MakeBuilder(ctx->memory_pool(), type, &array_builder_tmp); + std::shared_ptr struct_builder; + struct_builder.reset(checked_cast(array_builder_tmp.release())); + + const ArrayData& input = *batch[0].array(); + KERNEL_RETURN_IF_ERROR( + ctx, + VisitArrayDataInline( + input, + [&](util::string_view s) { + re2::StringPiece piece(s.data(), s.length()); + if (re2::RE2::FullMatchN(piece, regex, &args_pointers[0], group_count)) { + for (int i = 0; i < group_count; i++) { + BuilderType* builder = + static_cast(struct_builder->field_builder(i)); + RETURN_NOT_OK(builder->Append(found_values[i])); + } + RETURN_NOT_OK(struct_builder->Append()); + } else { + RETURN_NOT_OK(struct_builder->AppendNull()); + } + return Status::OK(); + }, + [&]() { + RETURN_NOT_OK(struct_builder->AppendNull()); + return Status::OK(); + })); + std::shared_ptr struct_array = + std::make_shared(out->array()); + KERNEL_RETURN_IF_ERROR(ctx, struct_builder->Finish(&struct_array)); + ArrayData* output = out->mutable_array(); + output->type = type; + output->child_data = struct_array->data()->child_data; + + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = std::make_shared(type); + if (input.is_valid) { + util::string_view s = static_cast(*input.value); + re2::StringPiece piece(s.data(), s.length()); + if (re2::RE2::FullMatchN(piece, regex, &args_pointers[0], group_count)) { + for (int i = 0; i < group_count; i++) { + result->value.push_back(std::make_shared(found_values[i])); + } + result->is_valid = true; + } else { + result->is_valid = false; + } + } else { + result->is_valid = false; + } + out->value = result; + } + } +}; + +const FunctionDoc utf8_extract_re2_doc("Extract", ("Long.."), {"strings"}, "RE2Options"); + +void AddExtractRE2(FunctionRegistry* registry) { + auto func = std::make_shared("utf8_extract_re2", Arity::Unary(), + &utf8_extract_re2_doc); + using t32 = ExtractRE2; + using t64 = ExtractRE2; + DCHECK_OK(func->AddKernel({utf8()}, {struct_({})}, t32::Exec, t32::State::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, {struct_({})}, t64::Exec, t64::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} +void AddRE2(FunctionRegistry* registry) { AddExtractRE2(registry); } + // ---------------------------------------------------------------------- // strptime string parsing @@ -1496,6 +1614,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { #endif AddSplit(registry); + AddRE2(registry); AddBinaryLength(registry); AddMatchSubstring(registry); AddStrptime(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index c76b50ee176..014af200c81 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -416,6 +416,20 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) { &options_max); } +TYPED_TEST(TestStringKernels, ExtractRE2) { + RE2Options options{"(?P[ab])(?P\\d)"}; + auto type = struct_({field("letter", this->type()), field("digit", this->type())}); + // TODO: enable test when the following issue is fixed: + // https://issues.apache.org/jira/browse/ARROW-10208 + // this->CheckUnary( + // "utf8_extract_re2", R"(["a1", "b2", "c3", null])", type, + // R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}, null, + // null])", &options); + this->CheckUnary("utf8_extract_re2", R"(["a1", "b2"])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}])", + &options); +} + TYPED_TEST(TestStringKernels, Strptime) { std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index c1d3ac7e61b..e8ab64327fd 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -434,6 +434,20 @@ when a positive ``max_splits`` is given. (``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen as separator. +String extraction +~~~~~~~~~~~~~~~~~ + ++--------------------+------------+------------------------------------+---------------+----------------------------------------+ +| Function name | Arity | Input types | Output type | Options class | ++====================+============+====================================+===============+========================================+ +| utf8_extract_re2 | Unary | String-like | Struct (1) | :struct:`RE2Options` | ++--------------------+------------+------------------------------------+---------------+----------------------------------------+ + +* \(1) Extract substrings defined by a regular expression using the Google RE2 +library. Struct field names refer to the named groups, e.g. 'letter' and 'digit' +for following regular expression: '(?P[ab])(?P\\d)'. + + Structural transforms ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 84340229b16..721c67e8156 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -606,6 +606,18 @@ class MatchSubstringOptions(_MatchSubstringOptions): self._set_options(pattern) +cdef class RE2Options(FunctionOptions): + cdef: + unique_ptr[CRE2Options] match_substring_options + + def __init__(self, regex): + self.match_substring_options.reset( + new CRE2Options(tobytes(regex))) + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.match_substring_options.get() + + cdef class _FilterOptions(FunctionOptions): cdef: CFilterOptions filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index ddfd8057db2..95c27498c8b 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -36,6 +36,7 @@ MinMaxOptions, ModeOptions, PartitionNthOptions, + RE2Options, SetLookupOptions, StrptimeOptions, TakeOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 4d49532715f..06b71581541 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1716,6 +1716,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool reverse) c_string pattern + cdef cppclass CRE2Options \ + "arrow::compute::RE2Options"(CFunctionOptions): + CRE2Options(c_string regex) + c_string regex + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 981121a3672..edd6438cfa8 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -523,6 +523,12 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +def test_extract_re2(): + ar = pa.array(['a1', 'b2']) + struct = pc.utf8_extract_re2(ar, regex='(?P[ab])(?P\\d)') + assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, {'letter': 'b', 'digit': '2'}] + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty) diff --git a/testing b/testing index d6c4deb22c4..860376d4e58 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit d6c4deb22c4b4e9e3247a2f291046e3c671ad235 +Subproject commit 860376d4e586a3ac34ec93089889da624ead6c2a