From fb42b30b510d0c0ed4be1ee8016a606ae1a9ffc1 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 8 Jun 2021 16:14:43 +0200 Subject: [PATCH] ARROW-12951: [C++] Reduce generated code size for string kernels Factor out type-agnostic string operations (such as finding a split pattern) in separate classes to avoid generating several versions of them when generating the typed kernel execution classes. --- .../compute/kernels/scalar_arithmetic_test.cc | 13 - .../arrow/compute/kernels/scalar_string.cc | 1003 +++++++++-------- .../compute/kernels/scalar_string_test.cc | 5 +- cpp/src/arrow/testing/gtest_util.cc | 13 + cpp/src/arrow/testing/gtest_util.h | 8 + 5 files changed, 532 insertions(+), 510 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index ff66fcf1d12..c4bfac459dc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -41,19 +41,6 @@ namespace arrow { namespace compute { -std::shared_ptr TweakValidityBit(const std::shared_ptr& array, - int64_t index, bool validity) { - auto data = array->data()->Copy(); - if (data->buffers[0] == nullptr) { - data->buffers[0] = *AllocateBitmap(data->length); - BitUtil::SetBitsTo(data->buffers[0]->mutable_data(), 0, data->length, true); - } - BitUtil::SetBitTo(data->buffers[0]->mutable_data(), index, validity); - data->null_count = kUnknownNullCount; - // Need to return a new array, because Array caches the null bitmap pointer - return MakeArray(data); -} - template class TestUnaryArithmetic : public TestBase { protected: diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index a1e19b608d9..8b740f3742a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -38,6 +38,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/utf8.h" #include "arrow/util/value_parsing.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -130,117 +131,170 @@ void EnsureLookupTablesFilled() { }); } +#else + +void EnsureLookupTablesFilled() {} + #endif // ARROW_WITH_UTF8PROC -/// Transform string -> string with a reasonable guess on the maximum number of codepoints -template -struct StringTransform { - using offset_type = typename Type::offset_type; - using ArrayType = typename TypeTraits::ArrayType; +constexpr int64_t kTransformError = -1; - virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) { - return input_ncodeunits; +struct StringTransformBase { + virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); } - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return Derived().Execute(ctx, batch, out); + // Return the maximum total size of the output in codeunits (i.e. bytes) + // given input characteristics. + virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) { + return input_ncodeunits; } - static Status InvalidStatus() { + virtual Status InvalidStatus() { return Status::Invalid("Invalid UTF8 sequence in input"); } - Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Derived classes should also define this method: + // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + // uint8_t* output); +}; + +template +struct StringTransformExecBase { + using offset_type = typename Type::offset_type; + using ArrayType = typename TypeTraits::ArrayType; + + static Status Execute(KernelContext* ctx, StringTransform* transform, + const ExecBatch& batch, Datum* out) { if (batch[0].kind() == Datum::ARRAY) { - const ArrayData& input = *batch[0].array(); - ArrayType input_boxed(batch[0].array()); - ArrayData* output = out->mutable_array(); + return ExecArray(ctx, transform, batch[0].array(), out); + } + DCHECK_EQ(batch[0].kind(), Datum::SCALAR); + return ExecScalar(ctx, transform, batch[0].scalar(), out); + } - offset_type input_ncodeunits = input_boxed.total_values_length(); - offset_type input_nstrings = static_cast(input.length); + static Status ExecArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data, Datum* out) { + ArrayType input(data); + ArrayData* output = out->mutable_array(); - const int64_t output_ncodeunits_max = - MaxCodeunits(input_nstrings, input_ncodeunits); - if (output_ncodeunits_max > std::numeric_limits::max()) { - return Status::CapacityError( - "Result might not fit in a 32bit utf8 array, convert to large_utf8"); - } + const int64_t input_ncodeunits = input.total_values_length(); + const int64_t input_nstrings = input.length(); - ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); - output->buffers[2] = values_buffer; + const int64_t output_ncodeunits_max = + transform->MaxCodeunits(input_nstrings, input_ncodeunits); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); + output->buffers[2] = values_buffer; - // String offsets are preallocated - offset_type* output_string_offsets = output->GetMutableValues(1); - uint8_t* output_str = output->buffers[2]->mutable_data(); - offset_type output_ncodeunits = 0; + // String offsets are preallocated + offset_type* output_string_offsets = output->GetMutableValues(1); + uint8_t* output_str = output->buffers[2]->mutable_data(); + offset_type output_ncodeunits = 0; - output_string_offsets[0] = 0; - for (int64_t i = 0; i < input_nstrings; i++) { + output_string_offsets[0] = 0; + for (int64_t i = 0; i < input_nstrings; i++) { + if (!input.IsNull(i)) { offset_type input_string_ncodeunits; - const uint8_t* input_string = input_boxed.GetValue(i, &input_string_ncodeunits); - offset_type encoded_nbytes = 0; - if (ARROW_PREDICT_FALSE(!static_cast(*this).Transform( - input_string, input_string_ncodeunits, output_str + output_ncodeunits, - &encoded_nbytes))) { - return Derived::InvalidStatus(); + const uint8_t* input_string = input.GetValue(i, &input_string_ncodeunits); + auto encoded_nbytes = static_cast(transform->Transform( + input_string, input_string_ncodeunits, output_str + output_ncodeunits)); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); } output_ncodeunits += encoded_nbytes; - output_string_offsets[i + 1] = output_ncodeunits; } - DCHECK_LE(output_ncodeunits, output_ncodeunits_max); + output_string_offsets[i + 1] = output_ncodeunits; + } + DCHECK_LE(output_ncodeunits, output_ncodeunits_max); - // Trim the codepoint buffer, since we allocated too much - return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); - } else { - DCHECK_EQ(batch[0].kind(), Datum::SCALAR); - const auto& input = checked_cast(*batch[0].scalar()); - if (!input.is_valid) { - return Status::OK(); - } - auto* result = checked_cast(out->scalar().get()); - result->is_valid = true; - offset_type data_nbytes = static_cast(input.value->size()); + // Trim the codepoint buffer, since we allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + } - int64_t output_ncodeunits_max = MaxCodeunits(1, data_nbytes); - if (output_ncodeunits_max > std::numeric_limits::max()) { - return Status::CapacityError( - "Result might not fit in a 32bit utf8 array, convert to large_utf8"); - } - ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); - result->value = value_buffer; - offset_type encoded_nbytes = 0; - if (ARROW_PREDICT_FALSE(!static_cast(*this).Transform( - input.value->data(), data_nbytes, value_buffer->mutable_data(), - &encoded_nbytes))) { - return Derived::InvalidStatus(); - } - DCHECK_LE(encoded_nbytes, output_ncodeunits_max); - return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); + static Status ExecScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar, Datum* out) { + const auto& input = checked_cast(*scalar); + if (!input.is_valid) { + return Status::OK(); + } + auto* result = checked_cast(out->scalar().get()); + result->is_valid = true; + const int64_t data_nbytes = static_cast(input.value->size()); + + const int64_t output_ncodeunits_max = transform->MaxCodeunits(1, data_nbytes); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); + result->value = value_buffer; + auto encoded_nbytes = static_cast(transform->Transform( + input.value->data(), data_nbytes, value_buffer->mutable_data())); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); } + DCHECK_LE(encoded_nbytes, output_ncodeunits_max); + return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); + } +}; + +template +struct StringTransformExec : public StringTransformExecBase { + using StringTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform; + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + +template +struct StringTransformExecWithState + : public StringTransformExecBase { + using State = typename StringTransform::State; + using StringTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform(State::Get(ctx)); + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); } }; #ifdef ARROW_WITH_UTF8PROC -// transforms per codepoint -template -struct StringTransformCodepoint : StringTransform { - using Base = StringTransform; - using offset_type = typename Base::offset_type; +template +struct StringTransformCodepoint : public StringTransformBase { + Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override { + EnsureLookupTablesFilled(); + return Status::OK(); + } - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return CodepointTransform::MaxCodeunits(ninputs, input_ncodeunits); + } + + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { uint8_t* output_start = output; if (ARROW_PREDICT_FALSE( !arrow::util::UTF8Transform(input, input + input_string_ncodeunits, &output, - Derived::TransformCodepoint))) { - return false; + CodepointTransform::TransformCodepoint))) { + return kTransformError; } - *output_written = static_cast(output - output_start); - return true; + return output - output_start; } +}; - int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { +// struct CaseMappingMixin { +struct CaseMappingTransform { + static int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) { // Section 5.18 of the Unicode spec claim that the number of codepoints for case // mapping can grow by a factor of 3. This means grow by a factor of 3 in bytes // However, since we don't support all casings (SpecialCasing.txt) the growth @@ -249,74 +303,67 @@ struct StringTransformCodepoint : StringTransform { // two code units (even) can grow to 3 code units. return static_cast(input_ncodeunits) * 3 / 2; } - - Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - EnsureLookupTablesFilled(); - return Base::Execute(ctx, batch, out); - } }; -template -struct UTF8Upper : StringTransformCodepoint> { - inline static uint32_t TransformCodepoint(uint32_t codepoint) { +struct UTF8UpperTransform : public CaseMappingTransform { + static uint32_t TransformCodepoint(uint32_t codepoint) { return codepoint <= kMaxCodepointLookup ? lut_upper_codepoint[codepoint] : utf8proc_toupper(codepoint); } }; template -struct UTF8Lower : StringTransformCodepoint> { - inline static uint32_t TransformCodepoint(uint32_t codepoint) { +using UTF8Upper = StringTransformExec>; + +struct UTF8LowerTransform : public CaseMappingTransform { + static uint32_t TransformCodepoint(uint32_t codepoint) { return codepoint <= kMaxCodepointLookup ? lut_lower_codepoint[codepoint] : utf8proc_tolower(codepoint); } }; -#else - -void EnsureLookupTablesFilled() {} +template +using UTF8Lower = StringTransformExec>; #endif // ARROW_WITH_UTF8PROC -template -struct AsciiReverse : StringTransform> { - using Base = StringTransform>; - using offset_type = typename Base::offset_type; - - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { +struct AsciiReverseTransform : public StringTransformBase { + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { uint8_t utf8_char_found = 0; - for (offset_type i = 0; i < input_string_ncodeunits; i++) { + for (int64_t i = 0; i < input_string_ncodeunits; i++) { // if a utf8 char is found, report to utf8_char_found utf8_char_found |= input[i] & 0x80; output[input_string_ncodeunits - i - 1] = input[i]; } - *output_written = input_string_ncodeunits; - return utf8_char_found == 0; + return utf8_char_found ? kTransformError : input_string_ncodeunits; } - static Status InvalidStatus() { return Status::Invalid("Non-ASCII sequence in input"); } + Status InvalidStatus() override { + return Status::Invalid("Non-ASCII sequence in input"); + } }; template -struct Utf8Reverse : StringTransform> { - using Base = StringTransform>; - using offset_type = typename Base::offset_type; +using AsciiReverse = StringTransformExec; - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { - offset_type i = 0; +struct Utf8ReverseTransform : public StringTransformBase { + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + int64_t i = 0; while (i < input_string_ncodeunits) { - uint8_t offset = util::ValidUtf8CodepointByteSize(input + i); - offset_type stride = std::min(i + offset, input_string_ncodeunits); - std::copy(input + i, input + stride, output + input_string_ncodeunits - stride); - i += offset; + int64_t char_end = std::min(i + util::ValidUtf8CodepointByteSize(input + i), + input_string_ncodeunits); + std::copy(input + i, input + char_end, output + input_string_ncodeunits - char_end); + i = char_end; } - *output_written = input_string_ncodeunits; - return true; + return input_string_ncodeunits; } }; +template +using Utf8Reverse = StringTransformExec; + using TransformFunc = std::function; // Transform a buffer of offsets to one which begins with 0 and has same @@ -973,187 +1020,182 @@ void AddCountSubstring(FunctionRegistry* registry) { // Slicing -template -struct SliceBase : StringTransform { - using Base = StringTransform; - using offset_type = typename Base::offset_type; +struct SliceTransformBase : public StringTransformBase { using State = OptionsWrapper; - SliceOptions options; - - explicit SliceBase(SliceOptions options) : options(options) {} + const SliceOptions* options; - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - SliceOptions options = State::Get(ctx); - if (options.step == 0) { + Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override { + options = &State::Get(ctx); + if (options->step == 0) { return Status::Invalid("Slice step cannot be zero"); } - return Derived(options).Execute(ctx, batch, out); + return Status::OK(); } }; -#define PROPAGATE_FALSE(expr) \ +struct SliceCodeunitsTransform : SliceTransformBase { + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + const SliceOptions& opt = *this->options; + if ((opt.start >= 0) != (opt.stop >= 0)) { + // If start and stop don't have the same sign, we can't guess an upper bound + // on the resulting slice lengths, so return a worst case estimate. + return input_ncodeunits; + } + int64_t max_slice_codepoints = (opt.stop - opt.start + opt.step - 1) / opt.step; + // The maximum UTF8 byte size of a codepoint is 4 + return std::min(input_ncodeunits, + 4 * ninputs * std::max(0, max_slice_codepoints)); + } + + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + if (options->step >= 1) { + return SliceForward(input, input_string_ncodeunits, output); + } + return SliceBackward(input, input_string_ncodeunits, output); + } + +#define RETURN_IF_UTF8_ERROR(expr) \ do { \ if (ARROW_PREDICT_FALSE(!expr)) { \ - return false; \ + return kTransformError; \ } \ } while (0) -bool SliceCodeunitsTransform(const uint8_t* input, int64_t input_string_ncodeunits, - uint8_t* output, int64_t* output_written, - const SliceOptions& options) { - const uint8_t* begin = input; - const uint8_t* end = input + input_string_ncodeunits; - const uint8_t* begin_sliced = begin; - const uint8_t* end_sliced = end; + int64_t SliceForward(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + // Slice in forward order (step > 0) + const SliceOptions& opt = *this->options; + const uint8_t* begin = input; + const uint8_t* end = input + input_string_ncodeunits; + const uint8_t* begin_sliced = begin; + const uint8_t* end_sliced = end; - if (options.step >= 1) { - if (options.start >= 0) { + // First, compute begin_sliced and end_sliced + if (opt.start >= 0) { // start counting from the left - PROPAGATE_FALSE( - arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, options.start)); - if (options.stop > options.start) { + RETURN_IF_UTF8_ERROR( + arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opt.start)); + if (opt.stop > opt.start) { // continue counting from begin_sliced - int64_t length = options.stop - options.start; - PROPAGATE_FALSE( + const int64_t length = opt.stop - opt.start; + RETURN_IF_UTF8_ERROR( arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)); - } else if (options.stop < 0) { + } else if (opt.stop < 0) { // or from the end (but we will never need to < begin_sliced) - PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse( - begin_sliced, end, &end_sliced, -options.stop)); + RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( + begin_sliced, end, &end_sliced, -opt.stop)); } else { // zero length slice - *output_written = 0; - return true; + return 0; } } else { // start counting from the right - PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, - -options.start)); - if (options.stop > 0) { + RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( + begin, end, &begin_sliced, -opt.start)); + if (opt.stop > 0) { // continue counting from the left, we cannot start from begin_sliced because we // don't know how many codepoints are between begin and begin_sliced - PROPAGATE_FALSE( - arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, options.stop)); + RETURN_IF_UTF8_ERROR( + arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opt.stop)); // and therefore we also needs this if (end_sliced <= begin_sliced) { // zero length slice - *output_written = 0; - return true; + return 0; } - } else if ((options.stop < 0) && (options.stop > options.start)) { + } else if ((opt.stop < 0) && (opt.stop > opt.start)) { // stop is negative, but larger than start, so we count again from the right // in some cases we can optimize this, depending on the shortest path (from end - // or begin_sliced), but begin_sliced and options.start can be 'out of sync', + // or begin_sliced), but begin_sliced and opt.start can be 'out of sync', // for instance when start=-100, when the string length is only 10. - PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse( - begin_sliced, end, &end_sliced, -options.stop)); + RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( + begin_sliced, end, &end_sliced, -opt.stop)); } else { // zero length slice - *output_written = 0; - return true; + return 0; } } + + // Second, copy computed slice to output DCHECK(begin_sliced <= end_sliced); - if (options.step == 1) { + if (opt.step == 1) { // fast case, where we simply can finish with a memcpy std::copy(begin_sliced, end_sliced, output); - *output_written = end_sliced - begin_sliced; - } else { - uint8_t* dest = output; - const uint8_t* i = begin_sliced; - - while (i < end_sliced) { - uint32_t codepoint = 0; - // write a single codepoint - PROPAGATE_FALSE(arrow::util::UTF8Decode(&i, &codepoint)); - dest = arrow::util::UTF8Encode(dest, codepoint); - // and skip the remainder - int64_t skips = options.step - 1; - while ((skips--) && (i < end_sliced)) { - PROPAGATE_FALSE(arrow::util::UTF8Decode(&i, &codepoint)); - } + return end_sliced - begin_sliced; + } + uint8_t* dest = output; + const uint8_t* i = begin_sliced; + + while (i < end_sliced) { + uint32_t codepoint = 0; + // write a single codepoint + RETURN_IF_UTF8_ERROR(arrow::util::UTF8Decode(&i, &codepoint)); + dest = arrow::util::UTF8Encode(dest, codepoint); + // and skip the remainder + int64_t skips = opt.step - 1; + while ((skips--) && (i < end_sliced)) { + RETURN_IF_UTF8_ERROR(arrow::util::UTF8Decode(&i, &codepoint)); } - *output_written = dest - output; } - return true; - } else { // step < 0 - // serious +1 -1 kung fu because now begin_slice and end_slice act like reverse - // iterators. + return dest - output; + } - if (options.start >= 0) { + int64_t SliceBackward(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + // Slice in reverse order (step < 0) + const SliceOptions& opt = *this->options; + const uint8_t* begin = input; + const uint8_t* end = input + input_string_ncodeunits; + const uint8_t* begin_sliced = begin; + const uint8_t* end_sliced = end; + + // Serious +1 -1 kung fu because begin_sliced and end_sliced act like + // reverse iterators. + if (opt.start >= 0) { // +1 because begin_sliced acts as as the end of a reverse iterator - PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, - options.start + 1)); - // and make it point at the last codeunit of the previous codeunit - begin_sliced--; + RETURN_IF_UTF8_ERROR( + arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opt.start + 1)); } else { // -1 because start=-1 means the last codeunit, which is 0 advances - PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, - -options.start - 1)); - // and make it point at the last codeunit of the previous codeunit - begin_sliced--; - } - // similar to options.start - if (options.stop >= 0) { - PROPAGATE_FALSE( - arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, options.stop + 1)); - end_sliced--; + RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( + begin, end, &begin_sliced, -opt.start - 1)); + } + // make it point at the last codeunit of the previous codeunit + begin_sliced--; + + // similar to opt.start + if (opt.stop >= 0) { + RETURN_IF_UTF8_ERROR( + arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opt.stop + 1)); } else { - PROPAGATE_FALSE(arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &end_sliced, - -options.stop - 1)); - end_sliced--; + RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( + begin, end, &end_sliced, -opt.stop - 1)); } + end_sliced--; + // Copy computed slice to output uint8_t* dest = output; const uint8_t* i = begin_sliced; - while (i > end_sliced) { uint32_t codepoint = 0; // write a single codepoint - PROPAGATE_FALSE(arrow::util::UTF8DecodeReverse(&i, &codepoint)); + RETURN_IF_UTF8_ERROR(arrow::util::UTF8DecodeReverse(&i, &codepoint)); dest = arrow::util::UTF8Encode(dest, codepoint); // and skip the remainder - int64_t skips = -options.step - 1; + int64_t skips = -opt.step - 1; while ((skips--) && (i > end_sliced)) { - PROPAGATE_FALSE(arrow::util::UTF8DecodeReverse(&i, &codepoint)); + RETURN_IF_UTF8_ERROR(arrow::util::UTF8DecodeReverse(&i, &codepoint)); } } - *output_written = dest - output; - return true; + return dest - output; } -} -#undef PROPAGATE_FALSE +#undef RETURN_IF_UTF8_ERROR +}; template -struct SliceCodeunits : SliceBase> { - using Base = SliceBase>; - using offset_type = typename Base::offset_type; - using Base::Base; - - int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { - const SliceOptions& opt = this->options; - if ((opt.start >= 0) != (opt.stop >= 0)) { - // If start and stop don't have the same sign, we can't guess an upper bound - // on the resulting slice lengths, so return a worst case estimate. - return input_ncodeunits; - } - int64_t max_slice_codepoints = (opt.stop - opt.start + opt.step - 1) / opt.step; - // The maximum UTF8 byte size of a codepoint is 4 - return std::min(input_ncodeunits, - 4 * ninputs * std::max(0, max_slice_codepoints)); - } - - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { - int64_t output_written_64; - bool res = SliceCodeunitsTransform(input, input_string_ncodeunits, output, - &output_written_64, this->options); - *output_written = static_cast(output_written_64); - return res; - } -}; +using SliceCodeunits = StringTransformExec; const FunctionDoc utf8_slice_codeunits_doc( "Slice string ", @@ -1170,10 +1212,13 @@ void AddSlice(FunctionRegistry* registry) { &utf8_slice_codeunits_doc); using t32 = SliceCodeunits; using t64 = SliceCodeunits; - DCHECK_OK(func->AddKernel({utf8()}, utf8(), t32::Exec, t32::State::Init)); - DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), t64::Exec, t64::State::Init)); + DCHECK_OK( + func->AddKernel({utf8()}, utf8(), t32::Exec, SliceCodeunitsTransform::State::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), t64::Exec, + SliceCodeunitsTransform::State::Init)); DCHECK_OK(registry->AddFunction(std::move(func))); } + // IsAlpha/Digit etc #ifdef ARROW_WITH_UTF8PROC @@ -1583,8 +1628,25 @@ struct IsUpperAscii : CharacterPredicateAscii { // splitting -template -struct SplitBaseTransform { +template +struct SplitFinderBase { + virtual Status PreExec(const Options& options) { return Status::OK(); } + + // Derived classes should also define these methods: + // static bool Find(const uint8_t* begin, const uint8_t* end, + // const uint8_t** separator_begin, + // const uint8_t** separator_end, + // const SplitPatternOptions& options); + // + // static bool FindReverse(const uint8_t* begin, const uint8_t* end, + // const uint8_t** separator_begin, + // const uint8_t** separator_end, + // const SplitPatternOptions& options); +}; + +template +struct SplitExec { using string_offset_type = typename Type::offset_type; using list_offset_type = typename ListType::offset_type; using ArrayType = typename TypeTraits::ArrayType; @@ -1595,12 +1657,75 @@ struct SplitBaseTransform { using ListOffsetsBuilderType = TypedBufferBuilder; using State = OptionsWrapper; + // Keep the temporary storage accross individual values, to minimize reallocations std::vector parts; Options options; - explicit SplitBaseTransform(Options options) : options(options) {} + explicit SplitExec(const Options& options) : options(options) {} + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return SplitExec{State::Get(ctx)}.Execute(ctx, batch, out); + } + + Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + SplitFinder finder; + RETURN_NOT_OK(finder.PreExec(options)); + if (batch[0].kind() == Datum::ARRAY) { + return Execute(ctx, &finder, batch[0].array(), out); + } + DCHECK_EQ(batch[0].kind(), Datum::SCALAR); + return Execute(ctx, &finder, batch[0].scalar(), out); + } + + Status Execute(KernelContext* ctx, SplitFinder* finder, + const std::shared_ptr& data, Datum* out) { + const ArrayType input(data); + + BuilderType builder(input.type(), ctx->memory_pool()); + // A slight overestimate of the data needed + RETURN_NOT_OK(builder.ReserveData(input.total_values_length())); + // The minimum amount of strings needed + RETURN_NOT_OK(builder.Resize(input.length() - input.null_count())); + + ArrayData* output_list = out->mutable_array(); + // List offsets were preallocated + auto* list_offsets = output_list->GetMutableValues(1); + DCHECK_NE(list_offsets, nullptr); + // Initial value + *list_offsets++ = 0; + for (int64_t i = 0; i < input.length(); ++i) { + if (!input.IsNull(i)) { + RETURN_NOT_OK(SplitString(input.GetView(i), finder, &builder)); + if (ARROW_PREDICT_FALSE(builder.length() > + std::numeric_limits::max())) { + return Status::CapacityError("List offset does not fit into 32 bit"); + } + } + *list_offsets++ = static_cast(builder.length()); + } + // Assign string array to list child data + std::shared_ptr string_array; + RETURN_NOT_OK(builder.Finish(&string_array)); + output_list->child_data.push_back(string_array->data()); + return Status::OK(); + } + + Status Execute(KernelContext* ctx, SplitFinder* finder, + const std::shared_ptr& scalar, Datum* out) { + const auto& input = checked_cast(*scalar); + auto result = checked_cast(out->scalar().get()); + if (input.is_valid) { + result->is_valid = true; + BuilderType builder(input.type, ctx->memory_pool()); + util::string_view s(*input.value); + RETURN_NOT_OK(SplitString(s, finder, &builder)); + RETURN_NOT_OK(builder.Finish(&result->value)); + } + return Status::OK(); + } - Status Split(const util::string_view& s, BuilderType* builder) { + Status SplitString(const util::string_view& s, SplitFinder* finder, + BuilderType* builder) { const uint8_t* begin = reinterpret_cast(s.data()); const uint8_t* end = begin + s.length(); @@ -1618,8 +1743,7 @@ struct SplitBaseTransform { while (max_splits != 0) { const uint8_t *separator_begin, *separator_end; // find with whatever algo the part we will 'cut out' - if (static_cast(*this).FindReverse(begin, i, &separator_begin, - &separator_end, options)) { + if (finder->FindReverse(begin, i, &separator_begin, &separator_end, options)) { parts.emplace_back(reinterpret_cast(separator_end), i - separator_end); i = separator_begin; @@ -1639,8 +1763,7 @@ struct SplitBaseTransform { while (max_splits != 0) { const uint8_t *separator_begin, *separator_end; // find with whatever algo the part we will 'cut out' - if (static_cast(*this).Find(i, end, &separator_begin, &separator_end, - options)) { + if (finder->Find(i, end, &separator_begin, &separator_end, options)) { // the part till the beginning of the 'cut' RETURN_NOT_OK( builder->Append(i, static_cast(separator_begin - i))); @@ -1656,85 +1779,13 @@ struct SplitBaseTransform { } return Status::OK(); } - - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - Options options = State::Get(ctx); - Derived splitter(options); // we make an instance to reuse the parts vectors - RETURN_NOT_OK(splitter.CheckOptions()); - return splitter.Split(ctx, batch, out); - } - - Status CheckOptions() { return Status::OK(); } - - Status Split(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - EnsureLookupTablesFilled(); // only needed for unicode - - if (batch[0].kind() == Datum::ARRAY) { - const ArrayData& input = *batch[0].array(); - ArrayType input_boxed(batch[0].array()); - - BuilderType builder(input.type, ctx->memory_pool()); - // a slight overestimate of the data needed - RETURN_NOT_OK(builder.ReserveData(input_boxed.total_values_length())); - // the minimum amount of strings needed - RETURN_NOT_OK(builder.Resize(input.length)); - - ArrayData* output_list = out->mutable_array(); - // list offsets were preallocated - auto* list_offsets = output_list->GetMutableValues(1); - DCHECK_NE(list_offsets, nullptr); - // initial value - *list_offsets++ = 0; - RETURN_NOT_OK(VisitArrayDataInline( - input, - [&](util::string_view s) { - RETURN_NOT_OK(Split(s, &builder)); - if (ARROW_PREDICT_FALSE(builder.length() > - std::numeric_limits::max())) { - return Status::CapacityError("List offset does not fit into 32 bit"); - } - *list_offsets++ = static_cast(builder.length()); - return Status::OK(); - }, - [&]() { - // null value is already taken from input - *list_offsets++ = static_cast(builder.length()); - return Status::OK(); - })); - // assign list child data - std::shared_ptr string_array; - RETURN_NOT_OK(builder.Finish(&string_array)); - output_list->child_data.push_back(string_array->data()); - - } else { - const auto& input = checked_cast(*batch[0].scalar()); - auto result = checked_pointer_cast(MakeNullScalar(out->type())); - if (input.is_valid) { - result->is_valid = true; - BuilderType builder(input.type, ctx->memory_pool()); - util::string_view s(*input.value); - RETURN_NOT_OK(Split(s, &builder)); - RETURN_NOT_OK(builder.Finish(&result->value)); - } - out->value = result; - } - - return Status::OK(); - } }; -template -struct SplitPatternTransform : SplitBaseTransform> { - using Base = SplitBaseTransform>; - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - using string_offset_type = typename Type::offset_type; - using Base::Base; +struct SplitPatternFinder : public SplitFinderBase { + using Options = SplitPatternOptions; - Status CheckOptions() { - if (Base::options.pattern.length() == 0) { + Status PreExec(const SplitPatternOptions& options) override { + if (options.pattern.length() == 0) { return Status::Invalid("Empty separator"); } return Status::OK(); @@ -1782,6 +1833,9 @@ struct SplitPatternTransform : SplitBaseTransform +using SplitPatternExec = SplitExec; + const FunctionDoc split_pattern_doc( "Split string according to separator", ("Split each string according to the exact `pattern` defined in\n" @@ -1815,29 +1869,22 @@ const FunctionDoc utf8_split_whitespace_doc( void AddSplitPattern(FunctionRegistry* registry) { auto func = std::make_shared("split_pattern", Arity::Unary(), &split_pattern_doc); - using t32 = SplitPatternTransform; - using t64 = SplitPatternTransform; + using t32 = SplitPatternExec; + using t64 = SplitPatternExec; DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); DCHECK_OK( func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); DCHECK_OK(registry->AddFunction(std::move(func))); } -template -struct SplitWhitespaceAsciiTransform - : SplitBaseTransform> { - using Base = SplitBaseTransform>; - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - using string_offset_type = typename Type::offset_type; - using Base::Base; +struct SplitWhitespaceAsciiFinder : public SplitFinderBase { + using Options = SplitOptions; + static bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin, const uint8_t** separator_end, const SplitOptions& options) { const uint8_t* i = begin; - while ((i < end)) { + while (i < end) { if (IsSpaceCharacterAscii(*i)) { *separator_begin = i; do { @@ -1850,6 +1897,7 @@ struct SplitWhitespaceAsciiTransform } return false; } + static bool FindReverse(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin, const uint8_t** separator_end, const SplitOptions& options) { @@ -1869,13 +1917,16 @@ struct SplitWhitespaceAsciiTransform } }; +template +using SplitWhitespaceAsciiExec = SplitExec; + void AddSplitWhitespaceAscii(FunctionRegistry* registry) { static const SplitOptions default_options{}; auto func = std::make_shared("ascii_split_whitespace", Arity::Unary(), &ascii_split_whitespace_doc, &default_options); - using t32 = SplitWhitespaceAsciiTransform; - using t64 = SplitWhitespaceAsciiTransform; + using t32 = SplitWhitespaceAsciiExec; + using t64 = SplitWhitespaceAsciiExec; DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); DCHECK_OK( func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); @@ -1883,19 +1934,16 @@ void AddSplitWhitespaceAscii(FunctionRegistry* registry) { } #ifdef ARROW_WITH_UTF8PROC -template -struct SplitWhitespaceUtf8Transform - : SplitBaseTransform> { - using Base = SplitBaseTransform>; - using ArrayType = typename TypeTraits::ArrayType; - using string_offset_type = typename Type::offset_type; - using ScalarType = typename TypeTraits::ScalarType; - using Base::Base; - static bool Find(const uint8_t* begin, const uint8_t* end, - const uint8_t** separator_begin, const uint8_t** separator_end, - const SplitOptions& options) { +struct SplitWhitespaceUtf8Finder : public SplitFinderBase { + using Options = SplitOptions; + + Status PreExec(const SplitOptions& options) override { + EnsureLookupTablesFilled(); + return Status::OK(); + } + + bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin, + const uint8_t** separator_end, const SplitOptions& options) { const uint8_t* i = begin; while ((i < end)) { uint32_t codepoint = 0; @@ -1915,9 +1963,10 @@ struct SplitWhitespaceUtf8Transform } return false; } - static bool FindReverse(const uint8_t* begin, const uint8_t* end, - const uint8_t** separator_begin, const uint8_t** separator_end, - const SplitOptions& options) { + + bool FindReverse(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitOptions& options) { const uint8_t* i = end - 1; while ((i >= begin)) { uint32_t codepoint = 0; @@ -1939,73 +1988,68 @@ struct SplitWhitespaceUtf8Transform } }; +template +using SplitWhitespaceUtf8Exec = SplitExec; + void AddSplitWhitespaceUTF8(FunctionRegistry* registry) { static const SplitOptions default_options{}; auto func = std::make_shared("utf8_split_whitespace", Arity::Unary(), &utf8_split_whitespace_doc, &default_options); - using t32 = SplitWhitespaceUtf8Transform; - using t64 = SplitWhitespaceUtf8Transform; + using t32 = SplitWhitespaceUtf8Exec; + using t64 = SplitWhitespaceUtf8Exec; DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); DCHECK_OK( func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); DCHECK_OK(registry->AddFunction(std::move(func))); } -#endif +#endif // ARROW_WITH_UTF8PROC #ifdef ARROW_WITH_RE2 -template -struct SplitRegexTransform : SplitBaseTransform> { - using Base = SplitBaseTransform>; - using ArrayType = typename TypeTraits::ArrayType; - using string_offset_type = typename Type::offset_type; - using ScalarType = typename TypeTraits::ScalarType; +struct SplitRegexFinder : public SplitFinderBase { + using Options = SplitPatternOptions; - const RE2 regex_split; + util::optional regex_split; - explicit SplitRegexTransform(SplitPatternOptions options) - : Base(options), regex_split(MakePattern(options)) {} - - static std::string MakePattern(const SplitPatternOptions& options) { + Status PreExec(const SplitPatternOptions& options) override { + if (options.reverse) { + return Status::NotImplemented("Cannot split in reverse with regex"); + } // RE2 does *not* give you the full match! Must wrap the regex in a capture group // There is FindAndConsume, but it would give only the end of the separator std::string pattern = "("; pattern.reserve(options.pattern.size() + 2); pattern += options.pattern; pattern += ')'; - return pattern; - } - - Status CheckOptions() { - if (Base::options.reverse) { - return Status::NotImplemented("Cannot split in reverse with regex"); - } - return RegexStatus(regex_split); + regex_split.emplace(std::move(pattern)); + return RegexStatus(*regex_split); } bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin, - const uint8_t** separator_end, const SplitOptions& options) { + const uint8_t** separator_end, const SplitPatternOptions& options) { re2::StringPiece piece(reinterpret_cast(begin), std::distance(begin, end)); // "StringPiece is mutated to point to matched piece" re2::StringPiece result; - if (!re2::RE2::PartialMatch(piece, regex_split, &result)) { + if (!re2::RE2::PartialMatch(piece, *regex_split, &result)) { return false; } *separator_begin = reinterpret_cast(result.data()); *separator_end = reinterpret_cast(result.data() + result.size()); return true; } + bool FindReverse(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin, const uint8_t** separator_end, - const SplitOptions& options) { - // Not easily supportable, unfortunately + const SplitPatternOptions& options) { + // Unsupported (see PreExec) return false; } }; +template +using SplitRegexExec = SplitExec; + const FunctionDoc split_pattern_regex_doc( "Split string according to regex pattern", ("Split each string according to the regex `pattern` defined in\n" @@ -2019,14 +2063,14 @@ const FunctionDoc split_pattern_regex_doc( void AddSplitRegex(FunctionRegistry* registry) { auto func = std::make_shared("split_pattern_regex", Arity::Unary(), &split_pattern_regex_doc); - using t32 = SplitRegexTransform; - using t64 = SplitRegexTransform; + using t32 = SplitRegexExec; + using t64 = SplitRegexExec; DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); DCHECK_OK( func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); DCHECK_OK(registry->AddFunction(std::move(func))); } -#endif +#endif // ARROW_WITH_RE2 void AddSplit(FunctionRegistry* registry) { AddSplitPattern(registry); @@ -2477,56 +2521,54 @@ Result StrptimeResolve(KernelContext* ctx, const std::vector -struct UTF8TrimWhitespaceBase : StringTransform { - using Base = StringTransform; - using offset_type = typename Base::offset_type; - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { +template +struct UTF8TrimWhitespaceTransform : public StringTransformBase { + Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override { + EnsureLookupTablesFilled(); + return Status::OK(); + } + + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { const uint8_t* begin = input; const uint8_t* end = input + input_string_ncodeunits; const uint8_t* end_trimmed = end; const uint8_t* begin_trimmed = begin; auto predicate = [](uint32_t c) { return !IsSpaceCharacterUnicode(c); }; - if (left && !ARROW_PREDICT_TRUE( - arrow::util::UTF8FindIf(begin, end, predicate, &begin_trimmed))) { - return false; + if (TrimLeft && !ARROW_PREDICT_TRUE( + arrow::util::UTF8FindIf(begin, end, predicate, &begin_trimmed))) { + return kTransformError; } - if (right && (begin_trimmed < end)) { + if (TrimRight && begin_trimmed < end) { if (!ARROW_PREDICT_TRUE(arrow::util::UTF8FindIfReverse(begin_trimmed, end, predicate, &end_trimmed))) { - return false; + return kTransformError; } } std::copy(begin_trimmed, end_trimmed, output); - *output_written = static_cast(end_trimmed - begin_trimmed); - return true; - } - Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - EnsureLookupTablesFilled(); - return Base::Execute(ctx, batch, out); + return end_trimmed - begin_trimmed; } }; template -struct UTF8TrimWhitespace - : UTF8TrimWhitespaceBase> {}; +using UTF8TrimWhitespace = + StringTransformExec>; template -struct UTF8LTrimWhitespace - : UTF8TrimWhitespaceBase> {}; +using UTF8LTrimWhitespace = + StringTransformExec>; template -struct UTF8RTrimWhitespace - : UTF8TrimWhitespaceBase> {}; +using UTF8RTrimWhitespace = + StringTransformExec>; -struct TrimStateUTF8 { +struct UTF8TrimState { TrimOptions options_; std::vector codepoints_; Status status_ = Status::OK(); - explicit TrimStateUTF8(KernelContext* ctx, TrimOptions options) + explicit UTF8TrimState(KernelContext* ctx, TrimOptions options) : options_(std::move(options)) { if (!ARROW_PREDICT_TRUE( arrow::util::UTF8ForEach(options_.characters, [&](uint32_t c) { @@ -2539,167 +2581,136 @@ struct TrimStateUTF8 { } }; -template -struct UTF8TrimBase : StringTransform { - using Base = StringTransform; - using offset_type = typename Base::offset_type; - using State = KernelStateFromFunctionOptions; - TrimStateUTF8 state_; +template +struct UTF8TrimTransform : public StringTransformBase { + using State = KernelStateFromFunctionOptions; - explicit UTF8TrimBase(TrimStateUTF8 state) : state_(std::move(state)) {} + const UTF8TrimState& state_; - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - TrimStateUTF8 state = State::Get(ctx); - RETURN_NOT_OK(state.status_); - return Derived(state).Execute(ctx, batch, out); - } + explicit UTF8TrimTransform(const UTF8TrimState& state) : state_(state) {} - Status Execute(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - EnsureLookupTablesFilled(); - return Base::Execute(ctx, batch, out); + Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override { + return state_.status_; } - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { const uint8_t* begin = input; const uint8_t* end = input + input_string_ncodeunits; const uint8_t* end_trimmed = end; const uint8_t* begin_trimmed = begin; - auto predicate = [&](uint32_t c) { - bool contains = state_.codepoints_[c]; - return !contains; - }; - if (left && !ARROW_PREDICT_TRUE( - arrow::util::UTF8FindIf(begin, end, predicate, &begin_trimmed))) { - return false; + auto predicate = [&](uint32_t c) { return !state_.codepoints_[c]; }; + if (TrimLeft && !ARROW_PREDICT_TRUE( + arrow::util::UTF8FindIf(begin, end, predicate, &begin_trimmed))) { + return kTransformError; } - if (right && (begin_trimmed < end)) { + if (TrimRight && begin_trimmed < end) { if (!ARROW_PREDICT_TRUE(arrow::util::UTF8FindIfReverse(begin_trimmed, end, predicate, &end_trimmed))) { - return false; + return kTransformError; } } std::copy(begin_trimmed, end_trimmed, output); - *output_written = static_cast(end_trimmed - begin_trimmed); - return true; + return end_trimmed - begin_trimmed; } }; template -struct UTF8Trim : UTF8TrimBase> { - using Base = UTF8TrimBase>; - using Base::Base; -}; +using UTF8Trim = StringTransformExecWithState>; template -struct UTF8LTrim : UTF8TrimBase> { - using Base = UTF8TrimBase>; - using Base::Base; -}; +using UTF8LTrim = StringTransformExecWithState>; template -struct UTF8RTrim : UTF8TrimBase> { - using Base = UTF8TrimBase>; - using Base::Base; -}; +using UTF8RTrim = StringTransformExecWithState>; #endif -template -struct AsciiTrimWhitespaceBase : StringTransform { - using offset_type = typename Type::offset_type; - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { +template +struct AsciiTrimWhitespaceTransform : public StringTransformBase { + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { const uint8_t* begin = input; const uint8_t* end = input + input_string_ncodeunits; const uint8_t* end_trimmed = end; + const uint8_t* begin_trimmed = begin; auto predicate = [](unsigned char c) { return !IsSpaceCharacterAscii(c); }; - const uint8_t* begin_trimmed = left ? std::find_if(begin, end, predicate) : begin; - if (right & (begin_trimmed < end)) { + if (TrimLeft) { + begin_trimmed = std::find_if(begin, end, predicate); + } + if (TrimRight && begin_trimmed < end) { std::reverse_iterator rbegin(end); std::reverse_iterator rend(begin_trimmed); end_trimmed = std::find_if(rbegin, rend, predicate).base(); } std::copy(begin_trimmed, end_trimmed, output); - *output_written = static_cast(end_trimmed - begin_trimmed); - return true; + return end_trimmed - begin_trimmed; } }; template -struct AsciiTrimWhitespace - : AsciiTrimWhitespaceBase> {}; +using AsciiTrimWhitespace = + StringTransformExec>; template -struct AsciiLTrimWhitespace - : AsciiTrimWhitespaceBase> {}; +using AsciiLTrimWhitespace = + StringTransformExec>; template -struct AsciiRTrimWhitespace - : AsciiTrimWhitespaceBase> {}; - -template -struct AsciiTrimBase : StringTransform { - using Base = StringTransform; - using offset_type = typename Base::offset_type; - using State = OptionsWrapper; +using AsciiRTrimWhitespace = + StringTransformExec>; + +struct AsciiTrimState { TrimOptions options_; std::vector characters_; - explicit AsciiTrimBase(TrimOptions options) + explicit AsciiTrimState(KernelContext* ctx, TrimOptions options) : options_(std::move(options)), characters_(256) { - std::for_each(options_.characters.begin(), options_.characters.end(), - [&](char c) { characters_[static_cast(c)] = true; }); + for (const auto c : options_.characters) { + characters_[static_cast(c)] = true; + } } +}; - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - TrimOptions options = State::Get(ctx); - return Derived(options).Execute(ctx, batch, out); - } +template +struct AsciiTrimTransform : public StringTransformBase { + using State = KernelStateFromFunctionOptions; + + const AsciiTrimState& state_; + + explicit AsciiTrimTransform(const AsciiTrimState& state) : state_(state) {} - bool Transform(const uint8_t* input, offset_type input_string_ncodeunits, - uint8_t* output, offset_type* output_written) { + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { const uint8_t* begin = input; const uint8_t* end = input + input_string_ncodeunits; const uint8_t* end_trimmed = end; - const uint8_t* begin_trimmed; - - auto predicate = [&](unsigned char c) { - bool contains = characters_[c]; - return !contains; - }; + const uint8_t* begin_trimmed = begin; - begin_trimmed = left ? std::find_if(begin, end, predicate) : begin; - if (right & (begin_trimmed < end)) { + auto predicate = [&](uint8_t c) { return !state_.characters_[c]; }; + if (TrimLeft) { + begin_trimmed = std::find_if(begin, end, predicate); + } + if (TrimRight && begin_trimmed < end) { std::reverse_iterator rbegin(end); std::reverse_iterator rend(begin_trimmed); end_trimmed = std::find_if(rbegin, rend, predicate).base(); } std::copy(begin_trimmed, end_trimmed, output); - *output_written = static_cast(end_trimmed - begin_trimmed); - return true; + return end_trimmed - begin_trimmed; } }; template -struct AsciiTrim : AsciiTrimBase> { - using Base = AsciiTrimBase>; - using Base::Base; -}; +using AsciiTrim = StringTransformExecWithState>; template -struct AsciiLTrim : AsciiTrimBase> { - using Base = AsciiTrimBase>; - using Base::Base; -}; +using AsciiLTrim = StringTransformExecWithState>; template -struct AsciiRTrim : AsciiTrimBase> { - using Base = AsciiTrimBase>; - using Base::Base; -}; +using AsciiRTrim = StringTransformExecWithState>; const FunctionDoc utf8_trim_whitespace_doc( "Trim leading and trailing whitespace characters", diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index f015e339423..c4b6956be2b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -144,10 +144,13 @@ TYPED_TEST(TestStringKernels, AsciiReverse) { this->CheckUnary("ascii_reverse", R"(["abcd", null, "", "bbb"])", this->type(), R"(["dcba", null, "", "bbb"])"); - Datum invalid_input = ArrayFromJSON(this->type(), R"(["aAazZæÆ&", null, "", "bbb"])"); + auto invalid_input = ArrayFromJSON(this->type(), R"(["aAazZæÆ&", null, "", "bcd"])"); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("Non-ASCII sequence in input"), CallFunction("ascii_reverse", {invalid_input})); + auto masked_input = TweakValidityBit(invalid_input, 0, false); + CheckScalarUnary("ascii_reverse", masked_input, + ArrayFromJSON(this->type(), R"([null, null, "", "dcb"])")); } TYPED_TEST(TestStringKernels, Utf8Reverse) { diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 39bd665d5b6..eb0edd56566 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -553,6 +553,19 @@ void ApproxCompareBatch(const RecordBatch& left, const RecordBatch& right, [](const Array& left, const Array& right) { return left.ApproxEquals(right); }); } +std::shared_ptr TweakValidityBit(const std::shared_ptr& array, + int64_t index, bool validity) { + auto data = array->data()->Copy(); + if (data->buffers[0] == nullptr) { + data->buffers[0] = *AllocateBitmap(data->length); + BitUtil::SetBitsTo(data->buffers[0]->mutable_data(), 0, data->length, true); + } + BitUtil::SetBitTo(data->buffers[0]->mutable_data(), index, validity); + data->null_count = kUnknownNullCount; + // Need to return a new array, because Array caches the null bitmap pointer + return MakeArray(data); +} + class LocaleGuard::Impl { public: explicit Impl(const char* new_locale) : global_locale_(std::locale()) { diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index b8ea8e76298..9d01cd4bf27 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -444,6 +444,14 @@ inline void BitmapFromVector(const std::vector& is_valid, ASSERT_OK(GetBitmapFromVector(is_valid, out)); } +// Given an array, return a new identical array except for one validity bit +// set to a new value. +// This is useful to force the underlying "value" of null entries to otherwise +// invalid data and check that errors don't get reported. +ARROW_TESTING_EXPORT +std::shared_ptr TweakValidityBit(const std::shared_ptr& array, + int64_t index, bool validity); + ARROW_TESTING_EXPORT void SleepFor(double seconds);