From 282c6baa43d187e1ad59cb3077e6d7bc814c9f53 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 3 Sep 2021 19:03:22 -0400 Subject: [PATCH 01/84] add kernel executor for string binary functions --- .../arrow/compute/kernels/scalar_string.cc | 219 +++++++++++++++++- 1 file changed, 208 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 3fce9dd8e4a..80617fc7410 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -346,21 +346,23 @@ struct StringTransformExecBase { static Status Execute(KernelContext* ctx, StringTransform* transform, const ExecBatch& batch, Datum* out) { + if (batch.num_values() != 1) { + return Status::Invalid("Invalid arity for unary string transform"); + } + if (batch[0].kind() == Datum::ARRAY) { return ExecArray(ctx, transform, batch[0].array(), out); + } else if (batch[0].kind() == Datum::SCALAR) { + return ExecScalar(ctx, transform, batch[0].scalar(), out); } - DCHECK_EQ(batch[0].kind(), Datum::SCALAR); - return ExecScalar(ctx, transform, batch[0].scalar(), out); + return Status::Invalid("Invalid ExecBatch kind for unary string transform"); } static Status ExecArray(KernelContext* ctx, StringTransform* transform, const std::shared_ptr& data, Datum* out) { ArrayType input(data); - ArrayData* output = out->mutable_array(); - const int64_t input_ncodeunits = input.total_values_length(); const int64_t input_nstrings = input.length(); - const int64_t output_ncodeunits_max = transform->MaxCodeunits(input_nstrings, input_ncodeunits); if (output_ncodeunits_max > std::numeric_limits::max()) { @@ -368,6 +370,7 @@ struct StringTransformExecBase { "Result might not fit in a 32bit utf8 array, convert to large_utf8"); } + ArrayData* output = out->mutable_array(); ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); output->buffers[2] = values_buffer; @@ -375,7 +378,6 @@ struct StringTransformExecBase { offset_type* output_string_offsets = output->GetMutableValues(1); uint8_t* output_str = output->buffers[2]->mutable_data(); offset_type output_ncodeunits = 0; - output_string_offsets[0] = 0; for (int64_t i = 0; i < input_nstrings; i++) { if (!input.IsNull(i)) { @@ -402,16 +404,16 @@ struct StringTransformExecBase { if (!input.is_valid) { return Status::OK(); } - auto* result = checked_cast(out->scalar().get()); - result->is_valid = true; const int64_t data_nbytes = static_cast(input.value->size()); - const int64_t output_ncodeunits_max = transform->MaxCodeunits(1, data_nbytes); if (output_ncodeunits_max > std::numeric_limits::max()) { return Status::CapacityError( "Result might not fit in a 32bit utf8 array, convert to large_utf8"); } + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); + auto* result = checked_cast(out->scalar().get()); + result->is_valid = true; result->value = value_buffer; auto encoded_nbytes = static_cast(transform->Transform( input.value->data(), data_nbytes, value_buffer->mutable_data())); @@ -537,6 +539,203 @@ struct FixedSizeBinaryTransformExecWithState } }; +/// Kernel exec generator for binary string transforms. +/// The first parameter is expected to always be a string type while the second parameter +/// is generic. It supports executions of the form: +/// * Scalar, Scalar +/// * Array, Scalar - scalar is broadcasted and paired with all values of array +/// * Array, Array - arrays are processed element-wise +/// * Scalar, Array - not supported by default +template +struct StringBinaryTransformExecBase { + using offset_type = typename Type1::offset_type; + using ArrayType1 = typename TypeTraits::ArrayType; + using ArrayType2 = typename TypeTraits::ArrayType; + + static Status Execute(KernelContext* ctx, StringTransform* transform, + const ExecBatch& batch, Datum* out) { + if (batch.num_values() != 2) { + return Status::Invalid("Invalid arity for binary string transform"); + } + + if (batch[0].is_array()) { + if (batch[1].is_array()) { + return ExecArrayArray(ctx, transform, batch[0].array(), batch[1].array(), out); + } else if (batch[1].is_scalar()) { + return ExecArrayScalar(ctx, transform, batch[0].array(), batch[1].scalar(), out); + } + } else if (batch[0].is_scalar()) { + if (batch[1].is_array()) { + return ExecScalarArray(ctx, transform, batch[0].scalar(), batch[1].array(), out); + } else if (batch[1].is_scalar()) { + return ExecScalarScalar(ctx, transform, batch[0].scalar(), batch[1].scalar(), + out); + } + } + return Status::Invalid("Invalid ExecBatch kind for binary string transform"); + } + + private: + static Status ExecScalarScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar1, + const std::shared_ptr& scalar2, Datum* out) { + const auto& input1 = checked_cast(*scalar1); + // TODO(edponce): How to validate inputs? For some kernels, returning null is ok + // (ascii_lower) but others not necessarily (concatenate) + if (!input1.is_valid) { + return Status::OK(); + } + + auto input_ncodeunits = input1.value->size(); + auto input_nstrings = 1; + auto output_ncodeunits_max = + transform->MaxCodeunits(input_nstrings, input_ncodeunits); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); + auto result = checked_cast(out->scalar().get()); + result->is_valid = true; + result->value = value_buffer; + auto output_str = value_buffer->mutable_data(); + + auto input1_string = input1.value->data(); + auto encoded_nbytes = + transform->Transform(input1_string, input_ncodeunits, scalar2, output_str); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); + } + DCHECK_LE(encoded_nbytes, output_ncodeunits_max); + return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); + } + + static Status ExecArrayScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data1, + const std::shared_ptr& scalar2, Datum* out) { + ArrayType1 input1(data1); + + auto input1_ncodeunits = input1.total_values_length(); + auto input1_nstrings = input1.length(); + auto output_ncodeunits_max = + transform->MaxCodeunits(input1_nstrings, input1_ncodeunits); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ArrayData* output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); + output->buffers[2] = values_buffer; + + // String offsets are preallocated + auto output_string_offsets = output->GetMutableValues(1); + auto output_str = output->buffers[2]->mutable_data(); + output_string_offsets[0] = 0; + + offset_type output_ncodeunits = 0; + for (int64_t i = 0; i < input1_nstrings; ++i) { + if (!input1.IsNull(i)) { + offset_type input1_string_ncodeunits; + auto input1_string = input1.GetValue(i, &input1_string_ncodeunits); + auto encoded_nbytes = + transform->Transform(input1_string, input1_string_ncodeunits, scalar2, + output_str + output_ncodeunits); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); + } + output_ncodeunits += encoded_nbytes; + } + output_string_offsets[i + 1] = output_ncodeunits; + } + DCHECK_LE(output_ncodeunits, output_ncodeunits_max); + + // Trim the codepoint buffer, since we allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + return Status::OK(); + } + + static Status ExecScalarArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar1, + const std::shared_ptr& data2, Datum* out) { + return Status::NotImplemented( + "Binary string transforms with (scalar, array) inputs are not supported for the " + "general case"); + } + + static Status ExecArrayArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data1, + const std::shared_ptr& data2, Datum* out) { + ArrayType1 input1(data1); + ArrayType2 input2(data2); + + auto input1_ncodeunits = input1.total_values_length(); + auto input1_nstrings = input1.length(); + auto output_ncodeunits_max = + transform->MaxCodeunits(input1_nstrings, input1_ncodeunits); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ArrayData* output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); + output->buffers[2] = values_buffer; + + // String offsets are preallocated + auto output_string_offsets = output->GetMutableValues(1); + auto output_str = output->buffers[2]->mutable_data(); + output_string_offsets[0] = 0; + + offset_type output_ncodeunits = 0; + for (int64_t i = 0; i < input1_nstrings; ++i) { + if (!input1.IsNull(i)) { + offset_type input1_string_ncodeunits; + auto input1_string = input1.GetValue(i, &input1_string_ncodeunits); + auto scalar2 = *input2.GetScalar(i); + auto encoded_nbytes = + transform->Transform(input1_string, input1_string_ncodeunits, scalar2, + output_str + output_ncodeunits); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); + } + output_ncodeunits += encoded_nbytes; + } + output_string_offsets[i + 1] = output_ncodeunits; + } + DCHECK_LE(output_ncodeunits, output_ncodeunits_max); + + // Trim the codepoint buffer, since we allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + } +}; + +template +struct StringBinaryTransformExec + : public StringBinaryTransformExecBase { + using StringBinaryTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform; + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + +template +struct StringBinaryTransformExecWithState + : public StringBinaryTransformExecBase { + using State = typename StringTransform::State; + using StringBinaryTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform(State::Get(ctx)); + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + #ifdef ARROW_WITH_UTF8PROC struct FunctionalCaseMappingTransform : public StringTransformBase { @@ -4430,7 +4629,6 @@ const FunctionDoc utf8_reverse_doc( "clusters. Hence, it will not correctly reverse grapheme clusters\n" "composed of multiple codepoints."), {"strings"}); - } // namespace void RegisterScalarStringAscii(FunctionRegistry* registry) { @@ -4454,7 +4652,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { &ascii_rtrim_whitespace_doc); MakeUnaryStringBatchKernel("ascii_reverse", registry, &ascii_reverse_doc); MakeUnaryStringBatchKernel("utf8_reverse", registry, &utf8_reverse_doc); - MakeUnaryStringBatchKernelWithState("ascii_center", registry, &ascii_center_doc); MakeUnaryStringBatchKernelWithState("ascii_lpad", registry, &ascii_lpad_doc); From b390acd25d73f5b8ece21ad8cf8311d63027de05 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 3 Sep 2021 22:21:42 -0400 Subject: [PATCH 02/84] add comments/notes --- cpp/src/arrow/compute/kernels/scalar_string.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 80617fc7410..dfdfad1c6ea 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -334,9 +334,13 @@ struct StringTransformBase { return Status::Invalid("Invalid UTF8 sequence in input"); } - // Derived classes should also define this method: + // Unary derived classes should define this method: // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, // uint8_t* output); + + // Binary derived classes should define this method: + // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, const + // std::shared_ptr& input2, uint8_t* output); }; template @@ -546,6 +550,8 @@ struct FixedSizeBinaryTransformExecWithState /// * Array, Scalar - scalar is broadcasted and paired with all values of array /// * Array, Array - arrays are processed element-wise /// * Scalar, Array - not supported by default +// TODO(edponce): For when second parameter is an array, need to specify a corresponding +// iterator/visitor. template struct StringBinaryTransformExecBase { using offset_type = typename Type1::offset_type; From e327e62d73d93d08f0180d4443c34b783f12c92e Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 6 Sep 2021 20:07:30 -0400 Subject: [PATCH 03/84] add second input to MaxCodeunits --- .../arrow/compute/kernels/scalar_string.cc | 57 ++++++++++++------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index dfdfad1c6ea..590947f0102 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -333,14 +333,6 @@ struct StringTransformBase { virtual Status InvalidStatus() { return Status::Invalid("Invalid UTF8 sequence in input"); } - - // Unary derived classes should define this method: - // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, - // uint8_t* output); - - // Binary derived classes should define this method: - // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, const - // std::shared_ptr& input2, uint8_t* output); }; template @@ -427,6 +419,10 @@ struct StringTransformExecBase { DCHECK_LE(encoded_nbytes, output_ncodeunits_max); return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); } + + // Unary derived classes should define this method: + // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + // uint8_t* output); }; template @@ -543,6 +539,26 @@ struct FixedSizeBinaryTransformExecWithState } }; +struct StringBinaryTransformBase { + virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); + } + + // Return the maximum total size of the output in codeunits (i.e. bytes) + // given input characteristics. + virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits, + const std::shared_ptr& input2) { + return input_ncodeunits; + } + + // Return the maximum total size of the output in codeunits (i.e. bytes) + // given input characteristics. + virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits, + const std::shared_ptr& data2) { + return input_ncodeunits; + } +}; + /// Kernel exec generator for binary string transforms. /// The first parameter is expected to always be a string type while the second parameter /// is generic. It supports executions of the form: @@ -550,8 +566,6 @@ struct FixedSizeBinaryTransformExecWithState /// * Array, Scalar - scalar is broadcasted and paired with all values of array /// * Array, Array - arrays are processed element-wise /// * Scalar, Array - not supported by default -// TODO(edponce): For when second parameter is an array, need to specify a corresponding -// iterator/visitor. template struct StringBinaryTransformExecBase { using offset_type = typename Type1::offset_type; @@ -585,17 +599,15 @@ struct StringBinaryTransformExecBase { static Status ExecScalarScalar(KernelContext* ctx, StringTransform* transform, const std::shared_ptr& scalar1, const std::shared_ptr& scalar2, Datum* out) { - const auto& input1 = checked_cast(*scalar1); - // TODO(edponce): How to validate inputs? For some kernels, returning null is ok - // (ascii_lower) but others not necessarily (concatenate) - if (!input1.is_valid) { + if (!scalar1->is_valid || !scalar2->is_valid) { return Status::OK(); } + const auto& input1 = checked_cast(*scalar1); auto input_ncodeunits = input1.value->size(); auto input_nstrings = 1; auto output_ncodeunits_max = - transform->MaxCodeunits(input_nstrings, input_ncodeunits); + transform->MaxCodeunits(input_nstrings, input_ncodeunits, scalar2); if (output_ncodeunits_max > std::numeric_limits::max()) { return Status::CapacityError( "Result might not fit in a 32bit utf8 array, convert to large_utf8"); @@ -620,12 +632,15 @@ struct StringBinaryTransformExecBase { static Status ExecArrayScalar(KernelContext* ctx, StringTransform* transform, const std::shared_ptr& data1, const std::shared_ptr& scalar2, Datum* out) { - ArrayType1 input1(data1); + if (!scalar2->is_valid) { + return Status::OK(); + } + ArrayType1 input1(data1); auto input1_ncodeunits = input1.total_values_length(); auto input1_nstrings = input1.length(); auto output_ncodeunits_max = - transform->MaxCodeunits(input1_nstrings, input1_ncodeunits); + transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, scalar2); if (output_ncodeunits_max > std::numeric_limits::max()) { return Status::CapacityError( "Result might not fit in a 32bit utf8 array, convert to large_utf8"); @@ -679,7 +694,7 @@ struct StringBinaryTransformExecBase { auto input1_ncodeunits = input1.total_values_length(); auto input1_nstrings = input1.length(); auto output_ncodeunits_max = - transform->MaxCodeunits(input1_nstrings, input1_ncodeunits); + transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, data2); if (output_ncodeunits_max > std::numeric_limits::max()) { return Status::CapacityError( "Result might not fit in a 32bit utf8 array, convert to large_utf8"); @@ -696,7 +711,7 @@ struct StringBinaryTransformExecBase { offset_type output_ncodeunits = 0; for (int64_t i = 0; i < input1_nstrings; ++i) { - if (!input1.IsNull(i)) { + if (!input1.IsNull(i) || !input2.IsNull(i)) { offset_type input1_string_ncodeunits; auto input1_string = input1.GetValue(i, &input1_string_ncodeunits); auto scalar2 = *input2.GetScalar(i); @@ -715,6 +730,10 @@ struct StringBinaryTransformExecBase { // Trim the codepoint buffer, since we allocated too much return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); } + + // Binary derived classes should define this method: + // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, const + // std::shared_ptr& input2, uint8_t* output); }; template From 91baa30009cb8e53478c970bb34b916a8993ee37 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 6 Sep 2021 20:32:53 -0400 Subject: [PATCH 04/84] add inheritance to have PreExec and InvalidStatus --- cpp/src/arrow/compute/kernels/scalar_string.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 590947f0102..629ff1aa0cb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -539,11 +539,7 @@ struct FixedSizeBinaryTransformExecWithState } }; -struct StringBinaryTransformBase { - virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - return Status::OK(); - } - +struct StringBinaryTransformBase : StringTransformBase { // Return the maximum total size of the output in codeunits (i.e. bytes) // given input characteristics. virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits, From 673b41342611f842baf82c6d35fcf0a8c4acd3c9 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 27 Aug 2021 12:08:57 -0400 Subject: [PATCH 05/84] add RepeatOptions --- cpp/src/arrow/compute/api_scalar.cc | 8 ++++++++ cpp/src/arrow/compute/api_scalar.h | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index e3fe1bdf73d..0c95db809d5 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -209,6 +209,8 @@ static auto kSplitPatternOptionsType = GetFunctionOptionsType(DataMember("repeats", &RepeatOptions::repeats)); static auto kReplaceSliceOptionsType = GetFunctionOptionsType( DataMember("start", &ReplaceSliceOptions::start), DataMember("stop", &ReplaceSliceOptions::stop), @@ -319,6 +321,11 @@ SplitPatternOptions::SplitPatternOptions(std::string pattern, int64_t max_splits SplitPatternOptions::SplitPatternOptions() : SplitPatternOptions("", -1, false) {} constexpr char SplitPatternOptions::kTypeName[]; +RepeatOptions::RepeatOptions(int64_t repeats) + : FunctionOptions(internal::kRepeatOptionsType), repeats(repeats) {} +RepeatOptions::RepeatOptions() : RepeatOptions(1){}; +constexpr char RepeatOptions::kTypeName[]; + ReplaceSliceOptions::ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement) : FunctionOptions(internal::kReplaceSliceOptionsType), @@ -440,6 +447,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSplitOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSplitPatternOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kRepeatOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSliceOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSubstringOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType)); diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 4bb18b37527..83b3917ef1a 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -166,6 +166,16 @@ class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { bool reverse; }; +class ARROW_EXPORT RepeatOptions : public FunctionOptions { + public: + explicit RepeatOptions(int64_t repeats); + RepeatOptions(); + constexpr static char const kTypeName[] = "RepeatOptions"; + + /// Number of repeats + int64_t repeats; +}; + class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { public: explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement); From 543174cfa5de7a24f19d248bb5cb80f5741dd548 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 27 Aug 2021 17:58:25 -0400 Subject: [PATCH 06/84] add str repeat kernel --- .../arrow/compute/kernels/scalar_string.cc | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 629ff1aa0cb..79ddc95b1e2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2733,6 +2733,31 @@ void AddSplit(FunctionRegistry* registry) { #endif } +struct StrRepeatTransform : public StringTransformBase { + using Options = RepeatOptions; + using State = OptionsWrapper; + + const Options* options; + + explicit StrRepeatTransform(const Options& options) : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return input_ncodeunits * std::max(options->repeats, 0); + } + + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + uint8_t* output_start = output; + for (int i = 0; i < options->repeats; ++i) { + output = std::copy(input, input + input_string_ncodeunits, output); + } + return output - output_start; + } +}; + +template +using StrRepeat = StringTransformExecWithState; + // ---------------------------------------------------------------------- // Replace substring (plain, regex) From 00be5a0146e60a58b853511e85fd06d41e5702c9 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 27 Aug 2021 18:26:28 -0400 Subject: [PATCH 07/84] add tests --- .../compute/kernels/scalar_string_test.cc | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index be22ef4a7c1..a2987ff32bf 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include @@ -1041,6 +1043,27 @@ TYPED_TEST(TestStringKernels, Utf8Title) { R"([null, "", "B", "Aaaz;Zææ&", "Ɑɽɽow", "Ii", "Ⱥ.Ⱥ.Ⱥ..Ⱥ", "Hello, World!", "Foo Bar;Héhé0Zop", "!%$^.,;"])"); } +TYPED_TEST(TestStringKernels, StrRepeat) { + RepeatOptions options; + this->CheckUnary("str_repeat", "[]", this->type(), "[]", &options); + + std::string values( + R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"); + std::vector> repeats_and_expected_map({ + {-1, R"(["", null, "", "", "", "", "", "", "", ""])"}, + {0, R"(["", null, "", "", "", "", "", "", "", ""])"}, + {1, + R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"}, + {3, + R"(["aAazZæÆ&aAazZæÆ&aAazZæÆ&", null, "", "bbb", "ɑɽⱤoWɑɽⱤoWɑɽⱤoW", "ıIıIıI", "ⱥⱥⱥȺⱥⱥⱥȺⱥⱥⱥȺ", "hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!", "$. A3$. A3$. A3", "!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])"}, + }); + + for (const auto& pair : repeats_and_expected_map) { + options.repeats = pair.first; + this->CheckUnary("str_repeat", values, this->type(), pair.second, &options); + } +} + TYPED_TEST(TestStringKernels, IsAlphaNumericUnicode) { // U+08BE (utf8: \xE0\xA2\xBE) is undefined, but utf8proc things it is // UTF8PROC_CATEGORY_LO From cabe6a7d212846e585b5b037b3167a73242ace95 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 27 Aug 2021 18:31:02 -0400 Subject: [PATCH 08/84] update docs --- docs/source/cpp/compute.rst | 80 +++++++++++++++--------------- docs/source/python/api/compute.rst | 1 + 2 files changed, 42 insertions(+), 39 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 0a87752e92d..1312edca61b 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -812,45 +812,47 @@ The third set of functions examines string elements on a byte-per-byte basis: String transforms ~~~~~~~~~~~~~~~~~ -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| Function name | Arity | Input types | Output type | Options class | Notes | -+=========================+=======+========================+========================+===================================+=======+ -| ascii_capitalize | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_lower | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_reverse | Unary | String-like | String-like | | \(2) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_swapcase | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_title | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_upper | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | | \(3) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| binary_replace_slice | Unary | Binary- or String-like | Binary- or String-like | :struct:`ReplaceSliceOptions` | \(4) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| replace_substring | Unary | Binary- or String-like | Binary- or String-like | :struct:`ReplaceSubstringOptions` | \(5) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| replace_substring_regex | Unary | Binary- or String-like | Binary- or String-like | :struct:`ReplaceSubstringOptions` | \(6) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_capitalize | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_length | Unary | String-like | Int32 or Int64 | | \(7) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_lower | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_replace_slice | Unary | String-like | String-like | :struct:`ReplaceSliceOptions` | \(4) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_reverse | Unary | String-like | String-like | | \(9) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_swapcase | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_title | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_upper | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++=========================+========+========================+========================+===================================+=======+ +| ascii_capitalize | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| ascii_lower | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| ascii_reverse | Unary | String-like | String-like | | \(2) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| ascii_swapcase | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| ascii_title | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| ascii_upper | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | | \(3) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| binary_replace_slice | Unary | String-like | Binary- or String-like | :struct:`ReplaceSliceOptions` | \(4) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| replace_substring | Unary | String-like | String-like | :struct:`ReplaceSubstringOptions` | \(5) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| replace_substring_regex | Unary | String-like | String-like | :struct:`ReplaceSubstringOptions` | \(6) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| str_repeat | Binary | String-like | String-like | :struct:`RepeatOptions` | | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_capitalize | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_length | Unary | String-like | Int32 or Int64 | | \(7) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_lower | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_replace_slice | Unary | String-like | String-like | :struct:`ReplaceSliceOptions` | \(4) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_reverse | Unary | String-like | String-like | | \(9) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_swapcase | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_title | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ +| utf8_upper | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+------------------------+------------------------+-----------------------------------+-------+ * \(1) Each ASCII character in the input is converted to lowercase or uppercase. Non-ASCII characters are left untouched. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 521182f8a41..a56cc9dfce0 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -270,6 +270,7 @@ String Transforms binary_replace_slice replace_substring replace_substring_regex + str_repeat utf8_capitalize utf8_length utf8_lower From 2d4fcb6fdc19f110d389d4311aef4cc14f1cdfd2 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 27 Aug 2021 22:47:32 -0400 Subject: [PATCH 09/84] fix linter error --- cpp/src/arrow/compute/api_scalar.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 0c95db809d5..1bfe9bd4a87 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -323,7 +323,7 @@ constexpr char SplitPatternOptions::kTypeName[]; RepeatOptions::RepeatOptions(int64_t repeats) : FunctionOptions(internal::kRepeatOptionsType), repeats(repeats) {} -RepeatOptions::RepeatOptions() : RepeatOptions(1){}; +RepeatOptions::RepeatOptions() : RepeatOptions(1) {} constexpr char RepeatOptions::kTypeName[]; ReplaceSliceOptions::ReplaceSliceOptions(int64_t start, int64_t stop, From dfa931d26417b59ee89cffdfb407b36590d44182 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Sat, 28 Aug 2021 03:05:52 -0400 Subject: [PATCH 10/84] add doubling approach and benchmark --- .../arrow/compute/kernels/scalar_string.cc | 20 +++++++++++++++++-- .../kernels/scalar_string_benchmark.cc | 6 ++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 79ddc95b1e2..d233501d4b7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2748,8 +2748,24 @@ struct StrRepeatTransform : public StringTransformBase { int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, uint8_t* output) { uint8_t* output_start = output; - for (int i = 0; i < options->repeats; ++i) { - output = std::copy(input, input + input_string_ncodeunits, output); + + if (options->repeats < 4) { + // Naive for-loop + for (auto i = 0; i < options->repeats; ++i) { + output = std::copy(input, input + input_string_ncodeunits, output); + } + } else { + auto N = options->repeats; + auto L = input_string_ncodeunits; + auto i = 1; + // log(N) approach + output = std::copy(input, input + L, output); + for (auto iL = L; i <= (N >> 1); i <<= 1, iL <<= 1) { + output = std::copy(output_start, output_start + iL, output); + } + // Epilogue remainder + auto rem = (N ^ i) * L; + output = std::copy(output_start, output_start + rem, output); } return output - output_start; } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index ddc3a56f00f..71061f49404 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -77,6 +77,11 @@ static void SplitPattern(benchmark::State& state) { UnaryStringBenchmark(state, "split_pattern", &options); } +static void StrRepeat(benchmark::State& state) { + RepeatOptions options(8); + UnaryStringBenchmark(state, "str_repeat", &options); +} + static void TrimSingleAscii(benchmark::State& state) { TrimOptions options("a"); UnaryStringBenchmark(state, "ascii_trim", &options); @@ -215,6 +220,7 @@ BENCHMARK(AsciiUpper); BENCHMARK(IsAlphaNumericAscii); BENCHMARK(MatchSubstring); BENCHMARK(SplitPattern); +BENCHMARK(StrRepeat); BENCHMARK(TrimSingleAscii); BENCHMARK(TrimManyAscii); #ifdef ARROW_WITH_RE2 From d32574d567c5aad14046938599b8977ba48965c0 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Sun, 29 Aug 2021 02:16:50 -0400 Subject: [PATCH 11/84] remove naive approach and add check for repeats option --- .../arrow/compute/kernels/scalar_string.cc | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index d233501d4b7..5d0f2729ee1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2748,24 +2748,21 @@ struct StrRepeatTransform : public StringTransformBase { int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, uint8_t* output) { uint8_t* output_start = output; - - if (options->repeats < 4) { - // Naive for-loop - for (auto i = 0; i < options->repeats; ++i) { - output = std::copy(input, input + input_string_ncodeunits, output); - } - } else { - auto N = options->repeats; - auto L = input_string_ncodeunits; - auto i = 1; + if (options->repeats > 0) { // log(N) approach - output = std::copy(input, input + L, output); - for (auto iL = L; i <= (N >> 1); i <<= 1, iL <<= 1) { - output = std::copy(output_start, output_start + iL, output); + std::memcpy(output, input, input_string_ncodeunits); + output += input_string_ncodeunits; + int64_t i = 1; + for (int64_t ilen = input_string_ncodeunits; i <= (options->repeats >> 1); + i <<= 1, ilen <<= 1) { + std::memcpy(output, output_start, ilen); + output += ilen; } + // Epilogue remainder - auto rem = (N ^ i) * L; - output = std::copy(output_start, output_start + rem, output); + int64_t rem = (options->repeats ^ i) * input_string_ncodeunits; + std::memcpy(output, output_start, rem); + output += rem; } return output - output_start; } From 8daa01f102c941dc148fda1a0302f1bfa46563b4 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Sun, 29 Aug 2021 21:49:57 -0400 Subject: [PATCH 12/84] add support for array of repeats --- cpp/src/arrow/compute/api_scalar.cc | 11 ++-- cpp/src/arrow/compute/api_scalar.h | 6 +- .../arrow/compute/kernels/scalar_string.cc | 62 ++++++++++++++++--- .../compute/kernels/scalar_string_test.cc | 18 +++++- 4 files changed, 81 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 1bfe9bd4a87..41e4d146dc5 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -209,8 +209,9 @@ static auto kSplitPatternOptionsType = GetFunctionOptionsType(DataMember("repeats", &RepeatOptions::repeats)); +static auto kRepeatOptionsType = GetFunctionOptionsType( + DataMember("nrepeats", &RepeatOptions::nrepeats), + DataMember("repeats", &RepeatOptions::repeats)); static auto kReplaceSliceOptionsType = GetFunctionOptionsType( DataMember("start", &ReplaceSliceOptions::start), DataMember("stop", &ReplaceSliceOptions::stop), @@ -321,8 +322,10 @@ SplitPatternOptions::SplitPatternOptions(std::string pattern, int64_t max_splits SplitPatternOptions::SplitPatternOptions() : SplitPatternOptions("", -1, false) {} constexpr char SplitPatternOptions::kTypeName[]; -RepeatOptions::RepeatOptions(int64_t repeats) - : FunctionOptions(internal::kRepeatOptionsType), repeats(repeats) {} +RepeatOptions::RepeatOptions(int64_t nrepeats) + : FunctionOptions(internal::kRepeatOptionsType), nrepeats(nrepeats), repeats({}) {} +RepeatOptions::RepeatOptions(std::vector repeats) + : FunctionOptions(internal::kRepeatOptionsType), nrepeats(0), repeats(repeats) {} RepeatOptions::RepeatOptions() : RepeatOptions(1) {} constexpr char RepeatOptions::kTypeName[]; diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 83b3917ef1a..86dd1c43c2a 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -168,12 +168,14 @@ class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { class ARROW_EXPORT RepeatOptions : public FunctionOptions { public: - explicit RepeatOptions(int64_t repeats); + explicit RepeatOptions(int64_t nrepeats); + explicit RepeatOptions(std::vector repeats); RepeatOptions(); constexpr static char const kTypeName[] = "RepeatOptions"; /// Number of repeats - int64_t repeats; + int64_t nrepeats; + std::vector repeats; }; class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 5d0f2729ee1..ebe0d0293c0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2733,34 +2733,78 @@ void AddSplit(FunctionRegistry* registry) { #endif } +template struct StrRepeatTransform : public StringTransformBase { using Options = RepeatOptions; using State = OptionsWrapper; + using ArrayType = typename TypeTraits::ArrayType; const Options* options; + std::function Transform; + std::vector::const_iterator it; + + explicit StrRepeatTransform(const Options& options) : options{&options} { + if (this->options->repeats.size()) { + // NOTE: This is an incorrect hack to iterate through the repeat values because for + // null entries, Transform() is not invoked and thus the iterator does not moves. + it = this->options->repeats.begin(); + Transform = [&](const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + auto nrepeats = *it++; + if (it == this->options->repeats.end()) { + it = this->options->repeats.begin(); + } + return this->Transform_(input, input_string_ncodeunits, output, nrepeats); + }; + } else { + Transform = [&](const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + return this->Transform_(input, input_string_ncodeunits, output, + this->options->nrepeats); + }; + } + } - explicit StrRepeatTransform(const Options& options) : options{&options} {} + Status PreExec(KernelContext*, const ExecBatch& batch, Datum*) override { + if (options->repeats.size() && batch[0].kind() == Datum::ARRAY) { + ArrayType input(batch[0].array()); + if (static_cast(options->repeats.size()) != + static_cast(input.length())) { + return Status::Invalid( + "Number of repeats and input strings are differ in length"); + } + } + return Status::OK(); + } int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { - return input_ncodeunits * std::max(options->repeats, 0); + // NOTE: Ideally, we would like to sum the values that correspond to non-null entries + // along with each inputs' size but this requires traversing the data twice. The upper + // limit is to assume that all strings are repeated the max number of times. + auto max_nrepeats = + options->repeats.size() + ? *std::max_element(options->repeats.begin(), options->repeats.end()) + : options->nrepeats; + return input_ncodeunits * std::max(max_nrepeats, 0); } - int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, - uint8_t* output) { + private: + int64_t Transform_(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output, const int64_t nrepeats) { uint8_t* output_start = output; - if (options->repeats > 0) { - // log(N) approach + if (nrepeats > 0) { + // log2(repeats) approach std::memcpy(output, input, input_string_ncodeunits); output += input_string_ncodeunits; int64_t i = 1; - for (int64_t ilen = input_string_ncodeunits; i <= (options->repeats >> 1); + for (int64_t ilen = input_string_ncodeunits; i <= (nrepeats >> 1); i <<= 1, ilen <<= 1) { std::memcpy(output, output_start, ilen); output += ilen; } // Epilogue remainder - int64_t rem = (options->repeats ^ i) * input_string_ncodeunits; + int64_t rem = (nrepeats ^ i) * input_string_ncodeunits; std::memcpy(output, output_start, rem); output += rem; } @@ -2769,7 +2813,7 @@ struct StrRepeatTransform : public StringTransformBase { }; template -using StrRepeat = StringTransformExecWithState; +using StrRepeat = StringTransformExecWithState>; // ---------------------------------------------------------------------- // Replace substring (plain, regex) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index a2987ff32bf..a2c7bab1c94 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -1059,11 +1059,27 @@ TYPED_TEST(TestStringKernels, StrRepeat) { }); for (const auto& pair : repeats_and_expected_map) { - options.repeats = pair.first; + options.nrepeats = pair.first; this->CheckUnary("str_repeat", values, this->type(), pair.second, &options); } } +TYPED_TEST(TestStringKernels, StrRepeats) { + RepeatOptions options{{-1, 2, 4, 2, 0, 1, 3, 2, 3}}; + std::string values( + R"(["aAazZæÆ&", "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"); + + std::string expected( + R"(["", "", "bbbb", "ɑɽⱤoWɑɽⱤoW", "", "ⱥⱥⱥȺ", "hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!", "$. A3$. A3", "!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])"); + this->CheckUnary("str_repeat", values, this->type(), expected, &options); + + // Test invalid data: len(repeats) != len(inputs) + options.repeats.pop_back(); + auto invalid_input = ArrayFromJSON(this->type(), "[\"b\"]"); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("differ in length"), + CallFunction("str_repeat", {invalid_input}, &options)); +} + TYPED_TEST(TestStringKernels, IsAlphaNumericUnicode) { // U+08BE (utf8: \xE0\xA2\xBE) is undefined, but utf8proc things it is // UTF8PROC_CATEGORY_LO From 2e226482869434b7742b5c54f630519e94a503f7 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 30 Aug 2021 02:59:44 -0400 Subject: [PATCH 13/84] add RepeatOptions to PyArrow --- python/pyarrow/_compute.pyx | 10 ++++++++++ python/pyarrow/compute.py | 1 + python/pyarrow/tests/test_compute.py | 1 + 3 files changed, 12 insertions(+) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index d62c9c0ee85..217738ac3a2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1161,6 +1161,16 @@ class VarianceOptions(_VarianceOptions): self._set_options(ddof, skip_nulls, min_count) +cdef class _RepeatOptions(FunctionOptions): + def _set_options(self, repeats): + self.wrapped.reset(new CRepeatOptions(repeats)) + + +class RepeatOptions(_RepeatOptions): + def __init__(self, repeats): + self._set_options(repeats) + + cdef class _SplitOptions(FunctionOptions): def _set_options(self, max_splits, reverse): self.wrapped.reset(new CSplitOptions(max_splits, reverse)) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 6e3bd7fcab3..ccaa90accca 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -47,6 +47,7 @@ PadOptions, PartitionNthOptions, QuantileOptions, + RepeatOptions, ReplaceSliceOptions, ReplaceSubstringOptions, RoundOptions, diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index be2da31b9d1..5d87ae5c9ff 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -131,6 +131,7 @@ def test_option_class_equality(): pc.PadOptions(5), pc.PartitionNthOptions(1, null_placement="at_start"), pc.QuantileOptions(), + pc.RepeatOptions(1), pc.ReplaceSliceOptions(0, 1, "a"), pc.ReplaceSubstringOptions("a", "b"), pc.RoundOptions(2, "towards_infinity"), From 8e394c6e6eb1ecd4b9982fea6d5d7629a7ca95e5 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 30 Aug 2021 03:03:55 -0400 Subject: [PATCH 14/84] add R bindings --- r/R/dplyr-functions.R | 4 ++++ r/R/expression.R | 1 + r/src/compute.cpp | 6 ++++++ 3 files changed, 11 insertions(+) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 717cdae9662..bd51a6ca431 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -409,6 +409,10 @@ nse_funcs$substr <- function(x, start, stop) { ) } +nse_funcs$str_repeat <- function(x, repeats) { + Expression$create("str_repeat", x, options = list(repeats = repeats)) +} + nse_funcs$substring <- function(text, first, last) { nse_funcs$substr(x = text, start = first, stop = last) } diff --git a/r/R/expression.R b/r/R/expression.R index b1b6635f538..28750928d9f 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -49,6 +49,7 @@ # nchar is defined in dplyr-functions.R "str_length" = "utf8_length", # str_pad is defined in dplyr-functions.R + # str_repeat is defined in dplyr-functions.R # str_sub is defined in dplyr-functions.R # str_to_lower is defined in dplyr-functions.R # str_to_title is defined in dplyr-functions.R diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 0f0ef2f7dd1..309c5888c5c 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -345,6 +345,7 @@ std::shared_ptr make_compute_options( cpp11::as_cpp(options["week_start"])); } +<<<<<<< 6fa024b0ea28d5d97efeb6b8cee073ed1f930041 if (func_name == "iso_week") { return std::make_shared( arrow::compute::WeekOptions::ISODefaults()); @@ -373,6 +374,11 @@ std::shared_ptr make_compute_options( first_week_is_fully_in_year); } + if (func_name == "str_repeat") { + using Options = arrow::compute::RepeatOptions; + return std::make_shared(cpp11::as_cpp>(options["repeats"])); + } + if (func_name == "strptime") { using Options = arrow::compute::StrptimeOptions; return std::make_shared( From 8f3c977096b33e56c4e1a5d2ef8183d7d503ed6d Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 30 Aug 2021 05:07:04 -0400 Subject: [PATCH 15/84] set repeats to std::vector --- cpp/src/arrow/compute/api_scalar.cc | 9 ++-- cpp/src/arrow/compute/api_scalar.h | 10 +++-- .../arrow/compute/kernels/scalar_string.cc | 41 +++++++------------ .../kernels/scalar_string_benchmark.cc | 2 +- .../compute/kernels/scalar_string_test.cc | 2 +- 5 files changed, 26 insertions(+), 38 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 41e4d146dc5..01efd57f33f 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -210,7 +210,6 @@ static auto kSplitPatternOptionsType = GetFunctionOptionsType( - DataMember("nrepeats", &RepeatOptions::nrepeats), DataMember("repeats", &RepeatOptions::repeats)); static auto kReplaceSliceOptionsType = GetFunctionOptionsType( DataMember("start", &ReplaceSliceOptions::start), @@ -322,11 +321,9 @@ SplitPatternOptions::SplitPatternOptions(std::string pattern, int64_t max_splits SplitPatternOptions::SplitPatternOptions() : SplitPatternOptions("", -1, false) {} constexpr char SplitPatternOptions::kTypeName[]; -RepeatOptions::RepeatOptions(int64_t nrepeats) - : FunctionOptions(internal::kRepeatOptionsType), nrepeats(nrepeats), repeats({}) {} -RepeatOptions::RepeatOptions(std::vector repeats) - : FunctionOptions(internal::kRepeatOptionsType), nrepeats(0), repeats(repeats) {} -RepeatOptions::RepeatOptions() : RepeatOptions(1) {} +RepeatOptions::RepeatOptions(std::vector repeats) + : FunctionOptions(internal::kRepeatOptionsType), repeats(repeats) {} +RepeatOptions::RepeatOptions() : RepeatOptions({1}) {} constexpr char RepeatOptions::kTypeName[]; ReplaceSliceOptions::ReplaceSliceOptions(int64_t start, int64_t stop, diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 86dd1c43c2a..989dc0c4c53 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -168,14 +168,16 @@ class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { class ARROW_EXPORT RepeatOptions : public FunctionOptions { public: - explicit RepeatOptions(int64_t nrepeats); - explicit RepeatOptions(std::vector repeats); + // NOTE: Use 'int' instead of 'int64_t' because R-cpp11 does not supports + // 'r_vector'. + explicit RepeatOptions(std::vector repeats); RepeatOptions(); constexpr static char const kTypeName[] = "RepeatOptions"; /// Number of repeats - int64_t nrepeats; - std::vector repeats; + /// A single value is applied to all inputs, otherwise repeats are apply to + /// corresponding input based on order. + std::vector repeats; }; class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index ebe0d0293c0..145f6791e81 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2741,32 +2741,24 @@ struct StrRepeatTransform : public StringTransformBase { const Options* options; std::function Transform; - std::vector::const_iterator it; + std::vector::const_iterator it; explicit StrRepeatTransform(const Options& options) : options{&options} { - if (this->options->repeats.size()) { - // NOTE: This is an incorrect hack to iterate through the repeat values because for - // null entries, Transform() is not invoked and thus the iterator does not moves. - it = this->options->repeats.begin(); - Transform = [&](const uint8_t* input, int64_t input_string_ncodeunits, - uint8_t* output) { - auto nrepeats = *it++; - if (it == this->options->repeats.end()) { - it = this->options->repeats.begin(); - } - return this->Transform_(input, input_string_ncodeunits, output, nrepeats); - }; - } else { - Transform = [&](const uint8_t* input, int64_t input_string_ncodeunits, - uint8_t* output) { - return this->Transform_(input, input_string_ncodeunits, output, - this->options->nrepeats); - }; - } + // NOTE: This is an incorrect hack to iterate through the repeat values because for + // null entries, Transform() is not invoked and thus the iterator does not moves. + it = this->options->repeats.begin(); + Transform = [&](const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + auto nrepeats = *it++; + if (it == this->options->repeats.end()) { + it = this->options->repeats.begin(); + } + return this->Transform_(input, input_string_ncodeunits, output, nrepeats); + }; } Status PreExec(KernelContext*, const ExecBatch& batch, Datum*) override { - if (options->repeats.size() && batch[0].kind() == Datum::ARRAY) { + if ((options->repeats.size() > 1) && (batch[0].kind() == Datum::ARRAY)) { ArrayType input(batch[0].array()); if (static_cast(options->repeats.size()) != static_cast(input.length())) { @@ -2781,11 +2773,8 @@ struct StrRepeatTransform : public StringTransformBase { // NOTE: Ideally, we would like to sum the values that correspond to non-null entries // along with each inputs' size but this requires traversing the data twice. The upper // limit is to assume that all strings are repeated the max number of times. - auto max_nrepeats = - options->repeats.size() - ? *std::max_element(options->repeats.begin(), options->repeats.end()) - : options->nrepeats; - return input_ncodeunits * std::max(max_nrepeats, 0); + auto max_repeat = *std::max_element(options->repeats.begin(), options->repeats.end()); + return input_ncodeunits * std::max(max_repeat, 0); } private: diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index 71061f49404..77471f218fb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -78,7 +78,7 @@ static void SplitPattern(benchmark::State& state) { } static void StrRepeat(benchmark::State& state) { - RepeatOptions options(8); + RepeatOptions options({8}); UnaryStringBenchmark(state, "str_repeat", &options); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index a2c7bab1c94..c679ce1052c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -1059,7 +1059,7 @@ TYPED_TEST(TestStringKernels, StrRepeat) { }); for (const auto& pair : repeats_and_expected_map) { - options.nrepeats = pair.first; + options.repeats = {pair.first}; this->CheckUnary("str_repeat", values, this->type(), pair.second, &options); } } From f89576906aaa85e9873c6cc7abc92cae76c6c410 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 30 Aug 2021 05:14:46 -0400 Subject: [PATCH 16/84] update pyarrow bindings and tests --- python/pyarrow/_compute.pyx | 6 ++++++ python/pyarrow/includes/libarrow.pxd | 5 +++++ python/pyarrow/tests/test_compute.py | 31 ++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 217738ac3a2..c3ac4755326 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -31,6 +31,10 @@ import pyarrow.lib as lib import numpy as np +def is_iterable(obj): + return hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes)) + + cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): """ Wrap a C++ scalar Function in a ScalarFunction object. @@ -1168,6 +1172,8 @@ cdef class _RepeatOptions(FunctionOptions): class RepeatOptions(_RepeatOptions): def __init__(self, repeats): + if not is_iterable(repeats): + repeats = [repeats] self._set_options(repeats) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 6d187eaa8c6..fd915d91375 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1904,6 +1904,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: int64_t width c_string padding + cdef cppclass CRepeatOptions \ + "arrow::compute::RepeatOptions"(CFunctionOptions): + CRepeatOptions(vector[int] repeats) + vector[int] repeats + cdef cppclass CSliceOptions \ "arrow::compute::SliceOptions"(CFunctionOptions): CSliceOptions(int64_t start, int64_t stop, int64_t step) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 5d87ae5c9ff..4d9ec082c62 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2237,3 +2237,34 @@ def test_count_distinct_options(): assert pc.count_distinct(arr, mode='only_valid').as_py() == 3 assert pc.count_distinct(arr, mode='only_null').as_py() == 1 assert pc.count_distinct(arr, mode='all').as_py() == 4 + + +def test_str_repeat(): + # Test with single value for number of repeats + values = ["æÆ&", None, "", "b", "ɑɽⱤoW", "ıI", "$. 3"] + repeat_and_expected_map = { + -1: ["", None, "", "", "", "", ""], + 0: ["", None, "", "", "", "", ""], + 1: ["æÆ&", None, "", "b", "ɑɽⱤoW", "ıI", "$. 3"], + 2: ["æÆ&æÆ&", None, "", "bb", "ɑɽⱤoWɑɽⱤoW", "ıIıI", "$. 3$. 3"], + } + for repeat, expected in repeat_and_expected_map.items(): + options = pc.RepeatOptions(repeat) + result = pc.str_repeat(values, options=options) + assert result.equals(pa.array(expected)) + + # Test with multiple values for number of repeats + values = ["a", "b"] + repeat_and_expected_map = { + (-1, 2): ["", "bb"], + (3,): ["aaa", "bbb"], + (0, 0): ["", ""], + } + for repeat, expected in repeat_and_expected_map.items(): + options = pc.RepeatOptions(repeat) + result = pc.str_repeat(values, options=options) + assert result.equals(pa.array(expected)) + + # Test with invalid number of values and repeats + with pytest.raises(ValueError, match="differ in length"): + pc.str_repeat(["a", "b"], repeats=[1, 2, 3]) From 512d5b8517550ff1334d5d5d429b4a697f9122c4 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 30 Aug 2021 05:16:07 -0400 Subject: [PATCH 17/84] fix lint error --- cpp/src/arrow/compute/api_scalar.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 01efd57f33f..c613e6d3e0d 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -209,8 +209,8 @@ static auto kSplitPatternOptionsType = GetFunctionOptionsType( - DataMember("repeats", &RepeatOptions::repeats)); +static auto kRepeatOptionsType = + GetFunctionOptionsType(DataMember("repeats", &RepeatOptions::repeats)); static auto kReplaceSliceOptionsType = GetFunctionOptionsType( DataMember("start", &ReplaceSliceOptions::start), DataMember("stop", &ReplaceSliceOptions::stop), From e2744ccbb0e9c0e3fc7fdcb2f01063981db90017 Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Mon, 30 Aug 2021 05:18:43 -0400 Subject: [PATCH 18/84] fix typo --- cpp/src/arrow/compute/kernels/scalar_string.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 145f6791e81..ff1b475adc4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2763,7 +2763,7 @@ struct StrRepeatTransform : public StringTransformBase { if (static_cast(options->repeats.size()) != static_cast(input.length())) { return Status::Invalid( - "Number of repeats and input strings are differ in length"); + "Number of repeats and input strings differ in length"); } } return Status::OK(); From 81afdafd7820b1d4b0ba599944e4ae4cc8c24a8f Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Fri, 3 Sep 2021 22:46:23 -0400 Subject: [PATCH 19/84] remove RepeatOptions --- cpp/src/arrow/compute/api_scalar.cc | 8 -------- cpp/src/arrow/compute/api_scalar.h | 14 -------------- .../arrow/compute/kernels/scalar_string_test.cc | 13 +++++-------- python/pyarrow/_compute.pyx | 16 ---------------- python/pyarrow/tests/test_compute.py | 8 +++----- r/R/dplyr-functions.R | 4 ---- r/R/expression.R | 1 - 7 files changed, 8 insertions(+), 56 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index c613e6d3e0d..e3fe1bdf73d 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -209,8 +209,6 @@ static auto kSplitPatternOptionsType = GetFunctionOptionsType(DataMember("repeats", &RepeatOptions::repeats)); static auto kReplaceSliceOptionsType = GetFunctionOptionsType( DataMember("start", &ReplaceSliceOptions::start), DataMember("stop", &ReplaceSliceOptions::stop), @@ -321,11 +319,6 @@ SplitPatternOptions::SplitPatternOptions(std::string pattern, int64_t max_splits SplitPatternOptions::SplitPatternOptions() : SplitPatternOptions("", -1, false) {} constexpr char SplitPatternOptions::kTypeName[]; -RepeatOptions::RepeatOptions(std::vector repeats) - : FunctionOptions(internal::kRepeatOptionsType), repeats(repeats) {} -RepeatOptions::RepeatOptions() : RepeatOptions({1}) {} -constexpr char RepeatOptions::kTypeName[]; - ReplaceSliceOptions::ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement) : FunctionOptions(internal::kReplaceSliceOptionsType), @@ -447,7 +440,6 @@ void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kMatchSubstringOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSplitOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kSplitPatternOptionsType)); - DCHECK_OK(registry->AddFunctionOptionsType(kRepeatOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSliceOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSubstringOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType)); diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 989dc0c4c53..4bb18b37527 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -166,20 +166,6 @@ class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { bool reverse; }; -class ARROW_EXPORT RepeatOptions : public FunctionOptions { - public: - // NOTE: Use 'int' instead of 'int64_t' because R-cpp11 does not supports - // 'r_vector'. - explicit RepeatOptions(std::vector repeats); - RepeatOptions(); - constexpr static char const kTypeName[] = "RepeatOptions"; - - /// Number of repeats - /// A single value is applied to all inputs, otherwise repeats are apply to - /// corresponding input based on order. - std::vector repeats; -}; - class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { public: explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index c679ce1052c..de662991c7f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -1044,8 +1044,7 @@ TYPED_TEST(TestStringKernels, Utf8Title) { } TYPED_TEST(TestStringKernels, StrRepeat) { - RepeatOptions options; - this->CheckUnary("str_repeat", "[]", this->type(), "[]", &options); + this->CheckUnary("str_repeat", "[]", this->type(), "[]"); std::string values( R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"); @@ -1059,25 +1058,23 @@ TYPED_TEST(TestStringKernels, StrRepeat) { }); for (const auto& pair : repeats_and_expected_map) { - options.repeats = {pair.first}; - this->CheckUnary("str_repeat", values, this->type(), pair.second, &options); + this->CheckUnary("str_repeat", values, this->type(), pair.second); } } TYPED_TEST(TestStringKernels, StrRepeats) { - RepeatOptions options{{-1, 2, 4, 2, 0, 1, 3, 2, 3}}; + std::vector repeats{-1, 2, 4, 2, 0, 1, 3, 2, 3}; std::string values( R"(["aAazZæÆ&", "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"); std::string expected( R"(["", "", "bbbb", "ɑɽⱤoWɑɽⱤoW", "", "ⱥⱥⱥȺ", "hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!", "$. A3$. A3", "!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])"); - this->CheckUnary("str_repeat", values, this->type(), expected, &options); + this->CheckUnary("str_repeat", values, this->type(), expected); // Test invalid data: len(repeats) != len(inputs) - options.repeats.pop_back(); auto invalid_input = ArrayFromJSON(this->type(), "[\"b\"]"); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("differ in length"), - CallFunction("str_repeat", {invalid_input}, &options)); + CallFunction("str_repeat", {invalid_input, repeats})); } TYPED_TEST(TestStringKernels, IsAlphaNumericUnicode) { diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c3ac4755326..d62c9c0ee85 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -31,10 +31,6 @@ import pyarrow.lib as lib import numpy as np -def is_iterable(obj): - return hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes)) - - cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): """ Wrap a C++ scalar Function in a ScalarFunction object. @@ -1165,18 +1161,6 @@ class VarianceOptions(_VarianceOptions): self._set_options(ddof, skip_nulls, min_count) -cdef class _RepeatOptions(FunctionOptions): - def _set_options(self, repeats): - self.wrapped.reset(new CRepeatOptions(repeats)) - - -class RepeatOptions(_RepeatOptions): - def __init__(self, repeats): - if not is_iterable(repeats): - repeats = [repeats] - self._set_options(repeats) - - cdef class _SplitOptions(FunctionOptions): def _set_options(self, max_splits, reverse): self.wrapped.reset(new CSplitOptions(max_splits, reverse)) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 4d9ec082c62..0997f1f845b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2249,8 +2249,7 @@ def test_str_repeat(): 2: ["æÆ&æÆ&", None, "", "bb", "ɑɽⱤoWɑɽⱤoW", "ıIıI", "$. 3$. 3"], } for repeat, expected in repeat_and_expected_map.items(): - options = pc.RepeatOptions(repeat) - result = pc.str_repeat(values, options=options) + result = pc.str_repeat(values, repeat, options=options) assert result.equals(pa.array(expected)) # Test with multiple values for number of repeats @@ -2261,10 +2260,9 @@ def test_str_repeat(): (0, 0): ["", ""], } for repeat, expected in repeat_and_expected_map.items(): - options = pc.RepeatOptions(repeat) - result = pc.str_repeat(values, options=options) + result = pc.str_repeat(values, repeat, options=options) assert result.equals(pa.array(expected)) # Test with invalid number of values and repeats with pytest.raises(ValueError, match="differ in length"): - pc.str_repeat(["a", "b"], repeats=[1, 2, 3]) + pc.str_repeat(["a", "b"], [1, 2, 3]) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index bd51a6ca431..717cdae9662 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -409,10 +409,6 @@ nse_funcs$substr <- function(x, start, stop) { ) } -nse_funcs$str_repeat <- function(x, repeats) { - Expression$create("str_repeat", x, options = list(repeats = repeats)) -} - nse_funcs$substring <- function(text, first, last) { nse_funcs$substr(x = text, start = first, stop = last) } diff --git a/r/R/expression.R b/r/R/expression.R index 28750928d9f..b1b6635f538 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -49,7 +49,6 @@ # nchar is defined in dplyr-functions.R "str_length" = "utf8_length", # str_pad is defined in dplyr-functions.R - # str_repeat is defined in dplyr-functions.R # str_sub is defined in dplyr-functions.R # str_to_lower is defined in dplyr-functions.R # str_to_title is defined in dplyr-functions.R From b652a07bd279dfa9bd639f8b03e755aa2f17a44c Mon Sep 17 00:00:00 2001 From: Eduardo Ponce Date: Sat, 4 Sep 2021 00:19:34 -0400 Subject: [PATCH 20/84] update kernel to conform to StringBinaryTransformExec --- .../arrow/compute/kernels/scalar_string.cc | 109 +++++++++++------- .../compute/kernels/scalar_string_test.cc | 46 +++++--- 2 files changed, 99 insertions(+), 56 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index ff1b475adc4..d0c06fa1cb5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2733,61 +2733,60 @@ void AddSplit(FunctionRegistry* registry) { #endif } -template +template struct StrRepeatTransform : public StringTransformBase { - using Options = RepeatOptions; - using State = OptionsWrapper; - using ArrayType = typename TypeTraits::ArrayType; - - const Options* options; - std::function Transform; - std::vector::const_iterator it; - - explicit StrRepeatTransform(const Options& options) : options{&options} { - // NOTE: This is an incorrect hack to iterate through the repeat values because for - // null entries, Transform() is not invoked and thus the iterator does not moves. - it = this->options->repeats.begin(); - Transform = [&](const uint8_t* input, int64_t input_string_ncodeunits, - uint8_t* output) { - auto nrepeats = *it++; - if (it == this->options->repeats.end()) { - it = this->options->repeats.begin(); - } - return this->Transform_(input, input_string_ncodeunits, output, nrepeats); - }; - } + using ArrayType1 = typename TypeTraits::ArrayType; + using ArrayType2 = typename TypeTraits::ArrayType; + int64_t max_nrepeats = 0; Status PreExec(KernelContext*, const ExecBatch& batch, Datum*) override { - if ((options->repeats.size() > 1) && (batch[0].kind() == Datum::ARRAY)) { - ArrayType input(batch[0].array()); - if (static_cast(options->repeats.size()) != - static_cast(input.length())) { - return Status::Invalid( - "Number of repeats and input strings differ in length"); + // Since repeat values are validated here, might as well get the maximum repeat value + // into a data member and use it for MaxCodeunits(). + if (batch[1].is_scalar()) { + max_nrepeats = static_cast( + checked_cast&>(*batch[1].scalar()).value); + } + + if (batch[0].is_array() && batch[1].is_array()) { + // TODO(edponce): Is it possible to not convert to ArrayType for these checks and + // finding max? + ArrayType1 array1(batch[0].array()); + ArrayType2 array2(batch[1].array()); + if (array1.length() != array2.length()) { + return Status::Invalid("Number of input strings and repetitions differ in length"); } + + // Note: Ideally, we would like to calculate the exact output size by iterating over + // all input strings and summing each length multiplied by the corresponding repeat + // value, but this requires traversing the data twice (now and during transform). + // The upper limit is to assume that all strings are repeated the max number of + // times. + max_nrepeats = + static_cast(**std::max_element(array2.begin(), array2.end())); + } + + if (max_nrepeats < 0) { + return Status::Invalid("Invalid string repetition value, has to be non-negative"); } return Status::OK(); } - int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { - // NOTE: Ideally, we would like to sum the values that correspond to non-null entries - // along with each inputs' size but this requires traversing the data twice. The upper - // limit is to assume that all strings are repeated the max number of times. - auto max_repeat = *std::max_element(options->repeats.begin(), options->repeats.end()); - return input_ncodeunits * std::max(max_repeat, 0); + int64_t MaxCodeunits(int64_t inputs, int64_t input_ncodeunits) override { + return input_ncodeunits * max_nrepeats; } - private: - int64_t Transform_(const uint8_t* input, int64_t input_string_ncodeunits, - uint8_t* output, const int64_t nrepeats) { + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + const std::shared_ptr& input2, uint8_t* output) { + auto nrepeats = + static_cast(checked_cast&>(*input2).value); uint8_t* output_start = output; if (nrepeats > 0) { // log2(repeats) approach std::memcpy(output, input, input_string_ncodeunits); output += input_string_ncodeunits; int64_t i = 1; - for (int64_t ilen = input_string_ncodeunits; i <= (nrepeats >> 1); - i <<= 1, ilen <<= 1) { + for (int64_t ilen = input_string_ncodeunits; i <= (nrepeats / 2); + i *= 2, ilen *= 2) { std::memcpy(output, output_start, ilen); output += ilen; } @@ -2801,8 +2800,30 @@ struct StrRepeatTransform : public StringTransformBase { } }; -template -using StrRepeat = StringTransformExecWithState>; +template +using StrRepeat = + StringBinaryTransformExec>; + +template