From c188236a8fee4fee926b49f409100c5582850905 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 11 Jun 2021 13:56:35 -0400 Subject: [PATCH 1/9] ARROW-12709: [C++] Add var_args_join --- cpp/src/arrow/compute/api_scalar.h | 19 ++ .../arrow/compute/kernels/scalar_string.cc | 232 ++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 116 +++++++++ docs/source/cpp/compute.rst | 16 +- docs/source/python/api/compute.rst | 17 ++ python/pyarrow/_compute.pyx | 31 +++ python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 16 ++ python/pyarrow/tests/test_compute.py | 25 ++ 9 files changed, 468 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 6e9a9340f2c..d6fd7075df2 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -48,6 +48,25 @@ struct ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { bool skip_nulls; }; +/// Options for var_args_join. +struct ARROW_EXPORT JoinOptions : public FunctionOptions { + /// How to handle null values. (A null separator always results in a null output.) + enum NullHandlingBehavior { + /// A null in any input results in a null in the output. + EMIT_NULL, + /// Nulls in inputs are skipped. + SKIP, + /// Nulls in inputs are replaced with the replacement string. + REPLACE, + }; + explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL, + std::string null_replacement = "") + : null_handling(null_handling), null_replacement(std::move(null_replacement)) {} + static JoinOptions Defaults() { return JoinOptions(); } + NullHandlingBehavior null_handling; + std::string null_replacement; +}; + struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false) : pattern(std::move(pattern)), ignore_case(ignore_case) {} diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index cd054fcea0e..dd741eda713 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -3367,6 +3367,237 @@ void AddBinaryJoin(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } +using VarArgsJoinState = OptionsWrapper; + +template +struct VarArgsJoin { + using ArrayType = typename TypeTraits::ArrayType; + using BuilderType = typename TypeTraits::BuilderType; + using offset_type = typename Type::offset_type; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + JoinOptions options = VarArgsJoinState::Get(ctx); + // Last argument is the separator (for consistency with binary_join) + if (std::all_of(batch.values.begin(), batch.values.end(), + [](const Datum& d) { return d.is_scalar(); })) { + return ExecOnlyScalar(ctx, options, batch, out); + } + return ExecContainingArrays(ctx, options, batch, out); + } + + static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options, + const ExecBatch& batch, Datum* out) { + BaseBinaryScalar* output = checked_cast(out->scalar().get()); + const size_t num_args = batch.values.size(); + if (num_args == 1) { + // Only separator, no values + ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0)); + output->is_valid = batch.values[0].scalar()->is_valid; + return Status::OK(); + } + + int64_t final_size = CalculateRowSize(options, batch, 0); + if (final_size < 0) { + ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0)); + output->is_valid = false; + return Status::OK(); + } + ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size)); + const auto separator = UnboxScalar::Unbox(*batch.values.back().scalar()); + uint8_t* buf = output->value->mutable_data(); + bool first = true; + for (size_t i = 0; i < num_args - 1; i++) { + const Scalar& scalar = *batch[i].scalar(); + util::string_view s; + if (scalar.is_valid) { + s = UnboxScalar::Unbox(scalar); + } else { + switch (options.null_handling) { + case JoinOptions::EMIT_NULL: + // Handled by CalculateRowSize + DCHECK(false) << "unreachable"; + break; + case JoinOptions::SKIP: + continue; + case JoinOptions::REPLACE: + s = options.null_replacement; + break; + } + } + if (!first) { + buf = std::copy(separator.begin(), separator.end(), buf); + } + first = false; + buf = std::copy(s.begin(), s.end(), buf); + } + output->is_valid = true; + return Status::OK(); + } + + static Status ExecContainingArrays(KernelContext* ctx, const JoinOptions& options, + const ExecBatch& batch, Datum* out) { + // Presize data to avoid reallocations + int64_t final_size = 0; + for (int64_t i = 0; i < batch.length; i++) { + auto size = CalculateRowSize(options, batch, i); + if (size > 0) final_size += size; + } + BuilderType builder(ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(batch.length)); + RETURN_NOT_OK(builder.ReserveData(final_size)); + + std::vector valid_cols(batch.values.size()); + for (size_t row = 0; row < static_cast(batch.length); row++) { + size_t num_valid = 0; // Not counting separator + for (size_t col = 0; col < batch.values.size(); col++) { + bool valid = false; + if (batch[col].is_scalar()) { + valid = batch[col].scalar()->is_valid; + } else { + const ArrayData& array = *batch[col].array(); + valid = !array.MayHaveNulls() || + BitUtil::GetBit(array.buffers[0]->data(), array.offset + row); + } + if (valid) { + valid_cols[col] = &batch[col]; + if (col < batch.values.size() - 1) num_valid++; + } else { + valid_cols[col] = nullptr; + } + } + + if (!valid_cols.back()) { + // Separator is null + builder.UnsafeAppendNull(); + continue; + } else if (batch.values.size() == 1) { + // Only given separator + builder.UnsafeAppendEmptyValue(); + continue; + } else if (num_valid < batch.values.size() - 1) { + // We had some nulls + if (options.null_handling == JoinOptions::EMIT_NULL) { + builder.UnsafeAppendNull(); + continue; + } + } + const auto separator = Lookup(*valid_cols.back(), row); + bool first = true; + for (size_t col = 0; col < batch.values.size() - 1; col++) { + const Datum* datum = valid_cols[col]; + util::string_view value; + if (!datum) { + switch (options.null_handling) { + case JoinOptions::EMIT_NULL: + DCHECK(false) << "unreachable"; + break; + case JoinOptions::SKIP: + continue; + case JoinOptions::REPLACE: + value = options.null_replacement; + break; + } + } else { + value = Lookup(*datum, row); + } + if (first) { + builder.UnsafeAppend(value); + first = false; + continue; + } + builder.UnsafeExtendCurrent(separator); + builder.UnsafeExtendCurrent(value); + } + } + + std::shared_ptr string_array; + RETURN_NOT_OK(builder.Finish(&string_array)); + *out = *string_array->data(); + out->mutable_array()->type = batch[0].type(); + DCHECK_EQ(batch.length, out->array()->length); + return Status::OK(); + } + + // Unbox a scalar or the given element of an array. + static util::string_view Lookup(const Datum& datum, size_t row) { + if (datum.is_scalar()) { + return UnboxScalar::Unbox(*datum.scalar()); + } + const ArrayData& array = *datum.array(); + const offset_type* offsets = array.GetValues(1); + const uint8_t* data = array.GetValues(2, /*absolute_offset=*/0); + const int64_t length = offsets[row + 1] - offsets[row]; + return util::string_view(reinterpret_cast(data + offsets[row]), length); + } + + // Compute the length of the output for the given position, or -1 if it would be null. + static int64_t CalculateRowSize(const JoinOptions& options, const ExecBatch& batch, + const int64_t index) { + const auto num_args = batch.values.size(); + int64_t final_size = 0; + int64_t num_non_null_args = 0; + for (size_t i = 0; i < num_args; i++) { + int64_t element_size = 0; + bool valid = true; + if (batch[i].is_scalar()) { + const Scalar& scalar = *batch[i].scalar(); + valid = scalar.is_valid; + element_size = UnboxScalar::Unbox(scalar).size(); + } else { + const ArrayData& array = *batch[i].array(); + valid = !array.MayHaveNulls() || + BitUtil::GetBit(array.buffers[0]->data(), array.offset + index); + const offset_type* offsets = array.GetValues(1); + element_size = offsets[index + 1] - offsets[index]; + } + if (i == num_args - 1) { + if (!valid) return -1; + if (num_non_null_args > 1) { + // Add separator size (only if there were values to join) + final_size += (num_non_null_args - 1) * element_size; + } + break; + } + if (!valid) { + switch (options.null_handling) { + case JoinOptions::EMIT_NULL: + return -1; + case JoinOptions::SKIP: + continue; + case JoinOptions::REPLACE: + element_size = options.null_replacement.size(); + break; + } + } + num_non_null_args++; + final_size += element_size; + } + return final_size; + } +}; + +const FunctionDoc var_args_join_doc( + "Join string arguments into one, using the last argument as the separator", + ("Insert the last argument of `strings` between the rest of the elements, " + "and concatenate them.\n" + "Any null separator element emits a null output. Null elements either " + "emit a null (the default), are skipped, or replaced with a given string.\n"), + {"*strings"}, "JoinOptions"); + +const auto kDefaultJoinOptions = JoinOptions::Defaults(); + +void AddVarArgsJoin(FunctionRegistry* registry) { + auto func = + std::make_shared("var_args_join", Arity::VarArgs(/*min_args=*/1), + &var_args_join_doc, &kDefaultJoinOptions); + for (const auto& ty : BaseBinaryTypes()) { + DCHECK_OK(func->AddKernel({InputType(ty)}, ty, + GenerateTypeAgnosticVarBinaryBase(ty), + VarArgsJoinState::Init)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); +} + template