Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,15 @@ Result<Datum> IsValid(const Datum& values, ExecContext* ctx = NULLPTR);
ARROW_EXPORT
Result<Datum> IsNull(const Datum& values, ExecContext* ctx = NULLPTR);

// ----------------------------------------------------------------------
// String functions

struct ARROW_EXPORT BinaryContainsExactOptions : public FunctionOptions {
explicit BinaryContainsExactOptions(std::string pattern) : pattern(pattern) {}

std::string pattern;
};

// ----------------------------------------------------------------------
// Temporal functions

Expand Down
113 changes: 113 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,118 @@ void AddAsciiLength(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// ----------------------------------------------------------------------
// exact pattern detection

using StrToBoolTransformFunc =
std::function<void(const void*, const uint8_t*, int64_t, int64_t, uint8_t*)>;

// Apply `transform` to input character data- this function cannot change the
// length
template <typename Type>
void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch,
StrToBoolTransformFunc transform, Datum* out) {
using offset_type = typename Type::offset_type;

if (batch[0].kind() == Datum::ARRAY) {
const ArrayData& input = *batch[0].array();
ArrayData* out_arr = out->mutable_array();
if (input.length > 0) {
transform(
reinterpret_cast<const offset_type*>(input.buffers[1]->data()) + input.offset,
input.buffers[2]->data(), input.length, out_arr->offset,
out_arr->buffers[1]->mutable_data());
}
} else {
const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar());
if (input.is_valid) {
auto result = checked_pointer_cast<BooleanScalar>(MakeNullScalar(out->type()));
uint8_t result_value = 0;
result->is_valid = true;
std::array<offset_type, 2> offsets{0,
static_cast<offset_type>(input.value->size())};
transform(offsets.data(), input.value->data(), 1, /*output_offset=*/0,
&result_value);
out->value = std::make_shared<BooleanScalar>(result_value > 0);
}
}
}

template <typename offset_type>
void TransformBinaryContainsExact(const uint8_t* pattern, int64_t pattern_length,
const offset_type* offsets, const uint8_t* data,
int64_t length, int64_t output_offset,
uint8_t* output) {
// This is an implementation of the Knuth-Morris-Pratt algorithm

// Phase 1: Build the prefix table
std::vector<offset_type> prefix_table(pattern_length + 1);
offset_type prefix_length = -1;
prefix_table[0] = -1;
for (offset_type pos = 0; pos < pattern_length; ++pos) {
// The prefix cannot be expanded, reset.
if (prefix_length >= 0 && pattern[pos] != pattern[prefix_length]) {
prefix_length = prefix_table[prefix_length];
}
prefix_length++;
prefix_table[pos + 1] = prefix_length;
}

// Phase 2: Find the prefix in the data
FirstTimeBitmapWriter bitmap_writer(output, output_offset, length);
for (int64_t i = 0; i < length; ++i) {
const uint8_t* current_data = data + offsets[i];
int64_t current_length = offsets[i + 1] - offsets[i];

int64_t pattern_pos = 0;
for (int64_t k = 0; k < current_length; k++) {
if (pattern[pattern_pos] == current_data[k]) {
pattern_pos++;
if (pattern_pos == pattern_length) {
bitmap_writer.Set();
break;
}
} else {
pattern_pos = std::max<offset_type>(0, prefix_table[pattern_pos]);
}
}
bitmap_writer.Next();
}
bitmap_writer.Finish();
}

using BinaryContainsExactState = OptionsWrapper<BinaryContainsExactOptions>;

template <typename Type>
struct BinaryContainsExact {
using offset_type = typename Type::offset_type;
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
BinaryContainsExactOptions arg = BinaryContainsExactState::Get(ctx);
const uint8_t* pat = reinterpret_cast<const uint8_t*>(arg.pattern.c_str());
const int64_t pat_size = arg.pattern.length();
StringBoolTransform<Type>(
ctx, batch,
[pat, pat_size](const void* offsets, const uint8_t* data, int64_t length,
int64_t output_offset, uint8_t* output) {
TransformBinaryContainsExact<offset_type>(
pat, pat_size, reinterpret_cast<const offset_type*>(offsets), data, length,
output_offset, output);
},
out);
}
};

