diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index e4abd13abc2..d513173d76f 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -259,6 +259,15 @@ Result IsValid(const Datum& values, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result IsNull(const Datum& values, ExecContext* ctx = NULLPTR); +// ---------------------------------------------------------------------- +// String functions + +struct ARROW_EXPORT BinaryContainsExactOptions : public FunctionOptions { + explicit BinaryContainsExactOptions(std::string pattern) : pattern(pattern) {} + + std::string pattern; +}; + // ---------------------------------------------------------------------- // Temporal functions diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index b261743aab0..b7d2fee1a3a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -297,6 +297,118 @@ void AddAsciiLength(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } +// ---------------------------------------------------------------------- +// exact pattern detection + +using StrToBoolTransformFunc = + std::function; + +// Apply `transform` to input character data- this function cannot change the +// length +template +void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, + StrToBoolTransformFunc transform, Datum* out) { + using offset_type = typename Type::offset_type; + + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& input = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + if (input.length > 0) { + transform( + reinterpret_cast(input.buffers[1]->data()) + input.offset, + input.buffers[2]->data(), input.length, out_arr->offset, + out_arr->buffers[1]->mutable_data()); + } + } else { + const auto& input = checked_cast(*batch[0].scalar()); + if (input.is_valid) { + auto result = checked_pointer_cast(MakeNullScalar(out->type())); + uint8_t result_value = 0; + result->is_valid = true; + std::array offsets{0, + static_cast(input.value->size())}; + transform(offsets.data(), input.value->data(), 1, /*output_offset=*/0, + &result_value); + out->value = std::make_shared(result_value > 0); + } + } +} + +template +void TransformBinaryContainsExact(const uint8_t* pattern, int64_t pattern_length, + const offset_type* offsets, const uint8_t* data, + int64_t length, int64_t output_offset, + uint8_t* output) { + // This is an implementation of the Knuth-Morris-Pratt algorithm + + // Phase 1: Build the prefix table + std::vector prefix_table(pattern_length + 1); + offset_type prefix_length = -1; + prefix_table[0] = -1; + for (offset_type pos = 0; pos < pattern_length; ++pos) { + // The prefix cannot be expanded, reset. + if (prefix_length >= 0 && pattern[pos] != pattern[prefix_length]) { + prefix_length = prefix_table[prefix_length]; + } + prefix_length++; + prefix_table[pos + 1] = prefix_length; + } + + // Phase 2: Find the prefix in the data + FirstTimeBitmapWriter bitmap_writer(output, output_offset, length); + for (int64_t i = 0; i < length; ++i) { + const uint8_t* current_data = data + offsets[i]; + int64_t current_length = offsets[i + 1] - offsets[i]; + + int64_t pattern_pos = 0; + for (int64_t k = 0; k < current_length; k++) { + if (pattern[pattern_pos] == current_data[k]) { + pattern_pos++; + if (pattern_pos == pattern_length) { + bitmap_writer.Set(); + break; + } + } else { + pattern_pos = std::max(0, prefix_table[pattern_pos]); + } + } + bitmap_writer.Next(); + } + bitmap_writer.Finish(); +} + +using BinaryContainsExactState = OptionsWrapper; + +template +struct BinaryContainsExact { + using offset_type = typename Type::offset_type; + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + BinaryContainsExactOptions arg = BinaryContainsExactState::Get(ctx); + const uint8_t* pat = reinterpret_cast(arg.pattern.c_str()); + const int64_t pat_size = arg.pattern.length(); + StringBoolTransform( + ctx, batch, + [pat, pat_size](const void* offsets, const uint8_t* data, int64_t length, + int64_t output_offset, uint8_t* output) { + TransformBinaryContainsExact( + pat, pat_size, reinterpret_cast(offsets), data, length, + output_offset, output); + }, + out); + } +}; + +void AddBinaryContainsExact(FunctionRegistry* registry) { + auto func = std::make_shared("binary_contains_exact", Arity::Unary()); + auto exec_32 = BinaryContainsExact::Exec; + auto exec_64 = BinaryContainsExact::Exec; + DCHECK_OK( + func->AddKernel({utf8()}, boolean(), exec_32, BinaryContainsExactState::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, boolean(), exec_64, + BinaryContainsExactState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + // ---------------------------------------------------------------------- // strptime string parsing @@ -377,6 +489,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringUtf8TransformKernel("utf8_lower", registry); #endif AddAsciiLength(registry); + AddBinaryContainsExact(registry); AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index fbfd2352a70..46ee129b03c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -28,7 +28,8 @@ namespace compute { constexpr auto kSeed = 0x94378165; -static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name) { +static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name, + const FunctionOptions* options = nullptr) { const int64_t array_length = 1 << 20; const int64_t value_min_size = 0; const int64_t value_max_size = 32; @@ -39,10 +40,10 @@ static void UnaryStringBenchmark(benchmark::State& state, const std::string& fun auto values = rng.String(array_length, value_min_size, value_max_size, null_probability); // Make sure lookup tables are initialized before measuring - ABORT_NOT_OK(CallFunction(func_name, {values})); + ABORT_NOT_OK(CallFunction(func_name, {values}, options)); for (auto _ : state) { - ABORT_NOT_OK(CallFunction(func_name, {values})); + ABORT_NOT_OK(CallFunction(func_name, {values}, options)); } state.SetItemsProcessed(state.iterations() * array_length); state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size()); @@ -56,6 +57,11 @@ static void AsciiUpper(benchmark::State& state) { UnaryStringBenchmark(state, "ascii_upper"); } +static void BinaryContainsExact(benchmark::State& state) { + BinaryContainsExactOptions options("abac"); + UnaryStringBenchmark(state, "binary_contains_exact", &options); +} + #ifdef ARROW_WITH_UTF8PROC static void Utf8Upper(benchmark::State& state) { UnaryStringBenchmark(state, "utf8_upper"); @@ -68,6 +74,7 @@ static void Utf8Lower(benchmark::State& state) { BENCHMARK(AsciiLower); BENCHMARK(AsciiUpper); +BENCHMARK(BinaryContainsExact); #ifdef ARROW_WITH_UTF8PROC BENCHMARK(Utf8Lower); BENCHMARK(Utf8Upper); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 9e5ba1ffaa7..0989401d034 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -147,6 +147,17 @@ TYPED_TEST(TestStringKernels, Utf8Lower) { #endif // ARROW_WITH_UTF8PROC +TYPED_TEST(TestStringKernels, BinaryContainsExact) { + BinaryContainsExactOptions options{"ab"}; + this->CheckUnary("binary_contains_exact", "[]", boolean(), "[]", &options); + this->CheckUnary("binary_contains_exact", R"(["abc", "acb", "cab", null, "bac"])", + boolean(), "[true, false, true, null, false]", &options); + + BinaryContainsExactOptions options_repeated{"abab"}; + this->CheckUnary("binary_contains_exact", R"(["abab", "ab", "cababc", null, "bac"])", + boolean(), "[true, false, true, null, false]", &options_repeated); +} + TYPED_TEST(TestStringKernels, Strptime) { std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0b8cbb5c955..201a59ec763 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -398,6 +398,18 @@ cdef class CastOptions(FunctionOptions): self.options.allow_invalid_utf8 = flag +cdef class BinaryContainsExactOptions(FunctionOptions): + cdef: + unique_ptr[CBinaryContainsExactOptions] binary_contains_exact_options + + def __init__(self, pattern): + self.binary_contains_exact_options.reset( + new CBinaryContainsExactOptions(tobytes(pattern))) + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.binary_contains_exact_options.get() + + cdef class FilterOptions(FunctionOptions): cdef: CFilterOptions filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 7a92a158812..babab27e58f 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -113,6 +113,24 @@ def func(left, right): multiply = _simple_binary_function('multiply') +def binary_contains_exact(array, pattern): + """ + Test if pattern is contained within a value of a binary array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("binary_contains_exact", [array], + _pc.BinaryContainsExactOptions(pattern)) + + def sum(array): """ Sum the values in a numerical (chunked) array. diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index c4d992aab25..e1e28e09da3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1566,6 +1566,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CFunctionRegistry* GetFunctionRegistry() + cdef cppclass CBinaryContainsExactOptions \ + "arrow::compute::BinaryContainsExactOptions"(CFunctionOptions): + CBinaryContainsExactOptions(c_string pattern) + c_string pattern + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index cce46a8fe53..9854748e8bd 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -90,6 +90,13 @@ def test_sum_chunked_array(arrow_type): assert pc.sum(arr) == None # noqa: E711 +def test_binary_contains_exact(): + arr = pa.array(["ab", "abc", "ba", None]) + result = pc.binary_contains_exact(arr, "ab") + expected = pa.array([True, True, False, None]) + assert expected.equals(result) + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty)