From b462b8c29942f7d34ec69f121f3f2e5b3e0407dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Saint-Jacques?= Date: Fri, 1 Mar 2019 10:00:57 -0500 Subject: [PATCH] ARROW-3123: [C++] Implement Count aggregate kernel --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/kernel.h | 13 ++ .../arrow/compute/kernels/aggregate-test.cc | 75 ++++++++++++ cpp/src/arrow/compute/kernels/count.cc | 115 ++++++++++++++++++ cpp/src/arrow/compute/kernels/count.h | 88 ++++++++++++++ 5 files changed, 292 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/count.cc create mode 100644 cpp/src/arrow/compute/kernels/count.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 06c1b335e8a..2c3f00d7322 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -146,6 +146,7 @@ if(ARROW_COMPUTE) compute/kernels/aggregate.cc compute/kernels/boolean.cc compute/kernels/cast.cc + compute/kernels/count.cc compute/kernels/hash.cc compute/kernels/mean.cc compute/kernels/sum.cc diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 58d7dc0323a..387715d3277 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -95,6 +95,19 @@ struct ARROW_EXPORT Datum { Datum(const std::shared_ptr& value) // NOLINT implicit conversion : Datum(std::shared_ptr(value)) {} + // Convenience constructors + explicit Datum(bool value) : value(std::make_shared(value)) {} + explicit Datum(int8_t value) : value(std::make_shared(value)) {} + explicit Datum(uint8_t value) : value(std::make_shared(value)) {} + explicit Datum(int16_t value) : value(std::make_shared(value)) {} + explicit Datum(uint16_t value) : value(std::make_shared(value)) {} + explicit Datum(int32_t value) : value(std::make_shared(value)) {} + explicit Datum(uint32_t value) : value(std::make_shared(value)) {} + explicit Datum(int64_t value) : value(std::make_shared(value)) {} + explicit Datum(uint64_t value) : value(std::make_shared(value)) {} + explicit Datum(float value) : value(std::make_shared(value)) {} + explicit Datum(double value) : value(std::make_shared(value)) {} + ~Datum() {} Datum(const Datum& other) noexcept { this->value = other.value; } diff --git a/cpp/src/arrow/compute/kernels/aggregate-test.cc b/cpp/src/arrow/compute/kernels/aggregate-test.cc index bdf50f5ac7b..cbe91a29607 100644 --- a/cpp/src/arrow/compute/kernels/aggregate-test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate-test.cc @@ -25,6 +25,7 @@ #include "arrow/array.h" #include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/count.h" #include "arrow/compute/kernels/mean.h" #include "arrow/compute/kernels/sum-internal.h" #include "arrow/compute/kernels/sum.h" @@ -43,6 +44,10 @@ using std::vector; namespace arrow { namespace compute { +/// +/// Sum +/// + template using SumResult = std::pair::Type::c_type, size_t>; @@ -162,6 +167,10 @@ TYPED_TEST(TestRandomNumericSumKernel, RandomSliceArraySum) { } } +/// +/// Mean +/// + template static Datum NaiveMean(const Array& array) { using MeanScalarType = typename TypeTraits::ScalarType; @@ -237,5 +246,71 @@ TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMean) { } } +/// +/// Count +/// +// +using CountPair = std::pair; + +static CountPair NaiveCount(const Array& array) { + CountPair count; + + count.first = array.length() - array.null_count(); + count.second = array.null_count(); + + return count; +} + +void ValidateCount(FunctionContext* ctx, const Array& input, CountPair expected) { + CountOptions all = CountOptions(CountOptions::COUNT_ALL); + CountOptions nulls = CountOptions(CountOptions::COUNT_NULL); + Datum result; + + ASSERT_OK(Count(ctx, all, input, &result)); + DatumEqual::EnsureEqual(result, Datum(expected.first)); + + ASSERT_OK(Count(ctx, nulls, input, &result)); + DatumEqual::EnsureEqual(result, Datum(expected.second)); +} + +template +void ValidateCount(FunctionContext* ctx, const char* json, CountPair expected) { + auto array = ArrayFromJSON(TypeTraits::type_singleton(), json); + ValidateCount(ctx, *array, expected); +} + +void ValidateCount(FunctionContext* ctx, const Array& input) { + ValidateCount(ctx, input, NaiveCount(input)); +} + +template +class TestCountKernel : public ComputeFixture, public TestBase {}; + +TYPED_TEST_CASE(TestCountKernel, NumericArrowTypes); +TYPED_TEST(TestCountKernel, SimpleCount) { + ValidateCount(&this->ctx_, "[]", {0, 0}); + ValidateCount(&this->ctx_, "[null]", {0, 1}); + ValidateCount(&this->ctx_, "[1, null, 2]", {2, 1}); + ValidateCount(&this->ctx_, "[null, null, null]", {0, 3}); + ValidateCount(&this->ctx_, "[1, 2, 3, 4, 5, 6, 7, 8, 9]", {9, 0}); +} + +template +class TestRandomNumericCountKernel : public ComputeFixture, public TestBase {}; + +TYPED_TEST_CASE(TestRandomNumericCountKernel, NumericArrowTypes); +TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { + auto rand = random::RandomArrayGenerator(0x1205643); + for (size_t i = 3; i < 14; i++) { + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto length_adjust : {-2, -1, 0, 1, 2}) { + int64_t length = (1UL << i) + length_adjust; + auto array = rand.Numeric(length, 0, 100, null_probability); + ValidateCount(&this->ctx_, *array); + } + } + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/count.cc b/cpp/src/arrow/compute/kernels/count.cc new file mode 100644 index 00000000000..44ba9d52299 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/count.cc @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/count.h" + +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/aggregate.h" + +namespace arrow { +namespace compute { + +struct CountState { + CountState() : non_nulls(0), nulls(0) {} + CountState(int64_t non_nulls, int64_t nulls) : non_nulls(non_nulls), nulls(nulls) {} + + CountState operator+(const CountState& rhs) const { + return CountState(this->non_nulls + rhs.non_nulls, this->nulls + rhs.nulls); + } + + CountState& operator+=(const CountState& rhs) { + this->non_nulls += rhs.non_nulls; + this->nulls += rhs.nulls; + return *this; + } + + std::shared_ptr NonNullsAsScalar() const { + using ScalarType = typename CTypeTraits::ScalarType; + return std::make_shared(non_nulls); + } + + std::shared_ptr NullsAsScalar() const { + using ScalarType = typename CTypeTraits::ScalarType; + return std::make_shared(nulls); + } + + int64_t non_nulls = 0; + int64_t nulls = 0; +}; + +class CountAggregateFunction final : public AggregateFunctionStaticState { + public: + explicit CountAggregateFunction(const CountOptions& options) : options_(options) {} + + Status Consume(const Array& input, CountState* state) const override { + const int64_t length = input.length(); + const int64_t nulls = input.null_count(); + + state->nulls = nulls; + state->non_nulls = length - nulls; + + return Status::OK(); + } + + Status Merge(const CountState& src, CountState* dst) const override { + *dst += src; + return Status::OK(); + } + + Status Finalize(const CountState& src, Datum* output) const override { + switch (options_.count_mode) { + case CountOptions::COUNT_ALL: + *output = src.NonNullsAsScalar(); + break; + case CountOptions::COUNT_NULL: + *output = src.NullsAsScalar(); + break; + default: + return Status::Invalid("Unknown CountOptions encountered"); + } + + return Status::OK(); + } + + std::shared_ptr out_type() const override { return int64(); } + + private: + CountOptions options_; +}; + +std::shared_ptr MakeCountAggregateFunction( + FunctionContext* context, const CountOptions& options) { + return std::make_shared(options); +} + +Status Count(FunctionContext* context, const CountOptions& options, const Datum& value, + Datum* out) { + if (!value.is_array()) return Status::Invalid("Count is expecting an array datum."); + + auto aggregate = MakeCountAggregateFunction(context, options); + auto kernel = std::make_shared(aggregate); + + return kernel->Call(context, value, out); +} + +Status Count(FunctionContext* context, const CountOptions& options, const Array& array, + Datum* out) { + return Count(context, options, array.data(), out); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/count.h b/cpp/src/arrow/compute/kernels/count.h new file mode 100644 index 00000000000..c33ac48665a --- /dev/null +++ b/cpp/src/arrow/compute/kernels/count.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; +class DataType; + +namespace compute { + +struct Datum; +class FunctionContext; +class AggregateFunction; + +/// \class CountOptions +/// +/// The user control the Count kernel behavior with this class. By default, the +/// it will count all non-null values. +struct ARROW_EXPORT CountOptions { + enum mode { + // Count all non-null values. + COUNT_ALL = 0, + // Count all null values. + COUNT_NULL, + }; + + explicit CountOptions(enum mode count_mode) : count_mode(count_mode) {} + + enum mode count_mode = COUNT_ALL; +}; + +/// \brief Return Count function aggregate +ARROW_EXPORT +std::shared_ptr MakeCount(FunctionContext* context, + const CountOptions& options); + +/// \brief Count non-null (or null) values in an array. +/// +/// \param[in] context the FunctionContext +/// \param[in] options counting options, see CountOptions for more information +/// \param[in] datum to count +/// \param[out] out resulting datum +/// +/// \since 0.13.0 +/// \note API not yet finalized +ARROW_EXPORT +Status Count(FunctionContext* context, const CountOptions& options, const Datum& datum, + Datum* out); + +/// \brief Count non-null (or null) values in an array. +/// +/// \param[in] context the FunctionContext +/// \param[in] options counting options, see CountOptions for more information +/// \param[in] array to count +/// \param[out] out resulting datum +/// +/// \since 0.13.0 +/// \note API not yet finalized +ARROW_EXPORT +Status Count(FunctionContext* context, const CountOptions& options, const Array& array, + Datum* out); + +} // namespace compute +} // namespace arrow