Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 237 additions & 15 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,6 @@ struct StringTransformBase {
virtual Status InvalidStatus() {
return Status::Invalid("Invalid UTF8 sequence in input");
}

// Derived classes should also define this method:
// int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
// uint8_t* output);
};

template <typename Type, typename StringTransform>
Expand All @@ -319,36 +315,38 @@ struct StringTransformExecBase {

static Status Execute(KernelContext* ctx, StringTransform* transform,
const ExecBatch& batch, Datum* out) {
if (batch.num_values() != 1) {
return Status::Invalid("Invalid arity for unary string transform");
}

if (batch[0].kind() == Datum::ARRAY) {
return ExecArray(ctx, transform, batch[0].array(), out);
} else if (batch[0].kind() == Datum::SCALAR) {
return ExecScalar(ctx, transform, batch[0].scalar(), out);
}
DCHECK_EQ(batch[0].kind(), Datum::SCALAR);
return ExecScalar(ctx, transform, batch[0].scalar(), out);
return Status::Invalid("Invalid ExecBatch kind for unary string transform");
}

static Status ExecArray(KernelContext* ctx, StringTransform* transform,
const std::shared_ptr<ArrayData>& data, Datum* out) {
ArrayType input(data);
ArrayData* output = out->mutable_array();

const int64_t input_ncodeunits = input.total_values_length();
const int64_t input_nstrings = input.length();

const int64_t output_ncodeunits_max =
transform->MaxCodeunits(input_nstrings, input_ncodeunits);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}

ArrayData* output = out->mutable_array();
ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max));
output->buffers[2] = values_buffer;

// String offsets are preallocated
offset_type* output_string_offsets = output->GetMutableValues<offset_type>(1);
uint8_t* output_str = output->buffers[2]->mutable_data();
offset_type output_ncodeunits = 0;

output_string_offsets[0] = 0;
for (int64_t i = 0; i < input_nstrings; i++) {
if (!input.IsNull(i)) {
Expand All @@ -375,16 +373,16 @@ struct StringTransformExecBase {
if (!input.is_valid) {
return Status::OK();
}
auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
result->is_valid = true;
const int64_t data_nbytes = static_cast<int64_t>(input.value->size());

const int64_t output_ncodeunits_max = transform->MaxCodeunits(1, data_nbytes);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}

ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max));
auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
result->is_valid = true;
result->value = value_buffer;
auto encoded_nbytes = static_cast<offset_type>(transform->Transform(
input.value->data(), data_nbytes, value_buffer->mutable_data()));
Expand All @@ -394,6 +392,10 @@ struct StringTransformExecBase {
DCHECK_LE(encoded_nbytes, output_ncodeunits_max);
return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true);
}

// Unary derived classes should define this method:
// int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits,
// uint8_t* output);
};

template <typename Type, typename StringTransform>
Expand All @@ -420,6 +422,228 @@ struct StringTransformExecWithState
}
};

struct StringBinaryTransformBase {
virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
return Status::OK();
}

// Return the maximum total size of the output in codeunits (i.e. bytes)
// given input characteristics.
virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits,
const std::shared_ptr<Scalar>& input2) {
return input_ncodeunits;
}

// Return the maximum total size of the output in codeunits (i.e. bytes)
// given input characteristics.
virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits,
const std::shared_ptr<ArrayData>& data2) {
return input_ncodeunits;
}
};

/// Kernel exec generator for binary string transforms.
/// The first parameter is expected to always be a string type while the second parameter
/// is generic. It supports executions of the form:
/// * Scalar, Scalar
/// * Array, Scalar - scalar is broadcasted and paired with all values of array
/// * Array, Array - arrays are processed element-wise
/// * Scalar, Array - not supported by default
template <typename Type1, typename Type2, typename StringTransform>
struct StringBinaryTransformExecBase {
using offset_type = typename Type1::offset_type;
using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
using ArrayType2 = typename TypeTraits<Type2>::ArrayType;

static Status Execute(KernelContext* ctx, StringTransform* transform,
const ExecBatch& batch, Datum* out) {
if (batch.num_values() != 2) {
return Status::Invalid("Invalid arity for binary string transform");
}

if (batch[0].is_array()) {
if (batch[1].is_array()) {
return ExecArrayArray(ctx, transform, batch[0].array(), batch[1].array(), out);
} else if (batch[1].is_scalar()) {
return ExecArrayScalar(ctx, transform, batch[0].array(), batch[1].scalar(), out);
}
} else if (batch[0].is_scalar()) {
if (batch[1].is_array()) {
return ExecScalarArray(ctx, transform, batch[0].scalar(), batch[1].array(), out);
} else if (batch[1].is_scalar()) {
return ExecScalarScalar(ctx, transform, batch[0].scalar(), batch[1].scalar(),
out);
}
}
return Status::Invalid("Invalid ExecBatch kind for binary string transform");
}

private:
static Status ExecScalarScalar(KernelContext* ctx, StringTransform* transform,
const std::shared_ptr<Scalar>& scalar1,
const std::shared_ptr<Scalar>& scalar2, Datum* out) {
if (!scalar1->is_valid || !scalar2->is_valid) {
return Status::OK();
}

const auto& input1 = checked_cast<const BaseBinaryScalar&>(*scalar1);
auto input_ncodeunits = input1.value->size();
auto input_nstrings = 1;
auto output_ncodeunits_max =
transform->MaxCodeunits(input_nstrings, input_ncodeunits, scalar2);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}

ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max));
auto result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
result->is_valid = true;
result->value = value_buffer;
auto output_str = value_buffer->mutable_data();