void AddBinaryContainsExact(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("binary_contains_exact", Arity::Unary());
auto exec_32 = BinaryContainsExact<StringType>::Exec;
auto exec_64 = BinaryContainsExact<LargeStringType>::Exec;
DCHECK_OK(
func->AddKernel({utf8()}, boolean(), exec_32, BinaryContainsExactState::Init));
DCHECK_OK(func->AddKernel({large_utf8()}, boolean(), exec_64,
BinaryContainsExactState::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// ----------------------------------------------------------------------
// strptime string parsing

Expand Down Expand Up @@ -377,6 +489,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
MakeUnaryStringUtf8TransformKernel<Utf8Lower>("utf8_lower", registry);
#endif
AddAsciiLength(registry);
AddBinaryContainsExact(registry);
AddStrptime(registry);
}

Expand Down
13 changes: 10 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace compute {

constexpr auto kSeed = 0x94378165;

static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name) {
static void UnaryStringBenchmark(benchmark::State& state, const std::string& func_name,
const FunctionOptions* options = nullptr) {
const int64_t array_length = 1 << 20;
const int64_t value_min_size = 0;
const int64_t value_max_size = 32;
Expand All @@ -39,10 +40,10 @@ static void UnaryStringBenchmark(benchmark::State& state, const std::string& fun
auto values =
rng.String(array_length, value_min_size, value_max_size, null_probability);
// Make sure lookup tables are initialized before measuring
ABORT_NOT_OK(CallFunction(func_name, {values}));
ABORT_NOT_OK(CallFunction(func_name, {values}, options));

for (auto _ : state) {
ABORT_NOT_OK(CallFunction(func_name, {values}));
ABORT_NOT_OK(CallFunction(func_name, {values}, options));
}
state.SetItemsProcessed(state.iterations() * array_length);
state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size());
Expand All @@ -56,6 +57,11 @@ static void AsciiUpper(benchmark::State& state) {
UnaryStringBenchmark(state, "ascii_upper");
}

static void BinaryContainsExact(benchmark::State& state) {
BinaryContainsExactOptions options("abac");
UnaryStringBenchmark(state, "binary_contains_exact", &options);
}

#ifdef ARROW_WITH_UTF8PROC
static void Utf8Upper(benchmark::State& state) {
UnaryStringBenchmark(state, "utf8_upper");
Expand All @@ -68,6 +74,7 @@ static void Utf8Lower(benchmark::State& state) {

BENCHMARK(AsciiLower);
BENCHMARK(AsciiUpper);
BENCHMARK(BinaryContainsExact);
#ifdef ARROW_WITH_UTF8PROC
BENCHMARK(Utf8Lower);
BENCHMARK(Utf8Upper);
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ TYPED_TEST(TestStringKernels, Utf8Lower) {

#endif // ARROW_WITH_UTF8PROC

TYPED_TEST(TestStringKernels, BinaryContainsExact) {
BinaryContainsExactOptions options{"ab"};
this->CheckUnary("binary_contains_exact", "[]", boolean(), "[]", &options);
this->CheckUnary("binary_contains_exact", R"(["abc", "acb", "cab", null, "bac"])",
boolean(), "[true, false, true, null, false]", &options);

BinaryContainsExactOptions options_repeated{"abab"};
this->CheckUnary("binary_contains_exact", R"(["abab", "ab", "cababc", null, "bac"])",
boolean(), "[true, false, true, null, false]", &options_repeated);
}

TYPED_TEST(TestStringKernels, Strptime) {
std::string input1 = R"(["5/1/2020", null, "12/11/1900"])";
std::string output1 = R"(["2020-05-01", null, "1900-12-11"])";
Expand Down
12 changes: 12 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,18 @@ cdef class CastOptions(FunctionOptions):
self.options.allow_invalid_utf8 = flag


cdef class BinaryContainsExactOptions(FunctionOptions):
cdef:
unique_ptr[CBinaryContainsExactOptions] binary_contains_exact_options

def __init__(self, pattern):
self.binary_contains_exact_options.reset(
new CBinaryContainsExactOptions(tobytes(pattern)))

cdef const CFunctionOptions* get_options(self) except NULL:
return self.binary_contains_exact_options.get()


cdef class FilterOptions(FunctionOptions):
cdef:
CFilterOptions filter_options
Expand Down
18 changes: 18 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def func(left, right):
multiply = _simple_binary_function('multiply')


def binary_contains_exact(array, pattern):
"""
Test if pattern is contained within a value of a binary array.

Parameters
----------
array : pyarrow.Array or pyarrow.ChunkedArray
pattern : str
pattern to search for exact matches

Returns
-------
result : pyarrow.Array or pyarrow.ChunkedArray
"""
return call_function("binary_contains_exact", [array],
_pc.BinaryContainsExactOptions(pattern))


def sum(array):
"""
Sum the values in a numerical (chunked) array.
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:

CFunctionRegistry* GetFunctionRegistry()

cdef cppclass CBinaryContainsExactOptions \
"arrow::compute::BinaryContainsExactOptions"(CFunctionOptions):
CBinaryContainsExactOptions(c_string pattern)
c_string pattern

cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions):
CCastOptions()
CCastOptions(c_bool safe)
Expand Down
7 changes: 7 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def test_sum_chunked_array(arrow_type):
assert pc.sum(arr) == None # noqa: E711


def test_binary_contains_exact():
arr = pa.array(["ab", "abc", "ba", None])
result = pc.binary_contains_exact(arr, "ab")
expected = pa.array([True, True, False, None])
assert expected.equals(result)


@pytest.mark.parametrize(('ty', 'values'), all_array_types)
def test_take(ty, values):
arr = pa.array(values, type=ty)
Expand Down