diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index aa953119d47..240386dc81d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -306,10 +306,6 @@ struct StringTransformBase { virtual Status InvalidStatus() { return Status::Invalid("Invalid UTF8 sequence in input"); } - - // Derived classes should also define this method: - // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, - // uint8_t* output); }; template @@ -319,21 +315,23 @@ struct StringTransformExecBase { static Status Execute(KernelContext* ctx, StringTransform* transform, const ExecBatch& batch, Datum* out) { + if (batch.num_values() != 1) { + return Status::Invalid("Invalid arity for unary string transform"); + } + if (batch[0].kind() == Datum::ARRAY) { return ExecArray(ctx, transform, batch[0].array(), out); + } else if (batch[0].kind() == Datum::SCALAR) { + return ExecScalar(ctx, transform, batch[0].scalar(), out); } - DCHECK_EQ(batch[0].kind(), Datum::SCALAR); - return ExecScalar(ctx, transform, batch[0].scalar(), out); + return Status::Invalid("Invalid ExecBatch kind for unary string transform"); } static Status ExecArray(KernelContext* ctx, StringTransform* transform, const std::shared_ptr& data, Datum* out) { ArrayType input(data); - ArrayData* output = out->mutable_array(); - const int64_t input_ncodeunits = input.total_values_length(); const int64_t input_nstrings = input.length(); - const int64_t output_ncodeunits_max = transform->MaxCodeunits(input_nstrings, input_ncodeunits); if (output_ncodeunits_max > std::numeric_limits::max()) { @@ -341,6 +339,7 @@ struct StringTransformExecBase { "Result might not fit in a 32bit utf8 array, convert to large_utf8"); } + ArrayData* output = out->mutable_array(); ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); output->buffers[2] = values_buffer; @@ -348,7 +347,6 @@ struct StringTransformExecBase { offset_type* output_string_offsets = output->GetMutableValues(1); uint8_t* output_str = output->buffers[2]->mutable_data(); offset_type output_ncodeunits = 0; - output_string_offsets[0] = 0; for (int64_t i = 0; i < input_nstrings; i++) { if (!input.IsNull(i)) { @@ -375,16 +373,16 @@ struct StringTransformExecBase { if (!input.is_valid) { return Status::OK(); } - auto* result = checked_cast(out->scalar().get()); - result->is_valid = true; const int64_t data_nbytes = static_cast(input.value->size()); - const int64_t output_ncodeunits_max = transform->MaxCodeunits(1, data_nbytes); if (output_ncodeunits_max > std::numeric_limits::max()) { return Status::CapacityError( "Result might not fit in a 32bit utf8 array, convert to large_utf8"); } + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); + auto* result = checked_cast(out->scalar().get()); + result->is_valid = true; result->value = value_buffer; auto encoded_nbytes = static_cast(transform->Transform( input.value->data(), data_nbytes, value_buffer->mutable_data())); @@ -394,6 +392,10 @@ struct StringTransformExecBase { DCHECK_LE(encoded_nbytes, output_ncodeunits_max); return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); } + + // Unary derived classes should define this method: + // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + // uint8_t* output); }; template @@ -420,6 +422,228 @@ struct StringTransformExecWithState } }; +struct StringBinaryTransformBase { + virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); + } + + // Return the maximum total size of the output in codeunits (i.e. bytes) + // given input characteristics. + virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits, + const std::shared_ptr& input2) { + return input_ncodeunits; + } + + // Return the maximum total size of the output in codeunits (i.e. bytes) + // given input characteristics. + virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits, + const std::shared_ptr& data2) { + return input_ncodeunits; + } +}; + +/// Kernel exec generator for binary string transforms. +/// The first parameter is expected to always be a string type while the second parameter +/// is generic. It supports executions of the form: +/// * Scalar, Scalar +/// * Array, Scalar - scalar is broadcasted and paired with all values of array +/// * Array, Array - arrays are processed element-wise +/// * Scalar, Array - not supported by default +template +struct StringBinaryTransformExecBase { + using offset_type = typename Type1::offset_type; + using ArrayType1 = typename TypeTraits::ArrayType; + using ArrayType2 = typename TypeTraits::ArrayType; + + static Status Execute(KernelContext* ctx, StringTransform* transform, + const ExecBatch& batch, Datum* out) { + if (batch.num_values() != 2) { + return Status::Invalid("Invalid arity for binary string transform"); + } + + if (batch[0].is_array()) { + if (batch[1].is_array()) { + return ExecArrayArray(ctx, transform, batch[0].array(), batch[1].array(), out); + } else if (batch[1].is_scalar()) { + return ExecArrayScalar(ctx, transform, batch[0].array(), batch[1].scalar(), out); + } + } else if (batch[0].is_scalar()) { + if (batch[1].is_array()) { + return ExecScalarArray(ctx, transform, batch[0].scalar(), batch[1].array(), out); + } else if (batch[1].is_scalar()) { + return ExecScalarScalar(ctx, transform, batch[0].scalar(), batch[1].scalar(), + out); + } + } + return Status::Invalid("Invalid ExecBatch kind for binary string transform"); + } + + private: + static Status ExecScalarScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar1, + const std::shared_ptr& scalar2, Datum* out) { + if (!scalar1->is_valid || !scalar2->is_valid) { + return Status::OK(); + } + + const auto& input1 = checked_cast(*scalar1); + auto input_ncodeunits = input1.value->size(); + auto input_nstrings = 1; + auto output_ncodeunits_max = + transform->MaxCodeunits(input_nstrings, input_ncodeunits, scalar2); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); + auto result = checked_cast(out->scalar().get()); + result->is_valid = true; + result->value = value_buffer; + auto output_str = value_buffer->mutable_data(); + + auto input1_string = input1.value->data(); + auto encoded_nbytes = + transform->Transform(input1_string, input_ncodeunits, scalar2, output_str); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); + } + DCHECK_LE(encoded_nbytes, output_ncodeunits_max); + return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); + } + + static Status ExecArrayScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data1, + const std::shared_ptr& scalar2, Datum* out) { + if (!scalar2->is_valid) { + return Status::OK(); + } + + ArrayType1 input1(data1); + auto input1_ncodeunits = input1.total_values_length(); + auto input1_nstrings = input1.length(); + auto output_ncodeunits_max = + transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, scalar2); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ArrayData* output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); + output->buffers[2] = values_buffer; + + // String offsets are preallocated + auto output_string_offsets = output->GetMutableValues(1); + auto output_str = output->buffers[2]->mutable_data(); + output_string_offsets[0] = 0; + + offset_type output_ncodeunits = 0; + for (int64_t i = 0; i < input1_nstrings; ++i) { + if (!input1.IsNull(i)) { + offset_type input1_string_ncodeunits; + auto input1_string = input1.GetValue(i, &input1_string_ncodeunits); + auto encoded_nbytes = + transform->Transform(input1_string, input1_string_ncodeunits, scalar2, + output_str + output_ncodeunits); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); + } + output_ncodeunits += encoded_nbytes; + } + output_string_offsets[i + 1] = output_ncodeunits; + } + DCHECK_LE(output_ncodeunits, output_ncodeunits_max); + + // Trim the codepoint buffer, since we allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + return Status::OK(); + } + + static Status ExecScalarArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar1, + const std::shared_ptr& data2, Datum* out) { + return Status::NotImplemented( + "Binary string transforms with (scalar, array) inputs are not supported for the " + "general case"); + } + + static Status ExecArrayArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data1, + const std::shared_ptr& data2, Datum* out) { + ArrayType1 input1(data1); + ArrayType2 input2(data2); + + auto input1_ncodeunits = input1.total_values_length(); + auto input1_nstrings = input1.length(); + auto output_ncodeunits_max = + transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, data2); + if (output_ncodeunits_max > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + + ArrayData* output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); + output->buffers[2] = values_buffer; + + // String offsets are preallocated + auto output_string_offsets = output->GetMutableValues(1); + auto output_str = output->buffers[2]->mutable_data(); + output_string_offsets[0] = 0; + + offset_type output_ncodeunits = 0; + for (int64_t i = 0; i < input1_nstrings; ++i) { + if (!input1.IsNull(i) || !input2.IsNull(i)) { + offset_type input1_string_ncodeunits; + auto input1_string = input1.GetValue(i, &input1_string_ncodeunits); + auto scalar2 = *input2.GetScalar(i); + auto encoded_nbytes = + transform->Transform(input1_string, input1_string_ncodeunits, scalar2, + output_str + output_ncodeunits); + if (encoded_nbytes < 0) { + return transform->InvalidStatus(); + } + output_ncodeunits += encoded_nbytes; + } + output_string_offsets[i + 1] = output_ncodeunits; + } + DCHECK_LE(output_ncodeunits, output_ncodeunits_max); + + // Trim the codepoint buffer, since we allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + } + + // Binary derived classes should define this method: + // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, const + // std::shared_ptr& input2, uint8_t* output); +}; + +template +struct StringBinaryTransformExec + : public StringBinaryTransformExecBase { + using StringBinaryTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform; + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + +template +struct StringBinaryTransformExecWithState + : public StringBinaryTransformExecBase { + using State = typename StringTransform::State; + using StringBinaryTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform(State::Get(ctx)); + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + #ifdef ARROW_WITH_UTF8PROC struct FunctionalCaseMappingTransform : public StringTransformBase { @@ -4231,7 +4455,6 @@ const FunctionDoc utf8_reverse_doc( "clusters. Hence, it will not correctly reverse grapheme clusters\n" "composed of multiple codepoints."), {"strings"}); - } // namespace void RegisterScalarStringAscii(FunctionRegistry* registry) { @@ -4255,7 +4478,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { &ascii_rtrim_whitespace_doc); MakeUnaryStringBatchKernel("ascii_reverse", registry, &ascii_reverse_doc); MakeUnaryStringBatchKernel("utf8_reverse", registry, &utf8_reverse_doc); - MakeUnaryStringBatchKernelWithState("ascii_center", registry, &ascii_center_doc); MakeUnaryStringBatchKernelWithState("ascii_lpad", registry, &ascii_lpad_doc);