auto input1_string = input1.value->data();
auto encoded_nbytes =
transform->Transform(input1_string, input_ncodeunits, scalar2, output_str);
if (encoded_nbytes < 0) {
return transform->InvalidStatus();
}
DCHECK_LE(encoded_nbytes, output_ncodeunits_max);
return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true);
}

static Status ExecArrayScalar(KernelContext* ctx, StringTransform* transform,
const std::shared_ptr<ArrayData>& data1,
const std::shared_ptr<Scalar>& scalar2, Datum* out) {
if (!scalar2->is_valid) {
return Status::OK();
}

ArrayType1 input1(data1);
auto input1_ncodeunits = input1.total_values_length();
auto input1_nstrings = input1.length();
auto output_ncodeunits_max =
transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, scalar2);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}

ArrayData* output = out->mutable_array();
ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max));
output->buffers[2] = values_buffer;

// String offsets are preallocated
auto output_string_offsets = output->GetMutableValues<offset_type>(1);
auto output_str = output->buffers[2]->mutable_data();
output_string_offsets[0] = 0;

offset_type output_ncodeunits = 0;
for (int64_t i = 0; i < input1_nstrings; ++i) {
if (!input1.IsNull(i)) {
offset_type input1_string_ncodeunits;
auto input1_string = input1.GetValue(i, &input1_string_ncodeunits);
auto encoded_nbytes =
transform->Transform(input1_string, input1_string_ncodeunits, scalar2,
output_str + output_ncodeunits);
if (encoded_nbytes < 0) {
return transform->InvalidStatus();
}
output_ncodeunits += encoded_nbytes;
}
output_string_offsets[i + 1] = output_ncodeunits;
}
DCHECK_LE(output_ncodeunits, output_ncodeunits_max);

// Trim the codepoint buffer, since we allocated too much
return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true);
return Status::OK();
}

static Status ExecScalarArray(KernelContext* ctx, StringTransform* transform,
const std::shared_ptr<Scalar>& scalar1,
const std::shared_ptr<ArrayData>& data2, Datum* out) {
return Status::NotImplemented(
"Binary string transforms with (scalar, array) inputs are not supported for the "
"general case");
}

static Status ExecArrayArray(KernelContext* ctx, StringTransform* transform,
const std::shared_ptr<ArrayData>& data1,
const std::shared_ptr<ArrayData>& data2, Datum* out) {
ArrayType1 input1(data1);
ArrayType2 input2(data2);

auto input1_ncodeunits = input1.total_values_length();
auto input1_nstrings = input1.length();
auto output_ncodeunits_max =
transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, data2);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}

ArrayData* output = out->mutable_array();
ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max));
output->buffers[2] = values_buffer;

// String offsets are preallocated
auto output_string_offsets = output->GetMutableValues<offset_type>(1);
auto output_str = output->buffers[2]->mutable_data();
output_string_offsets[0] = 0;

offset_type output_ncodeunits = 0;
for (int64_t i = 0; i < input1_nstrings; ++i) {
if (!input1.IsNull(i) || !input2.IsNull(i)) {
offset_type input1_string_ncodeunits;
auto input1_string = input1.GetValue(i, &input1_string_ncodeunits);
auto scalar2 = *input2.GetScalar(i);
auto encoded_nbytes =
transform->Transform(input1_string, input1_string_ncodeunits, scalar2,
output_str + output_ncodeunits);
if (encoded_nbytes < 0) {
return transform->InvalidStatus();
}
output_ncodeunits += encoded_nbytes;
}
output_string_offsets[i + 1] = output_ncodeunits;
}
DCHECK_LE(output_ncodeunits, output_ncodeunits_max);

// Trim the codepoint buffer, since we allocated too much
return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true);
}

// Binary derived classes should define this method:
// int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, const
// std::shared_ptr<Scalar>& input2, uint8_t* output);
};

template <typename Type1, typename Type2, typename StringTransform>
struct StringBinaryTransformExec
: public StringBinaryTransformExecBase<Type1, Type2, StringTransform> {
using StringBinaryTransformExecBase<Type1, Type2, StringTransform>::Execute;

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
StringTransform transform;
RETURN_NOT_OK(transform.PreExec(ctx, batch, out));
return Execute(ctx, &transform, batch, out);
}
};

template <typename Type1, typename Type2, typename StringTransform>
struct StringBinaryTransformExecWithState
: public StringBinaryTransformExecBase<Type1, Type2, StringTransform> {
using State = typename StringTransform::State;
using StringBinaryTransformExecBase<Type1, Type2, StringTransform>::Execute;

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
StringTransform transform(State::Get(ctx));
RETURN_NOT_OK(transform.PreExec(ctx, batch, out));
return Execute(ctx, &transform, batch, out);
}
};

#ifdef ARROW_WITH_UTF8PROC

struct FunctionalCaseMappingTransform : public StringTransformBase {
Expand Down Expand Up @@ -4231,7 +4455,6 @@ const FunctionDoc utf8_reverse_doc(
"clusters. Hence, it will not correctly reverse grapheme clusters\n"
"composed of multiple codepoints."),
{"strings"});

} // namespace

void RegisterScalarStringAscii(FunctionRegistry* registry) {
Expand All @@ -4255,7 +4478,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
&ascii_rtrim_whitespace_doc);
MakeUnaryStringBatchKernel<AsciiReverse>("ascii_reverse", registry, &ascii_reverse_doc);
MakeUnaryStringBatchKernel<Utf8Reverse>("utf8_reverse", registry, &utf8_reverse_doc);

MakeUnaryStringBatchKernelWithState<AsciiCenter>("ascii_center", registry,
&ascii_center_doc);
MakeUnaryStringBatchKernelWithState<AsciiLPad>("ascii_lpad", registry, &ascii_lpad_doc);
Expand Down