Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,54 @@ struct AsciiLength {

using TransformFunc = std::function<void(const uint8_t*, int64_t, uint8_t*)>;

// Transform a buffer of offsets to one which begins with 0 and has same
// value lengths.
template <typename T>
Status GetShiftedOffsets(KernelContext* ctx, const Buffer& input_buffer, int64_t offset,
int64_t length, std::shared_ptr<Buffer>* out) {
ARROW_ASSIGN_OR_RAISE(*out, ctx->Allocate((length + 1) * sizeof(T)));
const T* input_offsets = reinterpret_cast<const T*>(input_buffer.data()) + offset;
T* out_offsets = reinterpret_cast<T*>((*out)->mutable_data());
T first_offset = *input_offsets;
for (int64_t i = 0; i < length; ++i) {
*out_offsets++ = input_offsets[i] - first_offset;
}
*out_offsets = input_offsets[length] - first_offset;
return Status::OK();
}

// Apply `transform` to input character data- this function cannot change the
// length
template <typename Type>
void StringDataTransform(KernelContext* ctx, const ExecBatch& batch,
TransformFunc transform, Datum* out) {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using offset_type = typename Type::offset_type;

if (batch[0].kind() == Datum::ARRAY) {
const ArrayData& input = *batch[0].array();
ArrayType input_boxed(batch[0].array());

ArrayData* out_arr = out->mutable_array();
// Reuse offsets from input
out_arr->buffers[1] = input.buffers[1];
int64_t data_nbytes = input.buffers[2]->size();

if (input.offset == 0) {
// We can reuse offsets from input
out_arr->buffers[1] = input.buffers[1];
} else {
DCHECK(input.buffers[1]);
// We must allocate new space for the offsets and shift the existing offsets
KERNEL_RETURN_IF_ERROR(
ctx, GetShiftedOffsets<offset_type>(ctx, *input.buffers[1], input.offset,
input.length, &out_arr->buffers[1]));
}

// Allocate space for output data
int64_t data_nbytes = input_boxed.total_values_length();
KERNEL_RETURN_IF_ERROR(ctx, ctx->Allocate(data_nbytes).Value(&out_arr->buffers[2]));
transform(input.buffers[2]->data(), data_nbytes, out_arr->buffers[2]->mutable_data());
if (input.length > 0) {
transform(input.buffers[2]->data() + input_boxed.value_offset(0), data_nbytes,
out_arr->buffers[2]->mutable_data());
}
} else {
const auto& input = checked_cast<const BaseBinaryScalar&>(*batch[0].scalar());
auto result = checked_pointer_cast<BaseBinaryScalar>(MakeNullScalar(out->type()));
Expand All @@ -77,9 +115,12 @@ void TransformAsciiUpper(const uint8_t* input, int64_t length, uint8_t* output)
}
}

void AsciiUpperExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
StringDataTransform(ctx, batch, TransformAsciiUpper, out);
}
template <typename Type>
struct AsciiUpper {
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
StringDataTransform<Type>(ctx, batch, TransformAsciiUpper, out);
}
};

void TransformAsciiLower(const uint8_t* input, int64_t length, uint8_t* output) {
for (int64_t i = 0; i < length; ++i) {
Expand All @@ -91,9 +132,12 @@ void TransformAsciiLower(const uint8_t* input, int64_t length, uint8_t* output)
}
}

void AsciiLowerExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
StringDataTransform(ctx, batch, TransformAsciiLower, out);
}
template <typename Type>
struct AsciiLower {
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
StringDataTransform<Type>(ctx, batch, TransformAsciiLower, out);
}
};

void AddAsciiLength(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("ascii_length", Arity::Unary());
Expand Down Expand Up @@ -151,19 +195,21 @@ void AddStrptime(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

void MakeUnaryStringBatchKernel(std::string name, ArrayKernelExec exec,
FunctionRegistry* registry) {
template <template <typename> class ExecFunctor>
void MakeUnaryStringBatchKernel(std::string name, FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>(name, Arity::Unary());
DCHECK_OK(func->AddKernel({utf8()}, utf8(), exec));
DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), exec));
auto exec_32 = ExecFunctor<StringType>::Exec;
auto exec_64 = ExecFunctor<LargeStringType>::Exec;
DCHECK_OK(func->AddKernel({utf8()}, utf8(), exec_32));
DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), exec_64));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

} // namespace

void RegisterScalarStringAscii(FunctionRegistry* registry) {
MakeUnaryStringBatchKernel("ascii_upper", AsciiUpperExec, registry);
MakeUnaryStringBatchKernel("ascii_lower", AsciiLowerExec, registry);
MakeUnaryStringBatchKernel<AsciiUpper>("ascii_upper", registry);
MakeUnaryStringBatchKernel<AsciiLower>("ascii_lower", registry);
AddAsciiLength(registry);
AddStrptime(registry);
}
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@ TYPED_TEST(TestStringKernels, AsciiLength) {
"[3, null, 0, 1]");
}

TYPED_TEST(TestStringKernels, DISABLED_AsciiUpper) {
this->CheckUnary("ascii_upper", "[\"aAazZæÆ&\", null, \"\", \"b\"]",
this->string_type(), "[\"AAAZZæÆ&\", null, \"\", \"B\"]");
TYPED_TEST(TestStringKernels, AsciiUpper) {
this->CheckUnary("ascii_upper", "[]", this->string_type(), "[]");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this implicitly test sliced arrays?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does now -- @bkietz added this in ARROW-971

this->CheckUnary("ascii_upper", "[\"aAazZæÆ&\", null, \"\", \"bbb\"]",
this->string_type(), "[\"AAAZZæÆ&\", null, \"\", \"BBB\"]");
}

TYPED_TEST(TestStringKernels, DISABLED_AsciiLower) {
this->CheckUnary("ascii_lower", "[\"aAazZæÆ&\", null, \"\", \"b\"]",
this->string_type(), "[\"aaazzæÆ&\", null, \"\", \"b\"]");
TYPED_TEST(TestStringKernels, AsciiLower) {
this->CheckUnary("ascii_lower", "[]", this->string_type(), "[]");
this->CheckUnary("ascii_lower", "[\"aAazZæÆ&\", null, \"\", \"BBB\"]",
this->string_type(), "[\"aaazzæÆ&\", null, \"\", \"bbb\"]");
}

TYPED_TEST(TestStringKernels, Strptime) {
Expand Down