From 4cf15209f9be5f41295a45acefae2aca3e32726d Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 30 Jun 2020 14:29:30 +0200 Subject: [PATCH 1/8] ARROW-9160: [C++] Implement contains for exact matches --- cpp/src/arrow/compute/api_scalar.h | 9 ++ .../arrow/compute/kernels/scalar_string.cc | 94 +++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 7 ++ 3 files changed, 110 insertions(+) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index e4abd13abc2..c330774717e 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -259,6 +259,15 @@ Result IsValid(const Datum& values, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result IsNull(const Datum& values, ExecContext* ctx = NULLPTR); +// ---------------------------------------------------------------------- +// String functions + +struct ARROW_EXPORT ContainsExactOptions : public FunctionOptions { + explicit ContainsExactOptions(std::string pattern) : pattern(pattern) {} + + std::string pattern; +}; + // ---------------------------------------------------------------------- // Temporal functions diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index b261743aab0..e257a2cc188 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -297,6 +297,99 @@ void AddAsciiLength(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } +// ---------------------------------------------------------------------- +// exact pattern detection + +template +using StrToBoolTransformFunc = + std::function; + +// Apply `transform` to input character data- this function cannot change the +// length +template +void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, + StrToBoolTransformFunc transform, + Datum* out) { + using ArrayType = typename TypeTraits::ArrayType; + using offset_type = typename Type::offset_type; + + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& input = *batch[0].array(); + ArrayType input_boxed(batch[0].array()); + + ArrayData* out_arr = out->mutable_array(); + + // Allocate space for output data + KERNEL_RETURN_IF_ERROR( + ctx, + ctx->Allocate(BitUtil::BytesForBits(input.length)).Value(&out_arr->buffers[1])); + if (input.length > 0) { + transform(reinterpret_cast(input.buffers[1]->data()), + input.buffers[2]->data(), input.length, + out_arr->buffers[1]->mutable_data()); + } + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = checked_pointer_cast(MakeNullScalar(out->type())); + uint8_t result_value = 0; + if (input.is_valid) { + result->is_valid = true; + KERNEL_RETURN_IF_ERROR(ctx, ctx->Allocate(1).Value(&result->value)); + std::array offsets{0, + static_cast(input.value->size())}; + transform(offsets.data(), input.value->data(), 1, &result_value); + out->value = std::make_shared(result_value > 0); + } + } +} + +template +void TransformContainsExact(const uint8_t* pattern, int64_t pattern_length, + const offset_type* offsets, const uint8_t* data, + int64_t length, uint8_t* output) { + FirstTimeBitmapWriter bitmap_writer(output, 0, length); + for (int64_t i = 0; i < length; ++i) { + int64_t current_length = offsets[i + 1] - offsets[i]; + + // Search for the pattern at every possible position + for (int64_t k = 0; k < (current_length - pattern_length + 1); k++) { + if (memcmp(pattern, data + offsets[i] + k, pattern_length) == 0) { + bitmap_writer.Set(); + break; + } + } + bitmap_writer.Next(); + } + bitmap_writer.Finish(); +} + +using ContainsExactState = OptionsWrapper; + +template +struct ContainsExact { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ContainsExactOptions arg = + checked_cast(*ctx->state()).options; + auto transform_func = + std::bind(TransformContainsExact, + reinterpret_cast(arg.pattern.c_str()), + arg.pattern.length(), std::placeholders::_1, std::placeholders::_2, + std::placeholders::_3, std::placeholders::_4); + + StringBoolTransform(ctx, batch, transform_func, out); + } +}; + +void AddContainsExact(FunctionRegistry* registry) { + auto func = std::make_shared("contains_exact", Arity::Unary()); + auto exec_32 = ContainsExact::Exec; + auto exec_64 = ContainsExact::Exec; + DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, ContainsExactState::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, boolean(), exec_64, ContainsExactState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + // ---------------------------------------------------------------------- // strptime string parsing @@ -377,6 +470,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringUtf8TransformKernel("utf8_lower", registry); #endif AddAsciiLength(registry); + AddContainsExact(registry); AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 9e5ba1ffaa7..f8021f434fd 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -147,6 +147,13 @@ TYPED_TEST(TestStringKernels, Utf8Lower) { #endif // ARROW_WITH_UTF8PROC +TYPED_TEST(TestStringKernels, ContainsExact) { + ContainsExactOptions options{"ab"}; + this->CheckUnary("contains_exact", "[]", boolean(), "[]", &options); + this->CheckUnary("contains_exact", R"(["abc", "acb", "cab", null, "bac"])", boolean(), + "[true, false, true, null, false]", &options); +} + TYPED_TEST(TestStringKernels, Strptime) { std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; From 9dfb2181a42ee09653a736599351e3a21e72d183 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 30 Jun 2020 15:01:40 +0200 Subject: [PATCH 2/8] Fix offsetted arrays --- cpp/src/arrow/compute/kernels/scalar_string.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index e257a2cc188..d0a0c451697 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -310,12 +310,10 @@ template void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, StrToBoolTransformFunc transform, Datum* out) { - using ArrayType = typename TypeTraits::ArrayType; using offset_type = typename Type::offset_type; if (batch[0].kind() == Datum::ARRAY) { const ArrayData& input = *batch[0].array(); - ArrayType input_boxed(batch[0].array()); ArrayData* out_arr = out->mutable_array(); @@ -324,9 +322,9 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, ctx, ctx->Allocate(BitUtil::BytesForBits(input.length)).Value(&out_arr->buffers[1])); if (input.length > 0) { - transform(reinterpret_cast(input.buffers[1]->data()), - input.buffers[2]->data(), input.length, - out_arr->buffers[1]->mutable_data()); + transform( + reinterpret_cast(input.buffers[1]->data()) + input.offset, + input.buffers[2]->data(), input.length, out_arr->buffers[1]->mutable_data()); } } else { const auto& input = checked_cast(*batch[0].scalar()); From 1e31947b6c1459389924a257becb50c1b333353e Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 30 Jun 2020 19:30:47 +0200 Subject: [PATCH 3/8] Add benchmark, implement KMP --- .../arrow/compute/kernels/scalar_string.cc | 32 ++++++++++++++++--- .../kernels/scalar_string_benchmark.cc | 13 ++++++-- .../compute/kernels/scalar_string_test.cc | 4 +++ 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index d0a0c451697..dd16bc4ab2f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -345,15 +345,37 @@ template void TransformContainsExact(const uint8_t* pattern, int64_t pattern_length, const offset_type* offsets, const uint8_t* data, int64_t length, uint8_t* output) { + // This is an implementation of the Knuth-Morris-Pratt algorithm + + // Phase 1: Build the prefix table + std::vector prefix_table(pattern_length + 1); + offset_type prefix_length = -1; + prefix_table[0] = -1; + for (offset_type pos = 0; pos < pattern_length; ++pos) { + // The prefix cannot be expanded, reset. + if (prefix_length >= 0 && pattern[pos] != pattern[prefix_length]) { + prefix_length = prefix_table[prefix_length]; + } + prefix_length++; + prefix_table[pos + 1] = prefix_length; + } + + // Phase 2: Find the prefix in the data FirstTimeBitmapWriter bitmap_writer(output, 0, length); for (int64_t i = 0; i < length; ++i) { + const uint8_t* current_data = data + offsets[i]; int64_t current_length = offsets[i + 1] - offsets[i]; - // Search for the pattern at every possible position - for (int64_t k = 0; k < (current_length - pattern_length + 1); k++) { - if (memcmp(pattern, data + offsets[i] + k, pattern_length) == 0) { - bitmap_writer.Set(); - break; + int64_t pattern_pos = 0; + for (int64_t k = 0; k < current_length; k++) { + if (pattern[pattern_pos] == current_data[k]) { + pattern_pos++; + if (pattern_pos == pattern_length) { + bitmap_writer.Set(); + break; + } + } else { + pattern_pos = std::max(0, prefix_table[pattern_pos]); } } bitmap_writer.Next(); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index fbfd2352a70..a0842e65109 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -28,7 +28,8 @@ namespace compute { constexpr auto kSeed = 0x94378165; -static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name) { +static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name, + const FunctionOptions* options = nullptr) { const int64_t array_length = 1 << 20; const int64_t value_min_size = 0; const int64_t value_max_size = 32; @@ -39,10 +40,10 @@ static void UnaryStringBenchmark(benchmark::State& state, const std::string& fun auto values = rng.String(array_length, value_min_size, value_max_size, null_probability); // Make sure lookup tables are initialized before measuring - ABORT_NOT_OK(CallFunction(func_name, {values})); + ABORT_NOT_OK(CallFunction(func_name, {values}, options)); for (auto _ : state) { - ABORT_NOT_OK(CallFunction(func_name, {values})); + ABORT_NOT_OK(CallFunction(func_name, {values}, options)); } state.SetItemsProcessed(state.iterations() * array_length); state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size()); @@ -56,6 +57,11 @@ static void AsciiUpper(benchmark::State& state) { UnaryStringBenchmark(state, "ascii_upper"); } +static void ContainsExact(benchmark::State& state) { + ContainsExactOptions options("abac"); + UnaryStringBenchmark(state, "contains_exact", &options); +} + #ifdef ARROW_WITH_UTF8PROC static void Utf8Upper(benchmark::State& state) { UnaryStringBenchmark(state, "utf8_upper"); @@ -68,6 +74,7 @@ static void Utf8Lower(benchmark::State& state) { BENCHMARK(AsciiLower); BENCHMARK(AsciiUpper); +BENCHMARK(ContainsExact); #ifdef ARROW_WITH_UTF8PROC BENCHMARK(Utf8Lower); BENCHMARK(Utf8Upper); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index f8021f434fd..1053f525572 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -152,6 +152,10 @@ TYPED_TEST(TestStringKernels, ContainsExact) { this->CheckUnary("contains_exact", "[]", boolean(), "[]", &options); this->CheckUnary("contains_exact", R"(["abc", "acb", "cab", null, "bac"])", boolean(), "[true, false, true, null, false]", &options); + + ContainsExactOptions options_repeated{"abab"}; + this->CheckUnary("contains_exact", R"(["abab", "ab", "cababc", null, "bac"])", + boolean(), "[true, false, true, null, false]", &options_repeated); } TYPED_TEST(TestStringKernels, Strptime) { From 2a883791df960ff9cea163f68f743fb596044674 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Tue, 30 Jun 2020 20:59:46 +0200 Subject: [PATCH 4/8] Expose kernel in Python --- cpp/src/arrow/compute/api_scalar.h | 2 +- python/pyarrow/_compute.pyx | 11 +++++++++++ python/pyarrow/compute.py | 18 ++++++++++++++++++ python/pyarrow/includes/libarrow.pxd | 5 +++++ python/pyarrow/tests/test_compute.py | 7 +++++++ 5 files changed, 42 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index c330774717e..f7917766ff2 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -263,7 +263,7 @@ Result IsNull(const Datum& values, ExecContext* ctx = NULLPTR); // String functions struct ARROW_EXPORT ContainsExactOptions : public FunctionOptions { - explicit ContainsExactOptions(std::string pattern) : pattern(pattern) {} + explicit ContainsExactOptions(std::string pattern = "") : pattern(pattern) {} std::string pattern; }; diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0b8cbb5c955..6ae1fed4c09 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -398,6 +398,17 @@ cdef class CastOptions(FunctionOptions): self.options.allow_invalid_utf8 = flag +cdef class ContainsExactOptions(FunctionOptions): + cdef: + CContainsExactOptions contains_exact_options + + def __init__(self, pattern): + self.contains_exact_options.pattern = tobytes(pattern) + + cdef const CFunctionOptions* get_options(self) except NULL: + return &self.contains_exact_options + + cdef class FilterOptions(FunctionOptions): cdef: CFilterOptions filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 7a92a158812..dc996f0e78a 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -113,6 +113,24 @@ def func(left, right): multiply = _simple_binary_function('multiply') +def contains_exact(array, pattern): + """ + Check whether a pattern occurs as part of the values of the array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("contains_exact", [array], + _pc.ContainsExactOptions(pattern)) + + def sum(array): """ Sum the values in a numerical (chunked) array. diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index c4d992aab25..0dd4368a4e1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1582,6 +1582,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool allow_float_truncate c_bool allow_invalid_utf8 + cdef cppclass CContainsExactOptions \ + "arrow::compute::ContainsExactOptions"(CFunctionOptions): + ContainsExactOptions(c_string pattern) + c_string pattern + enum CFilterNullSelectionBehavior \ "arrow::compute::FilterOptions::NullSelectionBehavior": CFilterNullSelectionBehavior_DROP \ diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index cce46a8fe53..9230032a25a 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -90,6 +90,13 @@ def test_sum_chunked_array(arrow_type): assert pc.sum(arr) == None # noqa: E711 +def test_contains_exact(): + arr = pa.array(["ab", "abc", "ba", None]) + result = pc.contains_exact(arr, "ab") + expected = pa.array([True, True, False, None]) + assert expected.equals(result) + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty) From 2fa458b7da8f0577f8089caf34ff8447529bc8fd Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Tue, 30 Jun 2020 16:11:58 -0500 Subject: [PATCH 5/8] Do not allocate memory in kernel. Respect offset of output array --- .../arrow/compute/kernels/scalar_string.cc | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index dd16bc4ab2f..9e32705d52a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -300,31 +300,24 @@ void AddAsciiLength(FunctionRegistry* registry) { // ---------------------------------------------------------------------- // exact pattern detection -template using StrToBoolTransformFunc = - std::function; + std::function; // Apply `transform` to input character data- this function cannot change the // length template void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, - StrToBoolTransformFunc transform, - Datum* out) { + StrToBoolTransformFunc transform, Datum* out) { using offset_type = typename Type::offset_type; if (batch[0].kind() == Datum::ARRAY) { const ArrayData& input = *batch[0].array(); - ArrayData* out_arr = out->mutable_array(); - - // Allocate space for output data - KERNEL_RETURN_IF_ERROR( - ctx, - ctx->Allocate(BitUtil::BytesForBits(input.length)).Value(&out_arr->buffers[1])); if (input.length > 0) { transform( reinterpret_cast(input.buffers[1]->data()) + input.offset, - input.buffers[2]->data(), input.length, out_arr->buffers[1]->mutable_data()); + input.buffers[2]->data(), input.length, out_arr->offset, + out_arr->buffers[1]->mutable_data()); } } else { const auto& input = checked_cast(*batch[0].scalar()); @@ -332,10 +325,10 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, uint8_t result_value = 0; if (input.is_valid) { result->is_valid = true; - KERNEL_RETURN_IF_ERROR(ctx, ctx->Allocate(1).Value(&result->value)); std::array offsets{0, static_cast(input.value->size())}; - transform(offsets.data(), input.value->data(), 1, &result_value); + transform(offsets.data(), input.value->data(), 1, /*output_offset=*/0, + &result_value); out->value = std::make_shared(result_value > 0); } } @@ -344,7 +337,7 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, template void TransformContainsExact(const uint8_t* pattern, int64_t pattern_length, const offset_type* offsets, const uint8_t* data, - int64_t length, uint8_t* output) { + int64_t length, int64_t output_offset, uint8_t* output) { // This is an implementation of the Knuth-Morris-Pratt algorithm // Phase 1: Build the prefix table @@ -361,7 +354,7 @@ void TransformContainsExact(const uint8_t* pattern, int64_t pattern_length, } // Phase 2: Find the prefix in the data - FirstTimeBitmapWriter bitmap_writer(output, 0, length); + FirstTimeBitmapWriter bitmap_writer(output, output_offset, length); for (int64_t i = 0; i < length; ++i) { const uint8_t* current_data = data + offsets[i]; int64_t current_length = offsets[i + 1] - offsets[i]; @@ -387,16 +380,20 @@ using ContainsExactState = OptionsWrapper; template struct ContainsExact { + using offset_type = typename Type::offset_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ContainsExactOptions arg = - checked_cast(*ctx->state()).options; - auto transform_func = - std::bind(TransformContainsExact, - reinterpret_cast(arg.pattern.c_str()), - arg.pattern.length(), std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3, std::placeholders::_4); - - StringBoolTransform(ctx, batch, transform_func, out); + ContainsExactOptions arg = ContainsExactState::Get(ctx); + const uint8_t* pat = reinterpret_cast(arg.pattern.c_str()); + const int64_t pat_size = arg.pattern.length(); + StringBoolTransform( + ctx, batch, + [pat, pat_size](const void* offsets, const uint8_t* data, int64_t length, + int64_t output_offset, uint8_t* output) { + TransformContainsExact( + pat, pat_size, reinterpret_cast(offsets), data, length, + output_offset, output); + }, + out); } }; From 6f80852f4d6e8969d2d69b5a3833da7b18399feb Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Thu, 2 Jul 2020 18:22:42 +0200 Subject: [PATCH 6/8] Improve docstring, rename to binary_contains_exact --- cpp/src/arrow/compute/api_scalar.h | 4 +-- .../arrow/compute/kernels/scalar_string.cc | 30 ++++++++++--------- .../kernels/scalar_string_benchmark.cc | 6 ++-- .../compute/kernels/scalar_string_test.cc | 16 +++++----- python/pyarrow/_compute.pyx | 9 +++--- python/pyarrow/compute.py | 8 ++--- python/pyarrow/includes/libarrow.pxd | 10 +++---- python/pyarrow/tests/test_compute.py | 4 +-- 8 files changed, 45 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index f7917766ff2..d513173d76f 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -262,8 +262,8 @@ Result IsNull(const Datum& values, ExecContext* ctx = NULLPTR); // ---------------------------------------------------------------------- // String functions -struct ARROW_EXPORT ContainsExactOptions : public FunctionOptions { - explicit ContainsExactOptions(std::string pattern = "") : pattern(pattern) {} +struct ARROW_EXPORT BinaryContainsExactOptions : public FunctionOptions { + explicit BinaryContainsExactOptions(std::string pattern) : pattern(pattern) {} std::string pattern; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 9e32705d52a..c9124e17987 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -335,9 +335,10 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, } template -void TransformContainsExact(const uint8_t* pattern, int64_t pattern_length, - const offset_type* offsets, const uint8_t* data, - int64_t length, int64_t output_offset, uint8_t* output) { +void TransformBinaryContainsExact(const uint8_t* pattern, int64_t pattern_length, + const offset_type* offsets, const uint8_t* data, + int64_t length, int64_t output_offset, + uint8_t* output) { // This is an implementation of the Knuth-Morris-Pratt algorithm // Phase 1: Build the prefix table @@ -376,20 +377,20 @@ void TransformContainsExact(const uint8_t* pattern, int64_t pattern_length, bitmap_writer.Finish(); } -using ContainsExactState = OptionsWrapper; +using BinaryContainsExactState = OptionsWrapper; template -struct ContainsExact { +struct BinaryContainsExact { using offset_type = typename Type::offset_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ContainsExactOptions arg = ContainsExactState::Get(ctx); + BinaryContainsExactOptions arg = BinaryContainsExactState::Get(ctx); const uint8_t* pat = reinterpret_cast(arg.pattern.c_str()); const int64_t pat_size = arg.pattern.length(); StringBoolTransform( ctx, batch, [pat, pat_size](const void* offsets, const uint8_t* data, int64_t length, int64_t output_offset, uint8_t* output) { - TransformContainsExact( + TransformBinaryContainsExact( pat, pat_size, reinterpret_cast(offsets), data, length, output_offset, output); }, @@ -397,13 +398,14 @@ struct ContainsExact { } }; -void AddContainsExact(FunctionRegistry* registry) { - auto func = std::make_shared("contains_exact", Arity::Unary()); - auto exec_32 = ContainsExact::Exec; - auto exec_64 = ContainsExact::Exec; - DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, ContainsExactState::Init)); +void AddBinaryContainsExact(FunctionRegistry* registry) { + auto func = std::make_shared("binary_contains_exact", Arity::Unary()); + auto exec_32 = BinaryContainsExact::Exec; + auto exec_64 = BinaryContainsExact::Exec; DCHECK_OK( - func->AddKernel({large_utf8()}, boolean(), exec_64, ContainsExactState::Init)); + func->AddKernel({utf8()}, boolean(), exec_32, BinaryContainsExactState::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, boolean(), exec_64, + BinaryContainsExactState::Init)); DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -487,7 +489,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringUtf8TransformKernel("utf8_lower", registry); #endif AddAsciiLength(registry); - AddContainsExact(registry); + AddBinaryContainsExact(registry); AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index a0842e65109..dfd2e259783 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -57,8 +57,8 @@ static void AsciiUpper(benchmark::State& state) { UnaryStringBenchmark(state, "ascii_upper"); } -static void ContainsExact(benchmark::State& state) { - ContainsExactOptions options("abac"); +static void BinaryContainsExact(benchmark::State& state) { + BinaryContainsExactOptions options("abac"); UnaryStringBenchmark(state, "contains_exact", &options); } @@ -74,7 +74,7 @@ static void Utf8Lower(benchmark::State& state) { BENCHMARK(AsciiLower); BENCHMARK(AsciiUpper); -BENCHMARK(ContainsExact); +BENCHMARK(BinaryContainsExact); #ifdef ARROW_WITH_UTF8PROC BENCHMARK(Utf8Lower); BENCHMARK(Utf8Upper); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 1053f525572..0989401d034 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -147,14 +147,14 @@ TYPED_TEST(TestStringKernels, Utf8Lower) { #endif // ARROW_WITH_UTF8PROC -TYPED_TEST(TestStringKernels, ContainsExact) { - ContainsExactOptions options{"ab"}; - this->CheckUnary("contains_exact", "[]", boolean(), "[]", &options); - this->CheckUnary("contains_exact", R"(["abc", "acb", "cab", null, "bac"])", boolean(), - "[true, false, true, null, false]", &options); - - ContainsExactOptions options_repeated{"abab"}; - this->CheckUnary("contains_exact", R"(["abab", "ab", "cababc", null, "bac"])", +TYPED_TEST(TestStringKernels, BinaryContainsExact) { + BinaryContainsExactOptions options{"ab"}; + this->CheckUnary("binary_contains_exact", "[]", boolean(), "[]", &options); + this->CheckUnary("binary_contains_exact", R"(["abc", "acb", "cab", null, "bac"])", + boolean(), "[true, false, true, null, false]", &options); + + BinaryContainsExactOptions options_repeated{"abab"}; + this->CheckUnary("binary_contains_exact", R"(["abab", "ab", "cababc", null, "bac"])", boolean(), "[true, false, true, null, false]", &options_repeated); } diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 6ae1fed4c09..7268616c08b 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -398,15 +398,16 @@ cdef class CastOptions(FunctionOptions): self.options.allow_invalid_utf8 = flag -cdef class ContainsExactOptions(FunctionOptions): +cdef class BinaryContainsExactOptions(FunctionOptions): cdef: - CContainsExactOptions contains_exact_options + unique_ptr[CBinaryContainsExactOptions] binary_contains_exact_options def __init__(self, pattern): - self.contains_exact_options.pattern = tobytes(pattern) + self.binary_contains_exact_options.pattern.reset( + new CBinaryContainsExactOptions(tobytes(pattern))) cdef const CFunctionOptions* get_options(self) except NULL: - return &self.contains_exact_options + return &self.binary_contains_exact_options cdef class FilterOptions(FunctionOptions): diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index dc996f0e78a..babab27e58f 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -113,9 +113,9 @@ def func(left, right): multiply = _simple_binary_function('multiply') -def contains_exact(array, pattern): +def binary_contains_exact(array, pattern): """ - Check whether a pattern occurs as part of the values of the array. + Test if pattern is contained within a value of a binary array. Parameters ---------- @@ -127,8 +127,8 @@ def contains_exact(array, pattern): ------- result : pyarrow.Array or pyarrow.ChunkedArray """ - return call_function("contains_exact", [array], - _pc.ContainsExactOptions(pattern)) + return call_function("binary_contains_exact", [array], + _pc.BinaryContainsExactOptions(pattern)) def sum(array): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0dd4368a4e1..dd44c30551b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1566,6 +1566,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CFunctionRegistry* GetFunctionRegistry() + cdef cppclass CBinaryContainsExactOptions \ + "arrow::compute::BinaryContainsExactOptions"(CFunctionOptions): + BinaryContainsExactOptions(c_string pattern) + c_string pattern + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) @@ -1582,11 +1587,6 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool allow_float_truncate c_bool allow_invalid_utf8 - cdef cppclass CContainsExactOptions \ - "arrow::compute::ContainsExactOptions"(CFunctionOptions): - ContainsExactOptions(c_string pattern) - c_string pattern - enum CFilterNullSelectionBehavior \ "arrow::compute::FilterOptions::NullSelectionBehavior": CFilterNullSelectionBehavior_DROP \ diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 9230032a25a..9854748e8bd 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -90,9 +90,9 @@ def test_sum_chunked_array(arrow_type): assert pc.sum(arr) == None # noqa: E711 -def test_contains_exact(): +def test_binary_contains_exact(): arr = pa.array(["ab", "abc", "ba", None]) - result = pc.contains_exact(arr, "ab") + result = pc.binary_contains_exact(arr, "ab") expected = pa.array([True, True, False, None]) assert expected.equals(result) From 34d77a99781df9f500b1d1b553b08085d931849b Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Thu, 2 Jul 2020 18:25:05 +0200 Subject: [PATCH 7/8] Make result_value more local in scope --- cpp/src/arrow/compute/kernels/scalar_string.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index c9124e17987..b7d2fee1a3a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -321,9 +321,9 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, } } else { const auto& input = checked_cast(*batch[0].scalar()); - auto result = checked_pointer_cast(MakeNullScalar(out->type())); - uint8_t result_value = 0; if (input.is_valid) { + auto result = checked_pointer_cast(MakeNullScalar(out->type())); + uint8_t result_value = 0; result->is_valid = true; std::array offsets{0, static_cast(input.value->size())}; From 4e587adc1e5adda9201f5e5ffa10416b6e8d9007 Mon Sep 17 00:00:00 2001 From: "Uwe L. Korn" Date: Thu, 2 Jul 2020 18:47:43 +0200 Subject: [PATCH 8/8] Fix cython compilation Follow function rename in benchmarks --- cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc | 2 +- python/pyarrow/_compute.pyx | 4 ++-- python/pyarrow/includes/libarrow.pxd | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index dfd2e259783..46ee129b03c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -59,7 +59,7 @@ static void AsciiUpper(benchmark::State& state) { static void BinaryContainsExact(benchmark::State& state) { BinaryContainsExactOptions options("abac"); - UnaryStringBenchmark(state, "contains_exact", &options); + UnaryStringBenchmark(state, "binary_contains_exact", &options); } #ifdef ARROW_WITH_UTF8PROC diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 7268616c08b..201a59ec763 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -403,11 +403,11 @@ cdef class BinaryContainsExactOptions(FunctionOptions): unique_ptr[CBinaryContainsExactOptions] binary_contains_exact_options def __init__(self, pattern): - self.binary_contains_exact_options.pattern.reset( + self.binary_contains_exact_options.reset( new CBinaryContainsExactOptions(tobytes(pattern))) cdef const CFunctionOptions* get_options(self) except NULL: - return &self.binary_contains_exact_options + return self.binary_contains_exact_options.get() cdef class FilterOptions(FunctionOptions): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index dd44c30551b..e1e28e09da3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1568,7 +1568,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CBinaryContainsExactOptions \ "arrow::compute::BinaryContainsExactOptions"(CFunctionOptions): - BinaryContainsExactOptions(c_string pattern) + CBinaryContainsExactOptions(c_string pattern) c_string pattern cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions):