From 4458a3de09e770e7f42e3c97c99b7a4ae9372e4b Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Tue, 22 Dec 2020 13:21:43 +0100 Subject: [PATCH 1/5] ARROW-10959: [C++] Add scalar string join kernel --- cpp/src/arrow/array/builder_binary.h | 19 +++ cpp/src/arrow/compute/function.cc | 5 +- .../arrow/compute/kernels/scalar_string.cc | 140 ++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 29 ++++ cpp/src/arrow/compute/kernels/test_util.cc | 78 +++++++--- cpp/src/arrow/compute/kernels/test_util.h | 5 + docs/source/cpp/compute.rst | 18 ++- python/pyarrow/tests/test_compute.py | 11 ++ 8 files changed, 281 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index bc49c7d6787..16ca1694b9f 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -77,6 +77,25 @@ class BaseBinaryBuilder : public ArrayBuilder { return Append(value.data(), static_cast(value.size())); } + /// AppendCurrent does not add a new offset + Status AppendCurrent(const uint8_t* value, offset_type length) { + // Safety check for UBSAN. + if (ARROW_PREDICT_TRUE(length > 0)) { + ARROW_RETURN_NOT_OK(ValidateOverflow(length)); + ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length)); + } + + return Status::OK(); + } + + Status AppendCurrent(const char* value, offset_type length) { + return AppendCurrent(reinterpret_cast(value), length); + } + + Status AppendCurrent(util::string_view value) { + return AppendCurrent(value.data(), static_cast(value.size())); + } + Status AppendNulls(int64_t length) final { const int64_t num_bytes = value_data_builder_.length(); ARROW_RETURN_NOT_OK(Reserve(length)); diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index f74bb245d77..0f94baaedfc 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -210,8 +210,9 @@ Status Function::Validate() const { if (arity_.is_varargs && arg_count == arity_.num_args + 1) { return Status::OK(); } - return Status::Invalid("In function '", name_, - "': ", "number of argument names != function arity"); + return Status::Invalid( + "In function '", name_, + "': ", "number of argument names for function documentation != function arity"); } return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 1d87bd86c67..96852c91f66 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2637,6 +2637,145 @@ void AddUtf8Length(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } +// binary join + +template +struct BinaryJoin { + using ArrayType = typename TypeTraits::ArrayType; + using ListArrayType = ListArray; + using offset_type = typename Type::offset_type; + using BuilderType = typename TypeTraits::BuilderType; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].kind() == Datum::SCALAR) { + const ListScalar& list = checked_cast(*batch[0].scalar()); + if (!list.is_valid) { + return Status::OK(); + } + if (batch[1].kind() == Datum::SCALAR) { + const BaseBinaryScalar& separator_scalar = + checked_cast(*batch[1].scalar()); + if (!separator_scalar.is_valid) { + return Status::OK(); + } + util::string_view separator(*separator_scalar.value); + + TypedBufferBuilder builder(ctx->memory_pool()); + auto Append = [&](util::string_view value) { + return builder.Append(reinterpret_cast(value.data()), + static_cast(value.size())); + }; + + const ArrayType* strings = static_cast(list.value.get()); + if (strings->null_count() > 0) { + // since the input list is not null, the out datum needs to be assigned to + *out = MakeNullScalar(list.value->type()); + return Status::OK(); + } + if (strings->length() > 0) { + RETURN_NOT_OK(Append(strings->GetView(0))); + for (int64_t j = 1; j < strings->length(); j++) { + RETURN_NOT_OK(Append(separator)); + RETURN_NOT_OK(Append(strings->GetView(j))); + } + } + std::shared_ptr string_buffer; + RETURN_NOT_OK(builder.Finish(&string_buffer)); + ARROW_ASSIGN_OR_RAISE(auto scalar_right_type, + MakeScalar>( + list.value->type(), std::move(string_buffer))); + *out = scalar_right_type; + } + // XXX do we want to support scalar[list[str]] with array[str] ? + } else { + const ListArrayType list(batch[0].array()); + ArrayData* output = out->mutable_array(); + + BuilderType builder(ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(list.length())); + if (batch[1].kind() == Datum::ARRAY) { + ArrayType separator_array(batch[1].array()); + for (int64_t i = 0; i < list.length(); ++i) { + const std::shared_ptr slice = list.value_slice(i); + const ArrayType* strings = static_cast(slice.get()); + if ((strings->null_count() > 0) || (list.IsNull(i)) || + separator_array.IsNull(i)) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + const auto separator = separator_array.GetView(i); + if (strings->length() > 0) { + RETURN_NOT_OK(builder.Append(strings->GetView(0))); + for (int64_t j = 1; j < strings->length(); j++) { + RETURN_NOT_OK(builder.AppendCurrent(separator)); + RETURN_NOT_OK(builder.AppendCurrent(strings->GetView(j))); + } + } else { + RETURN_NOT_OK(builder.AppendEmptyValue()); + } + } + } + } else if (batch[1].kind() == Datum::SCALAR) { + const auto& separator_scalar = + checked_cast(*batch[1].scalar()); + if (!separator_scalar.is_valid) { + ARROW_ASSIGN_OR_RAISE( + auto nulls, + MakeArrayOfNull(list.value_type(), list.length(), ctx->memory_pool())); + *output = *nulls->data(); + output->type = list.value_type(); + return Status::OK(); + } + util::string_view separator(*separator_scalar.value); + + for (int64_t i = 0; i < list.length(); ++i) { + const std::shared_ptr slice = list.value_slice(i); + const ArrayType* strings = static_cast(slice.get()); + if ((strings->null_count() > 0) || (list.IsNull(i))) { + RETURN_NOT_OK(builder.AppendNull()); + } else { + if (strings->length() > 0) { + RETURN_NOT_OK(builder.Append(strings->GetView(0))); + for (int64_t j = 1; j < strings->length(); j++) { + RETURN_NOT_OK(builder.AppendCurrent(separator)); + RETURN_NOT_OK(builder.AppendCurrent(strings->GetView(j))); + } + } else { + RETURN_NOT_OK(builder.AppendEmptyValue()); + } + } + } + } + std::shared_ptr string_array; + RETURN_NOT_OK(builder.Finish(&string_array)); + *output = *string_array->data(); + // correct the output type based on the input + output->type = list.value_type(); + } + return Status::OK(); + } +}; + +const FunctionDoc binary_join_doc( + "Join a list of strings together with a `separator` to form a single string", + ("Insert `separator` between each list element, and concatenate them."), + {"list", "separator"}); + +void AddJoin(FunctionRegistry* registry) { + auto func = + std::make_shared("binary_join", Arity::Binary(), &binary_join_doc); + for (const std::shared_ptr& ty : BaseBinaryTypes()) { + auto exec = GenerateTypeAgnosticVarBinaryBase(*ty); + // TODO add large_list inputs + DCHECK_OK( + func->AddKernel({InputType::Array(list(ty)), InputType::Scalar(ty)}, ty, exec)); + DCHECK_OK( + func->AddKernel({InputType::Array(list(ty)), InputType::Array(ty)}, ty, exec)); + DCHECK_OK( + func->AddKernel({InputType::Scalar(list(ty)), InputType::Scalar(ty)}, ty, exec)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); +} + template