From d7f22300504c138d5b28556b0e712426b5f43613 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 17 Jun 2021 10:30:09 -0400 Subject: [PATCH 1/3] ARROW-13220: [C++] Implement 'choose' function --- .../arrow/compute/kernels/scalar_if_else.cc | 256 ++++++++++++++++++ .../kernels/scalar_if_else_benchmark.cc | 34 +++ .../compute/kernels/scalar_if_else_test.cc | 187 +++++++++++++ docs/source/cpp/compute.rst | 49 ++-- docs/source/python/api/compute.rst | 1 + 5 files changed, 507 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index ff308a673a3..e4bd660d3bd 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1182,6 +1182,20 @@ void CopyOneArrayValue(const DataType& type, const uint8_t* in_valid, out_offset); } +template +void CopyOneValue(const Datum& in_values, const int64_t in_offset, uint8_t* out_valid, + uint8_t* out_values, const int64_t out_offset) { + if (in_values.is_array()) { + const ArrayData& array = *in_values.array(); + CopyOneArrayValue(*array.type, array.GetValues(0, 0), + array.GetValues(1, 0), array.offset + in_offset, + out_valid, out_values, out_offset); + } else { + CopyValues(in_values, in_offset, /*length=*/1, out_valid, out_values, + out_offset); + } +} + struct CaseWhenFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -1606,6 +1620,203 @@ struct CoalesceFunctor> { } }; +template +Status ExecScalarChoose(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& index_scalar = *batch[0].scalar(); + if (!index_scalar.is_valid) { + if (out->is_array()) { + auto source = MakeNullScalar(out->type()); + ArrayData* output = out->mutable_array(); + CopyValues(source, /*row=*/0, batch.length, + output->GetMutableValues(0, /*absolute_offset=*/0), + output->GetMutableValues(1, /*absolute_offset=*/0), + output->offset); + } + return Status::OK(); + } + auto index = UnboxScalar::Unbox(index_scalar); + if (index < 0 || static_cast(index + 1) >= batch.values.size()) { + return Status::IndexError("choose: index ", index, " out of range"); + } + auto source = batch.values[index + 1]; + if (out->is_scalar()) { + *out = source; + } else { + ArrayData* output = out->mutable_array(); + CopyValues(source, /*row=*/0, batch.length, + output->GetMutableValues(0, /*absolute_offset=*/0), + output->GetMutableValues(1, /*absolute_offset=*/0), + output->offset); + } + return Status::OK(); +} + +template +Status ExecArrayChoose(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ArrayData* output = out->mutable_array(); + const int64_t out_offset = output->offset; + // Need a null bitmap if any input has nulls + uint8_t* out_valid = nullptr; + if (std::any_of(batch.values.begin(), batch.values.end(), + [](const Datum& d) { return d.null_count() > 0; })) { + out_valid = output->buffers[0]->mutable_data(); + } else { + BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), out_offset, batch.length, + true); + } + uint8_t* out_values = output->buffers[1]->mutable_data(); + int64_t row = 0; + return VisitArrayValuesInline( + *batch[0].array(), + [&](int64_t index) { + if (index < 0 || static_cast(index + 1) >= batch.values.size()) { + return Status::IndexError("choose: index ", index, " out of range"); + } + const auto& source = batch.values[index + 1]; + CopyOneValue(source, row, out_valid, out_values, out_offset + row); + row++; + return Status::OK(); + }, + [&]() { + // Index is null, but we should still initialize the output with some value + const auto& source = batch.values[1]; + CopyOneValue(source, row, out_valid, out_values, out_offset + row); + BitUtil::ClearBit(out_valid, out_offset + row); + row++; + return Status::OK(); + }); +} + +template +struct ChooseFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch.values[0].is_scalar()) { + return ExecScalarChoose(ctx, batch, out); + } + return ExecArrayChoose(ctx, batch, out); + } +}; + +template <> +struct ChooseFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); + } +}; + +template +struct ChooseFunctor> { + using offset_type = typename Type::offset_type; + using BuilderType = typename TypeTraits::BuilderType; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch.values[0].is_scalar()) { + const auto& index_scalar = *batch[0].scalar(); + if (!index_scalar.is_valid) { + if (out->is_array()) { + auto null_scalar = MakeNullScalar(out->type()); + ARROW_ASSIGN_OR_RAISE( + auto temp_array, + MakeArrayFromScalar(*null_scalar, batch.length, ctx->memory_pool())); + *out->mutable_array() = *temp_array->data(); + } + return Status::OK(); + } + auto index = UnboxScalar::Unbox(index_scalar); + if (index < 0 || static_cast(index + 1) >= batch.values.size()) { + return Status::IndexError("choose: index ", index, " out of range"); + } + auto source = batch.values[index + 1]; + if (source.is_scalar() && out->is_array()) { + ARROW_ASSIGN_OR_RAISE( + auto temp_array, + MakeArrayFromScalar(*source.scalar(), batch.length, ctx->memory_pool())); + *out->mutable_array() = *temp_array->data(); + } else { + *out = source; + } + return Status::OK(); + } + + // Row-wise implementation + BuilderType builder(out->type(), ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(batch.length)); + int64_t reserve_data = 0; + for (const auto& value : batch.values) { + if (value.is_scalar()) { + if (!value.scalar()->is_valid) continue; + const auto row_length = + checked_cast(*value.scalar()).value->size(); + reserve_data = std::max(reserve_data, batch.length * row_length); + } + const ArrayData& arr = *value.array(); + const offset_type* offsets = arr.GetValues(1); + const offset_type values_length = offsets[arr.length] - offsets[0]; + reserve_data = std::max(reserve_data, values_length); + } + RETURN_NOT_OK(builder.ReserveData(reserve_data)); + int64_t row = 0; + RETURN_NOT_OK(VisitArrayValuesInline( + *batch[0].array(), + [&](int64_t index) { + if (index < 0 || static_cast(index + 1) >= batch.values.size()) { + return Status::IndexError("choose: index ", index, " out of range"); + } + const auto& source = batch.values[index + 1]; + return CopyValue(source, &builder, row++); + }, + [&]() { + row++; + return builder.AppendNull(); + })); + auto actual_type = out->type(); + std::shared_ptr temp_output; + RETURN_NOT_OK(builder.Finish(&temp_output)); + ArrayData* output = out->mutable_array(); + *output = *temp_output->data(); + // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase + output->type = std::move(actual_type); + return Status::OK(); + } + + static Status CopyValue(const Datum& datum, BuilderType* builder, int64_t row) { + if (datum.is_scalar()) { + return builder->AppendScalar(*datum.scalar()); + } + const ArrayData& source = *datum.array(); + if (!source.MayHaveNulls() || + BitUtil::GetBit(source.buffers[0]->data(), source.offset + row)) { + const uint8_t* data = source.buffers[2]->data(); + const offset_type* offsets = source.GetValues(1); + const offset_type offset0 = offsets[row]; + const offset_type offset1 = offsets[row + 1]; + return builder->Append(data + offset0, offset1 - offset0); + } + return builder->AppendNull(); + } +}; + +struct ChooseFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + // The first argument is always int64 or promoted to it. The kernel is dispatched + // based on the type of the rest of the arguments. + RETURN_NOT_OK(CheckArity(*values)); + EnsureDictionaryDecoded(values); + if (values->front().type->id() != Type::INT64) { + values->front().type = int64(); + } + if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { + for (auto it = values->begin() + 1; it != values->end(); it++) { + it->type = type; + } + } + if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + Result LastType(KernelContext*, const std::vector& descrs) { ValueDescr result = descrs.back(); result.shape = GetBroadcastShape(descrs); @@ -1652,6 +1863,26 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr& scalar_f } } +void AddChooseKernel(const std::shared_ptr& scalar_function, + detail::GetTypeId get_id, ArrayKernelExec exec) { + ScalarKernel kernel( + KernelSignature::Make({Type::INT64, InputType(get_id.id)}, OutputType(LastType), + /*is_varargs=*/true), + exec); + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + kernel.mem_allocation = MemAllocation::PREALLOCATE; + kernel.can_write_into_slices = is_fixed_width(get_id.id); + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); +} + +void AddPrimitiveChooseKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = GenerateTypeAgnosticPrimitive(*type); + AddChooseKernel(scalar_function, type, std::move(exec)); + } +} + const FunctionDoc if_else_doc{"Choose values based on a condition", ("`cond` must be a Boolean scalar/ array. \n`left` or " "`right` must be of the same type scalar/ array.\n" @@ -1679,6 +1910,15 @@ const FunctionDoc coalesce_doc{ "for which the value is not null. If all inputs are null in a row, the output " "will be null."), {"*values"}}; + +const FunctionDoc choose_doc{ + "Given indices and arrays, choose the value from the corresponding array for each " + "index", + ("For each row, the value of the first argument is used as a 0-based index into the " + "rest of the arguments (i.e. index 0 selects the second argument). The output value " + "is the corresponding value of the selected argument.\n" + "If an index is null, the output will be null."), + {"indices", "*values"}}; } // namespace void RegisterScalarIfElse(FunctionRegistry* registry) { @@ -1723,6 +1963,22 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } + { + auto func = std::make_shared("choose", Arity::VarArgs(/*min_args=*/2), + &choose_doc); + AddPrimitiveChooseKernels(func, NumericTypes()); + AddPrimitiveChooseKernels(func, TemporalTypes()); + AddPrimitiveChooseKernels(func, + {boolean(), null(), day_time_interval(), month_interval()}); + AddChooseKernel(func, Type::FIXED_SIZE_BINARY, + ChooseFunctor::Exec); + AddChooseKernel(func, Type::DECIMAL128, ChooseFunctor::Exec); + AddChooseKernel(func, Type::DECIMAL256, ChooseFunctor::Exec); + for (const auto& ty : BaseBinaryTypes()) { + AddChooseKernel(func, ty, GenerateTypeAgnosticVarBinaryBase(ty)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index a63492987eb..98137ac702a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -282,6 +282,37 @@ static void CoalesceNonNullBench64(benchmark::State& state) { return CoalesceBench(state); } +template +static void ChooseBench(benchmark::State& state) { + constexpr int kNumChoices = 5; + using CType = typename Type::c_type; + auto type = TypeTraits::type_singleton(); + + int64_t len = state.range(0); + int64_t offset = state.range(1); + + random::RandomArrayGenerator rand(/*seed=*/0); + + std::vector arguments; + arguments.emplace_back( + rand.Int64(len, /*min=*/0, /*max=*/kNumChoices - 1, /*null_probability=*/0.1) + ->Slice(offset)); + for (int i = 0; i < kNumChoices; i++) { + arguments.emplace_back( + rand.ArrayOf(type, len, /*null_probability=*/0.25)->Slice(offset)); + } + + for (auto _ : state) { + ABORT_NOT_OK(CallFunction("choose", arguments)); + } + + state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); +} + +static void ChooseBench64(benchmark::State& state) { + return ChooseBench(state); +} + BENCHMARK(IfElseBench32)->Args({elems, 0}); BENCHMARK(IfElseBench64)->Args({elems, 0}); @@ -312,5 +343,8 @@ BENCHMARK(CoalesceBench64)->Args({elems, 99}); BENCHMARK(CoalesceNonNullBench64)->Args({elems, 0}); BENCHMARK(CoalesceNonNullBench64)->Args({elems, 99}); +BENCHMARK(ChooseBench64)->Args({elems, 0}); +BENCHMARK(ChooseBench64)->Args({elems, 99}); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 48b0cdb457d..2c15dccfb3a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1040,5 +1040,192 @@ TEST(TestCoalesce, FixedSizeBinary) { ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])")); } +template +class TestChooseNumeric : public ::testing::Test {}; +template +class TestChooseBinary : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestChooseNumeric, NumericBasedTypes); +TYPED_TEST_SUITE(TestChooseBinary, BinaryTypes); + +TYPED_TEST(TestChooseNumeric, FixedSize) { + auto type = default_type_instance(); + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto values1 = ArrayFromJSON(type, "[10, 11, null, null, 14]"); + auto values2 = ArrayFromJSON(type, "[20, 21, null, null, 24]"); + auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); + CheckScalar("choose", {indices1, values1, values2}, + ArrayFromJSON(type, "[10, 21, null, null, null]")); + // Mixed scalar and array (note CheckScalar checks all-scalar cases for us) + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar1 = ScalarFromJSON(type, "42"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2}, + *MakeArrayFromScalar(*scalar1, 5)); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2}, + *MakeArrayOfNull(type, 5)); +} + +TYPED_TEST(TestChooseBinary, Basics) { + auto type = default_type_instance(); + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto values1 = ArrayFromJSON(type, R"(["a", "bc", null, null, "def"])"); + auto values2 = ArrayFromJSON(type, R"(["ghij", "klmno", null, null, "pqrstu"])"); + auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); + CheckScalar("choose", {indices1, values1, values2}, + ArrayFromJSON(type, R"(["a", "klmno", null, null, null])")); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar1 = ScalarFromJSON(type, R"("abcd")"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2}, + *MakeArrayFromScalar(*scalar1, 5)); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2}, + *MakeArrayOfNull(type, 5)); +} + +TEST(TestChoose, Null) { + auto type = null(); + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto nulls = *MakeArrayOfNull(type, 5); + CheckScalar("choose", {indices1, nulls, nulls}, nulls); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), nulls, nulls}, nulls); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), nulls, nulls}, nulls); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), nulls, nulls}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, nulls}, nulls); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar_null, nulls}, nulls); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), nulls, nulls}, nulls); +} + +TEST(TestChoose, Boolean) { + auto type = boolean(); + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto values1 = ArrayFromJSON(type, "[true, true, null, null, true]"); + auto values2 = ArrayFromJSON(type, "[false, false, null, null, false]"); + auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); + CheckScalar("choose", {indices1, values1, values2}, + ArrayFromJSON(type, "[true, false, null, null, null]")); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar1 = ScalarFromJSON(type, "true"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2}, + *MakeArrayFromScalar(*scalar1, 5)); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2}, + *MakeArrayOfNull(type, 5)); +} + +TEST(TestChoose, DayTimeInterval) { + auto type = day_time_interval(); + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto values1 = ArrayFromJSON(type, "[[10, 1], [10, 1], null, null, [10, 1]]"); + auto values2 = ArrayFromJSON(type, "[[2, 20], [2, 20], null, null, [2, 20]]"); + auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); + CheckScalar("choose", {indices1, values1, values2}, + ArrayFromJSON(type, "[[10, 1], [2, 20], null, null, null]")); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar1 = ScalarFromJSON(type, "[10, 1]"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2}, + *MakeArrayFromScalar(*scalar1, 5)); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2}, + *MakeArrayOfNull(type, 5)); +} + +TEST(TestChoose, Decimal) { + for (const auto& type : {decimal128(3, 2), decimal256(3, 2)}) { + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto values1 = ArrayFromJSON(type, R"(["1.23", "1.24", null, null, "1.25"])"); + auto values2 = ArrayFromJSON(type, R"(["4.56", "4.57", null, null, "4.58"])"); + auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); + CheckScalar("choose", {indices1, values1, values2}, + ArrayFromJSON(type, R"(["1.23", "4.57", null, null, null])")); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar1 = ScalarFromJSON(type, R"("1.23")"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2}, + *MakeArrayFromScalar(*scalar1, 5)); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2}, + *MakeArrayOfNull(type, 5)); + } +} + +TEST(TestChoose, FixedSizeBinary) { + auto type = fixed_size_binary(3); + auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); + auto values1 = ArrayFromJSON(type, R"(["abc", "abd", null, null, "abe"])"); + auto values2 = ArrayFromJSON(type, R"(["def", "deg", null, null, "deh"])"); + auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); + CheckScalar("choose", {indices1, values1, values2}, + ArrayFromJSON(type, R"(["abc", "deg", null, null, null])")); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar1 = ScalarFromJSON(type, R"("abc")"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar1, values2}, + *MakeArrayFromScalar(*scalar1, 5)); + CheckScalar("choose", {ScalarFromJSON(int64(), "1"), scalar1, values2}, values2); + CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); + auto scalar_null = ScalarFromJSON(type, "null"); + CheckScalar("choose", {ScalarFromJSON(int64(), "0"), scalar_null, values2}, + *MakeArrayOfNull(type, 5)); +} + +TEST(TestChooseKernel, DispatchBest) { + ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction("choose")); + auto Check = [&](std::vector original_values) { + auto values = original_values; + ARROW_EXPECT_OK(function->DispatchBest(&values)); + return values; + }; + + // Since DispatchBest for this kernel pulls tricks, we can't compare it to DispatchExact + // as CheckDispatchBest does + for (auto ty : + {int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(), uint64()}) { + // Index always promoted to int64 + EXPECT_EQ((std::vector{int64(), ty}), Check({ty, ty})); + EXPECT_EQ((std::vector{int64(), int64(), int64()}), + Check({ty, ty, int64()})); + } + // Other arguments promoted separately from index + EXPECT_EQ((std::vector{int64(), int32(), int32()}), + Check({int8(), int32(), uint8()})); +} + +TEST(TestChooseKernel, Errors) { + ASSERT_RAISES(Invalid, CallFunction("choose", {})); + ASSERT_RAISES(Invalid, CallFunction("choose", {ArrayFromJSON(int64(), "[]")})); + ASSERT_RAISES(Invalid, CallFunction("choose", {ArrayFromJSON(utf8(), "[\"a\"]"), + ArrayFromJSON(int64(), "[0]")})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + IndexError, ::testing::HasSubstr("choose: index 1 out of range"), + CallFunction("choose", + {ArrayFromJSON(int64(), "[1]"), ArrayFromJSON(int32(), "[0]")})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + IndexError, ::testing::HasSubstr("choose: index -1 out of range"), + CallFunction("choose", + {ArrayFromJSON(int64(), "[-1]"), ArrayFromJSON(int32(), "[0]")})); +} + } // namespace compute } // namespace arrow diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 01dc1d92e17..be9f7789dbe 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -891,25 +891,27 @@ Structural transforms +==========================+============+===================================================+=====================+=========+ | case_when | Varargs | Struct of Boolean (Arg 0), Any fixed-width (rest) | Input type | \(1) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| coalesce | Varargs | Any | Input type | \(2) | +| choose | Varargs | Integral (Arg 0); Fixed-width/Binary-like (rest) | Input type | \(2) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(3) | +| coalesce | Varargs | Any | Input type | \(3) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(4) | +| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(4) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| is_finite | Unary | Float, Double | Boolean | \(5) | +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(5) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| is_inf | Unary | Float, Double | Boolean | \(6) | +| is_finite | Unary | Float, Double | Boolean | \(6) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| is_nan | Unary | Float, Double | Boolean | \(7) | +| is_inf | Unary | Float, Double | Boolean | \(7) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| is_null | Unary | Any | Boolean | \(8) | +| is_nan | Unary | Float, Double | Boolean | \(8) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| is_valid | Unary | Any | Boolean | \(9) | +| is_null | Unary | Any | Boolean | \(9) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| list_value_length | Unary | List-like | Int32 or Int64 | \(10) | +| is_valid | Unary | Any | Boolean | \(10) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ -| make_struct | Varargs | Any | Struct | \(11) | +| list_value_length | Unary | List-like | Int32 or Int64 | \(11) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| make_struct | Varargs | Any | Struct | \(12) | +--------------------------+------------+---------------------------------------------------+---------------------+---------+ * \(1) This function acts like a SQL 'case when' statement or switch-case. The @@ -921,14 +923,21 @@ Structural transforms the first value datum for which the corresponding Boolean is true, or the corresponding value from the 'default' input, or null otherwise. -* \(2) Each row of the output will be the corresponding value of the first +* \(2) The first input must be an integral type. The rest of the arguments can be + any type, but must all be the same type or promotable to a common type. Each + value of the first input (the 'index') is used as a zero-based index into the + remaining arguments (i.e. index 0 is the second argument, index 1 is the third + argument, etc.), and the value of the output for that row will be the + corresponding value of the selected input at that row. + +* \(3) Each row of the output will be the corresponding value of the first input which is non-null for that row, otherwise null. -* \(3) First input must be an array, second input a scalar of the same type. +* \(4) First input must be an array, second input a scalar of the same type. Output is an array of the same type as the inputs, and with the same values as the first input, except for nulls replaced with the second input value. -* \(4) First input must be a Boolean scalar or array. Second and third inputs +* \(5) First input must be a Boolean scalar or array. Second and third inputs could be scalars or arrays and must be of the same type. Output is an array (or scalar if all inputs are scalar) of the same type as the second/ third input. If the nulls present on the first input, they will be promoted to the @@ -936,21 +945,21 @@ Structural transforms Also see: :ref:`replace_with_mask `. -* \(5) Output is true iff the corresponding input element is finite (not Infinity, +* \(6) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). -* \(6) Output is true iff the corresponding input element is Infinity/-Infinity. +* \(7) Output is true iff the corresponding input element is Infinity/-Infinity. -* \(7) Output is true iff the corresponding input element is NaN. +* \(8) Output is true iff the corresponding input element is NaN. -* \(8) Output is true iff the corresponding input element is null. +* \(9) Output is true iff the corresponding input element is null. -* \(9) Output is true iff the corresponding input element is non-null. +* \(10) Output is true iff the corresponding input element is non-null. -* \(10) Each output element is the length of the corresponding input element +* \(11) Each output element is the length of the corresponding input element (null if input is null). Output type is Int32 for List, Int64 for LargeList. -* \(11) The output struct's field types are the types of its arguments. The +* \(12) The output struct's field types are the types of its arguments. The field names are specified using an instance of :struct:`MakeStructOptions`. The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index c503cba319c..790b9ba5d23 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -352,6 +352,7 @@ Structural Transforms binary_length case_when + choose coalesce fill_null if_else From a024bb0186e0b4c49b00ea10f45ee557c6731eb4 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 27 Jul 2021 10:43:37 -0400 Subject: [PATCH 2/3] ARROW-13220: [C++] Address PR feedback --- .../arrow/compute/kernels/codegen_internal.cc | 9 ++++ .../arrow/compute/kernels/codegen_internal.h | 3 ++ .../arrow/compute/kernels/scalar_if_else.cc | 35 +++++++++----- .../kernels/scalar_if_else_benchmark.cc | 46 +++++++++---------- .../compute/kernels/scalar_if_else_test.cc | 13 ++++++ .../arrow/compute/kernels/vector_replace.cc | 5 +- docs/source/cpp/compute.rst | 3 +- 7 files changed, 76 insertions(+), 38 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index bab8e7000cd..f8b90085010 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -47,6 +47,7 @@ std::vector> g_floating_types; std::vector> g_numeric_types; std::vector> g_base_binary_types; std::vector> g_temporal_types; +std::vector> g_interval_types; std::vector> g_primitive_types; std::vector g_decimal_type_ids; static std::once_flag codegen_static_initialized; @@ -91,6 +92,9 @@ static void InitStaticData() { timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}; + // Interval types + g_interval_types = {day_time_interval(), month_interval()}; + // Base binary types (without FixedSizeBinary) g_base_binary_types = {binary(), utf8(), large_binary(), large_utf8()}; @@ -157,6 +161,11 @@ const std::vector>& TemporalTypes() { return g_temporal_types; } +const std::vector>& IntervalTypes() { + std::call_once(codegen_static_initialized, InitStaticData); + return g_interval_types; +} + const std::vector>& PrimitiveTypes() { std::call_once(codegen_static_initialized, InitStaticData); return g_primitive_types; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index f432c93daac..9c8b2cef198 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -442,6 +442,9 @@ const std::vector>& NumericTypes(); // Temporal types including time and timestamps for each unit const std::vector>& TemporalTypes(); +// Interval types +const std::vector>& IntervalTypes(); + // Integer, floating point, base binary, and temporal const std::vector>& PrimitiveTypes(); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index e4bd660d3bd..35055609956 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1182,6 +1182,15 @@ void CopyOneArrayValue(const DataType& type, const uint8_t* in_valid, out_offset); } +template +void CopyOneScalarValue(const Scalar& scalar, uint8_t* out_valid, uint8_t* out_values, + const int64_t out_offset) { + if (out_valid) { + BitUtil::SetBitTo(out_valid, out_offset, scalar.is_valid); + } + CopyFixedWidth::CopyScalar(scalar, /*length=*/1, out_values, out_offset); +} + template void CopyOneValue(const Datum& in_values, const int64_t in_offset, uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) { @@ -1191,8 +1200,7 @@ void CopyOneValue(const Datum& in_values, const int64_t in_offset, uint8_t* out_ array.GetValues(1, 0), array.offset + in_offset, out_valid, out_values, out_offset); } else { - CopyValues(in_values, in_offset, /*length=*/1, out_valid, out_values, - out_offset); + CopyOneScalarValue(*in_values.scalar(), out_valid, out_values, out_offset); } } @@ -1714,10 +1722,9 @@ struct ChooseFunctor> { const auto& index_scalar = *batch[0].scalar(); if (!index_scalar.is_valid) { if (out->is_array()) { - auto null_scalar = MakeNullScalar(out->type()); ARROW_ASSIGN_OR_RAISE( auto temp_array, - MakeArrayFromScalar(*null_scalar, batch.length, ctx->memory_pool())); + MakeArrayOfNull(out->type(), batch.length, ctx->memory_pool())); *out->mutable_array() = *temp_array->data(); } return Status::OK(); @@ -1748,6 +1755,7 @@ struct ChooseFunctor> { const auto row_length = checked_cast(*value.scalar()).value->size(); reserve_data = std::max(reserve_data, batch.length * row_length); + continue; } const ArrayData& arr = *value.array(); const offset_type* offsets = arr.GetValues(1); @@ -1781,7 +1789,9 @@ struct ChooseFunctor> { static Status CopyValue(const Datum& datum, BuilderType* builder, int64_t row) { if (datum.is_scalar()) { - return builder->AppendScalar(*datum.scalar()); + const auto& scalar = checked_cast(*datum.scalar()); + if (!scalar.value) return builder->AppendNull(); + return builder->Append(scalar.value->data(), scalar.value->size()); } const ArrayData& source = *datum.array(); if (!source.MayHaveNulls() || @@ -1928,7 +1938,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveIfElseKernels(func, NumericTypes()); AddPrimitiveIfElseKernels(func, TemporalTypes()); - AddPrimitiveIfElseKernels(func, {boolean(), day_time_interval(), month_interval()}); + AddPrimitiveIfElseKernels(func, IntervalTypes()); + AddPrimitiveIfElseKernels(func, {boolean()}); AddNullIfElseKernel(func); AddBinaryIfElseKernels(func, BaseBinaryTypes()); AddFSBinaryIfElseKernel(func); @@ -1939,8 +1950,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc); AddPrimitiveCaseWhenKernels(func, NumericTypes()); AddPrimitiveCaseWhenKernels(func, TemporalTypes()); - AddPrimitiveCaseWhenKernels( - func, {boolean(), null(), day_time_interval(), month_interval()}); + AddPrimitiveCaseWhenKernels(func, IntervalTypes()); + AddPrimitiveCaseWhenKernels(func, {boolean(), null()}); AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec); @@ -1952,8 +1963,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { "coalesce", Arity::VarArgs(/*min_args=*/1), &coalesce_doc); AddPrimitiveCoalesceKernels(func, NumericTypes()); AddPrimitiveCoalesceKernels(func, TemporalTypes()); - AddPrimitiveCoalesceKernels( - func, {boolean(), null(), day_time_interval(), month_interval()}); + AddPrimitiveCoalesceKernels(func, IntervalTypes()); + AddPrimitiveCoalesceKernels(func, {boolean(), null()}); AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY, CoalesceFunctor::Exec); AddCoalesceKernel(func, Type::DECIMAL128, CoalesceFunctor::Exec); @@ -1968,8 +1979,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { &choose_doc); AddPrimitiveChooseKernels(func, NumericTypes()); AddPrimitiveChooseKernels(func, TemporalTypes()); - AddPrimitiveChooseKernels(func, - {boolean(), null(), day_time_interval(), month_interval()}); + AddPrimitiveChooseKernels(func, IntervalTypes()); + AddPrimitiveChooseKernels(func, {boolean(), null()}); AddChooseKernel(func, Type::FIXED_SIZE_BINARY, ChooseFunctor::Exec); AddChooseKernel(func, Type::DECIMAL128, ChooseFunctor::Exec); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 98137ac702a..9b59d54c3da 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -26,7 +26,7 @@ namespace arrow { namespace compute { -const int64_t elems = 1024 * 1024; +const int64_t kNumItems = 1024 * 1024; template struct SetBytesProcessed {}; @@ -313,38 +313,38 @@ static void ChooseBench64(benchmark::State& state) { return ChooseBench(state); } -BENCHMARK(IfElseBench32)->Args({elems, 0}); -BENCHMARK(IfElseBench64)->Args({elems, 0}); +BENCHMARK(IfElseBench32)->Args({kNumItems, 0}); +BENCHMARK(IfElseBench64)->Args({kNumItems, 0}); -BENCHMARK(IfElseBench32)->Args({elems, 99}); -BENCHMARK(IfElseBench64)->Args({elems, 99}); +BENCHMARK(IfElseBench32)->Args({kNumItems, 99}); +BENCHMARK(IfElseBench64)->Args({kNumItems, 99}); -BENCHMARK(IfElseBench32Contiguous)->Args({elems, 0}); -BENCHMARK(IfElseBench64Contiguous)->Args({elems, 0}); +BENCHMARK(IfElseBench32Contiguous)->Args({kNumItems, 0}); +BENCHMARK(IfElseBench64Contiguous)->Args({kNumItems, 0}); -BENCHMARK(IfElseBench32Contiguous)->Args({elems, 99}); -BENCHMARK(IfElseBench64Contiguous)->Args({elems, 99}); +BENCHMARK(IfElseBench32Contiguous)->Args({kNumItems, 99}); +BENCHMARK(IfElseBench64Contiguous)->Args({kNumItems, 99}); -BENCHMARK(IfElseBenchString32)->Args({elems, 0}); -BENCHMARK(IfElseBenchString64)->Args({elems, 0}); +BENCHMARK(IfElseBenchString32)->Args({kNumItems, 0}); +BENCHMARK(IfElseBenchString64)->Args({kNumItems, 0}); -BENCHMARK(IfElseBenchString32Contiguous)->Args({elems, 99}); -BENCHMARK(IfElseBenchString64Contiguous)->Args({elems, 99}); +BENCHMARK(IfElseBenchString32Contiguous)->Args({kNumItems, 99}); +BENCHMARK(IfElseBenchString64Contiguous)->Args({kNumItems, 99}); -BENCHMARK(CaseWhenBench64)->Args({elems, 0}); -BENCHMARK(CaseWhenBench64)->Args({elems, 99}); +BENCHMARK(CaseWhenBench64)->Args({kNumItems, 0}); +BENCHMARK(CaseWhenBench64)->Args({kNumItems, 99}); -BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0}); -BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99}); +BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 0}); +BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 99}); -BENCHMARK(CoalesceBench64)->Args({elems, 0}); -BENCHMARK(CoalesceBench64)->Args({elems, 99}); +BENCHMARK(CoalesceBench64)->Args({kNumItems, 0}); +BENCHMARK(CoalesceBench64)->Args({kNumItems, 99}); -BENCHMARK(CoalesceNonNullBench64)->Args({elems, 0}); -BENCHMARK(CoalesceNonNullBench64)->Args({elems, 99}); +BENCHMARK(CoalesceNonNullBench64)->Args({kNumItems, 0}); +BENCHMARK(CoalesceNonNullBench64)->Args({kNumItems, 99}); -BENCHMARK(ChooseBench64)->Args({elems, 0}); -BENCHMARK(ChooseBench64)->Args({elems, 99}); +BENCHMARK(ChooseBench64)->Args({kNumItems, 0}); +BENCHMARK(ChooseBench64)->Args({kNumItems, 99}); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 2c15dccfb3a..f06a6822a0f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -1056,6 +1056,8 @@ TYPED_TEST(TestChooseNumeric, FixedSize) { auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); CheckScalar("choose", {indices1, values1, values2}, ArrayFromJSON(type, "[10, 21, null, null, null]")); + CheckScalar("choose", {indices1, ScalarFromJSON(type, "1"), values1}, + ArrayFromJSON(type, "[1, 11, 1, null, null]")); // Mixed scalar and array (note CheckScalar checks all-scalar cases for us) CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); @@ -1078,6 +1080,8 @@ TYPED_TEST(TestChooseBinary, Basics) { auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); CheckScalar("choose", {indices1, values1, values2}, ArrayFromJSON(type, R"(["a", "klmno", null, null, null])")); + CheckScalar("choose", {indices1, ScalarFromJSON(type, R"("foo")"), values1}, + ArrayFromJSON(type, R"(["foo", "bc", "foo", null, null])")); CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); @@ -1096,6 +1100,7 @@ TEST(TestChoose, Null) { auto indices1 = ArrayFromJSON(int64(), "[0, 1, 0, 1, null]"); auto nulls = *MakeArrayOfNull(type, 5); CheckScalar("choose", {indices1, nulls, nulls}, nulls); + CheckScalar("choose", {indices1, MakeNullScalar(type), nulls}, nulls); CheckScalar("choose", {ScalarFromJSON(int64(), "0"), nulls, nulls}, nulls); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), nulls, nulls}, nulls); CheckScalar("choose", {ScalarFromJSON(int64(), "null"), nulls, nulls}, nulls); @@ -1113,6 +1118,8 @@ TEST(TestChoose, Boolean) { auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); CheckScalar("choose", {indices1, values1, values2}, ArrayFromJSON(type, "[true, false, null, null, null]")); + CheckScalar("choose", {indices1, ScalarFromJSON(type, "false"), values1}, + ArrayFromJSON(type, "[false, true, false, null, null]")); CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); @@ -1134,6 +1141,8 @@ TEST(TestChoose, DayTimeInterval) { auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); CheckScalar("choose", {indices1, values1, values2}, ArrayFromJSON(type, "[[10, 1], [2, 20], null, null, null]")); + CheckScalar("choose", {indices1, ScalarFromJSON(type, "[1, 2]"), values1}, + ArrayFromJSON(type, "[[1, 2], [10, 1], [1, 2], null, null]")); CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); @@ -1155,6 +1164,8 @@ TEST(TestChoose, Decimal) { auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); CheckScalar("choose", {indices1, values1, values2}, ArrayFromJSON(type, R"(["1.23", "4.57", null, null, null])")); + CheckScalar("choose", {indices1, ScalarFromJSON(type, R"("2.34")"), values1}, + ArrayFromJSON(type, R"(["2.34", "1.24", "2.34", null, null])")); CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); @@ -1177,6 +1188,8 @@ TEST(TestChoose, FixedSizeBinary) { auto nulls = ArrayFromJSON(type, "[null, null, null, null, null]"); CheckScalar("choose", {indices1, values1, values2}, ArrayFromJSON(type, R"(["abc", "deg", null, null, null])")); + CheckScalar("choose", {indices1, ScalarFromJSON(type, R"("xyz")"), values1}, + ArrayFromJSON(type, R"(["xyz", "abd", "xyz", null, null])")); CheckScalar("choose", {ScalarFromJSON(int64(), "0"), values1, values2}, values1); CheckScalar("choose", {ScalarFromJSON(int64(), "1"), values1, values2}, values2); CheckScalar("choose", {ScalarFromJSON(int64(), "null"), values1, values2}, nulls); diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 644aec2a4e9..450f99d7826 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -520,10 +520,11 @@ void RegisterVectorReplace(FunctionRegistry* registry) { for (const auto& ty : TemporalTypes()) { add_primitive_kernel(ty); } + for (const auto& ty : IntervalTypes()) { + add_primitive_kernel(ty); + } add_primitive_kernel(null()); add_primitive_kernel(boolean()); - add_primitive_kernel(day_time_interval()); - add_primitive_kernel(month_interval()); add_kernel(Type::FIXED_SIZE_BINARY, ReplaceWithMaskFunctor::Exec); add_kernel(Type::DECIMAL128, ReplaceWithMaskFunctor::Exec); add_kernel(Type::DECIMAL256, ReplaceWithMaskFunctor::Exec); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index be9f7789dbe..46523cad3dd 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -928,7 +928,8 @@ Structural transforms value of the first input (the 'index') is used as a zero-based index into the remaining arguments (i.e. index 0 is the second argument, index 1 is the third argument, etc.), and the value of the output for that row will be the - corresponding value of the selected input at that row. + corresponding value of the selected input at that row. If the index is null, + then the output will also be null. * \(3) Each row of the output will be the corresponding value of the first input which is non-null for that row, otherwise null. From d4d1ecf966d2821215e014581f91ae216c7ffb38 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 27 Jul 2021 11:03:42 -0400 Subject: [PATCH 3/3] ARROW-13220: [C++] Add cast --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 35055609956..cb261ec59a7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1791,7 +1791,8 @@ struct ChooseFunctor> { if (datum.is_scalar()) { const auto& scalar = checked_cast(*datum.scalar()); if (!scalar.value) return builder->AppendNull(); - return builder->Append(scalar.value->data(), scalar.value->size()); + return builder->Append(scalar.value->data(), + static_cast(scalar.value->size())); } const ArrayData& source = *datum.array(); if (!source.MayHaveNulls() ||