From e53c0b8c34f39e1feef14bdc519d486b352a3aab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Percy=20Camilo=20Trive=C3=B1o=20Aucahuasi?= Date: Mon, 27 Sep 2021 19:43:04 -0500 Subject: [PATCH 1/3] ARROW-14035: Implement count distinct kernel finish basic implementation & add tests and docs add count_distinct_doc add R binding for count_distinct kernel support more types & add more tests minor change reduce code use type matchers to reduce even more code minor change Update docs/source/cpp/compute.rst Co-authored-by: Ian Cook minor changes add tests for scalar types and fix an issue for scalars --- cpp/src/arrow/compute/kernel.cc | 22 +++ cpp/src/arrow/compute/kernel.h | 3 + .../arrow/compute/kernels/aggregate_basic.cc | 145 +++++++++++++++++ .../arrow/compute/kernels/aggregate_test.cc | 147 ++++++++++++++++++ docs/source/cpp/compute.rst | 2 + docs/source/python/api/compute.rst | 1 + python/pyarrow/tests/test_compute.py | 17 ++ r/src/compute.cpp | 2 +- r/tests/testthat/test-dplyr-summarize.R | 16 +- 9 files changed, 353 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 5d47175f260..666b73e415c 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -249,10 +249,32 @@ class LargeBinaryLikeMatcher : public TypeMatcher { std::string ToString() const override { return "large-binary-like"; } }; +class FixedSizeBinaryLikeMatcher : public TypeMatcher { + public: + FixedSizeBinaryLikeMatcher() {} + + bool Matches(const DataType& type) const override { + return is_fixed_size_binary(type.id()); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + return casted != nullptr; + } + std::string ToString() const override { return "fixed-size-binary-like"; } +}; + std::shared_ptr LargeBinaryLike() { return std::make_shared(); } +std::shared_ptr FixedSizeBinaryLike() { + return std::make_shared(); +} + } // namespace match // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 01750d1f359..27fb831636c 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -134,6 +134,9 @@ ARROW_EXPORT std::shared_ptr BinaryLike(); // Match types using 64-bit varbinary representation ARROW_EXPORT std::shared_ptr LargeBinaryLike(); +// Match any fixed binary type +ARROW_EXPORT std::shared_ptr FixedSizeBinaryLike(); + // \brief Match any primitive type (boolean or any type representable as a C // Type) ARROW_EXPORT std::shared_ptr Primitive(); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 08efa9eec66..38b91251dd8 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -21,6 +21,7 @@ #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/util/cpu_info.h" +#include "arrow/util/hashing.h" #include "arrow/util/make_unique.h" namespace arrow { @@ -121,6 +122,138 @@ Result> CountInit(KernelContext*, static_cast(*args.options)); } +// ---------------------------------------------------------------------- +// Distinct Count implementation + +template +struct CountDistinctImpl : public ScalarAggregator { + using MemoTable = typename arrow::internal::HashTraits::MemoTableType; + + explicit CountDistinctImpl(MemoryPool* memory_pool, CountOptions options) + : options(std::move(options)), memo_table_(new MemoTable(memory_pool, 0)) {} + + Status Consume(KernelContext*, const ExecBatch& batch) override { + if (batch[0].is_array()) { + const ArrayData& arr = *batch[0].array(); + auto visit_null = []() { return Status::OK(); }; + auto visit_value = [&](VisitorArgType arg) { + int y; + return memo_table_->GetOrInsert(arg, &y); + }; + RETURN_NOT_OK(VisitArrayDataInline(arr, visit_value, visit_null)); + this->non_nulls += memo_table_->size(); + this->has_nulls = arr.GetNullCount() > 0; + } else { + const Scalar& input = *batch[0].scalar(); + this->has_nulls = !input.is_valid; + if (input.is_valid) { + this->non_nulls += batch.length; + } + } + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + const auto& other_state = checked_cast(src); + this->non_nulls += other_state.non_nulls; + this->has_nulls = this->has_nulls || other_state.has_nulls; + return Status::OK(); + } + + Status Finalize(KernelContext* ctx, Datum* out) override { + const auto& state = checked_cast(*ctx->state()); + const int64_t nulls = state.has_nulls ? 1 : 0; + switch (state.options.mode) { + case CountOptions::ONLY_VALID: + *out = Datum(state.non_nulls); + break; + case CountOptions::ALL: + *out = Datum(state.non_nulls + nulls); + break; + case CountOptions::ONLY_NULL: + *out = Datum(nulls); + break; + default: + DCHECK(false) << "unreachable"; + } + return Status::OK(); + } + + const CountOptions options; + int64_t non_nulls = 0; + bool has_nulls = false; + std::unique_ptr memo_table_; +}; + +template +Result> CountDistinctInit(KernelContext* ctx, + const KernelInitArgs& args) { + return ::arrow::internal::make_unique>( + ctx->memory_pool(), static_cast(*args.options)); +} + +template +void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) { + AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())), + aggregate::CountDistinctInit, func); +} + +template +struct CountDistinctKernel { + static void Add(InputType type, ScalarAggregateFunction* func) { + using PhysicalType = typename Type::PhysicalType; + AddCountDistinctKernel(type, func); + } +}; + +template <> +struct CountDistinctKernel { + static void Add(InputType type, ScalarAggregateFunction* func) { + AddCountDistinctKernel(type, func); + } +}; + +void AddCountDistinctKernels(ScalarAggregateFunction* func) { + // Boolean + aggregate::CountDistinctKernel::Add(boolean(), func); + // Number + aggregate::CountDistinctKernel::Add(int8(), func); + aggregate::CountDistinctKernel::Add(int16(), func); + aggregate::CountDistinctKernel::Add(int32(), func); + aggregate::CountDistinctKernel::Add(int64(), func); + aggregate::CountDistinctKernel::Add(uint8(), func); + aggregate::CountDistinctKernel::Add(uint16(), func); + aggregate::CountDistinctKernel::Add(uint32(), func); + aggregate::CountDistinctKernel::Add(uint64(), func); + aggregate::CountDistinctKernel::Add(float16(), func); + aggregate::CountDistinctKernel::Add(float32(), func); + aggregate::CountDistinctKernel::Add(float64(), func); + // Date + aggregate::CountDistinctKernel::Add(date32(), func); + aggregate::CountDistinctKernel::Add(date64(), func); + // Time + aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::TIME32), func); + aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::TIME64), func); + // Timestamp & Duration + aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::TIMESTAMP), + func); + aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::DURATION), + func); + // Interval + aggregate::CountDistinctKernel::Add(month_interval(), func); + aggregate::CountDistinctKernel::Add(day_time_interval(), func); + aggregate::CountDistinctKernel::Add(month_day_nano_interval(), + func); + // Binary & String + aggregate::CountDistinctKernel::Add(match::BinaryLike(), + func); + aggregate::CountDistinctKernel::Add( + match::LargeBinaryLike(), func); + // Fixed binary & Decimal + aggregate::CountDistinctKernel::Add( + match::FixedSizeBinaryLike(), func); +} + // ---------------------------------------------------------------------- // Sum implementation @@ -674,6 +807,12 @@ const FunctionDoc count_doc{"Count the number of null / non-null values", {"array"}, "CountOptions"}; +const FunctionDoc count_distinct_doc{"Count the number of unique values", + ("By default, only non-null values are counted.\n" + "This can be changed through CountOptions."), + {"array"}, + "CountOptions"}; + const FunctionDoc sum_doc{ "Compute the sum of a numeric array", ("Null values are ignored by default. Minimum count of non-null\n" @@ -754,6 +893,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { aggregate::CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared( + "count_distinct", Arity::Unary(), &count_distinct_doc, &default_count_options); + // Takes any input, outputs int64 scalar + aggregate::AddCountDistinctKernels(func.get()); + DCHECK_OK(registry->AddFunction(std::move(func))); + func = std::make_shared("sum", Arity::Unary(), &sum_doc, &default_scalar_aggregate_options); aggregate::AddArrayScalarAggKernels(aggregate::SumInit, {boolean()}, uint64(), diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 7db3f292269..458e324826c 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -873,6 +874,152 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { } } +// +// Count Distinct +// + +class TestCountDistinctKernel : public ::testing::Test { + protected: + void SetUp() override { + only_valid = CountOptions(CountOptions::ONLY_VALID); + only_null = CountOptions(CountOptions::ONLY_NULL); + all = CountOptions(CountOptions::ALL); + } + + Datum Expected(int64_t value) { return MakeScalar(static_cast(value)); } + + void Check(Datum input, int64_t expected_all, bool has_nulls = true) { + int64_t expected_valid = has_nulls ? expected_all - 1 : expected_all; + int64_t expected_null = has_nulls ? 1 : 0; + CheckScalar("count_distinct", {input}, Expected(expected_valid), &only_valid); + CheckScalar("count_distinct", {input}, Expected(expected_null), &only_null); + CheckScalar("count_distinct", {input}, Expected(expected_all), &all); + } + + void Check(const std::shared_ptr& type, util::string_view json, + int64_t expected_all, bool has_nulls = true) { + Check(ArrayFromJSON(type, json), expected_all, has_nulls); + } + + void Check(const std::shared_ptr& type, util::string_view json) { + auto input = ScalarFromJSON(type, json); + auto zero = ResultWith(Expected(0)); + auto one = ResultWith(Expected(1)); + // non null scalar + EXPECT_THAT(CallFunction("count_distinct", {input}, &only_valid), one); + EXPECT_THAT(CallFunction("count_distinct", {input}, &only_null), zero); + EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); + // null scalar + input = MakeNullScalar(input->type); + EXPECT_THAT(CallFunction("count_distinct", {input}, &only_valid), zero); + EXPECT_THAT(CallFunction("count_distinct", {input}, &only_null), one); + EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); + } + + CountOptions only_valid; + CountOptions only_null; + CountOptions all; +}; + +TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) { + // Boolean + Check(boolean(), "[true, null, false, null, false, true]", 3); + // Number + for (auto ty : NumericTypes()) { + Check(ty, "[1, 1, null, 2, 5, 8, 9, 9, null, 10, 6, 6]", 8); + Check(ty, "[1, 1, 8, 2, 5, 8, 9, 9, 10, 10, 6, 6]", 7, false); + } + // Date + Check(date32(), "[0, 11016, 0, null, 14241, 14241, null]", 4); + Check(date64(), "[0, null, 0, null, 0, 0, 1262217600000]", 3); + // Time + Check(time32(TimeUnit::SECOND), "[0, 11, 0, null, 14, 14, null]", 4); + Check(time32(TimeUnit::MILLI), "[0, 11000, 0, null, 11000, 11000]", 3); + Check(time64(TimeUnit::MICRO), "[84203999999, 0, null, 84203999999, 0]", 3); + Check(time64(TimeUnit::NANO), "[11715003000000, 0, null, 0, 0]", 3); + // Timestamp & Duration + for (auto u : TimeUnit::values()) { + Check(duration(u), "[123456789, null, 987654321, 123456789, null]", 3); + Check(duration(u), "[123456789, 987654321, 123456789, 123456789]", 2, false); + auto ts = R"(["2009-12-31T04:20:20", "2020-01-01", null, "2009-12-31T04:20:20"])"; + Check(timestamp(u), ts, 3); + Check(timestamp(u, "Pacific/Marquesas"), ts, 3); + } + // Interval + Check(month_interval(), "[9012, 5678, null, 9012, 5678, null, 9012]", 3); + Check(day_time_interval(), "[[0, 1], [0, 1], null, [0, 1], [1234, 5678]]", 3); + Check(month_day_nano_interval(), "[[0, 1, 2], [0, 1, 2], null, [0, 1, 2]]", 2); + // Binary & String & Fixed binary + auto samples = R"([null, "abc", null, "abc", "abc", "cba", "bca", "cba", null])"; + Check(binary(), samples, 4); + Check(large_binary(), samples, 4); + Check(utf8(), samples, 4); + Check(large_utf8(), samples, 4); + Check(fixed_size_binary(3), samples, 4); + // Decimal + samples = R"(["12345.679", "98765.421", null, "12345.679", "98765.421"])"; + Check(decimal128(21, 3), samples, 3); + Check(decimal256(13, 3), samples, 3); +} + +TEST_F(TestCountDistinctKernel, AllScalarTypesWithNulls) { + // Boolean + Check(boolean(), "true"); + // Number + for (auto ty : NumericTypes()) { + Check(ty, "91"); + } + // Date + Check(date32(), "11016"); + Check(date64(), "1262217600000"); + // Time + Check(time32(TimeUnit::SECOND), "14"); + Check(time32(TimeUnit::MILLI), "11000"); + Check(time64(TimeUnit::MICRO), "84203999999"); + Check(time64(TimeUnit::NANO), "11715003000000"); + // Timestamp & Duration + for (auto u : TimeUnit::values()) { + Check(duration(u), "987654321"); + Check(duration(u), "123456789"); + auto ts = R"("2009-12-31T04:20:20")"; + Check(timestamp(u), ts); + Check(timestamp(u, "Pacific/Marquesas"), ts); + } + // Interval + Check(month_interval(), "5678"); + Check(day_time_interval(), "[1234, 5678]"); + Check(month_day_nano_interval(), "[0, 1, 2]"); + // Binary & String & Fixed binary + auto sample = R"("cba")"; + Check(binary(), sample); + Check(large_binary(), sample); + Check(utf8(), sample); + Check(large_utf8(), sample); + Check(fixed_size_binary(3), sample); + // Decimal + sample = R"("98765.421")"; + Check(decimal128(21, 3), sample); + Check(decimal256(13, 3), sample); +} + +TEST_F(TestCountDistinctKernel, RandomValidsStdMap) { + UInt32Builder builder; + std::unordered_set memo; + auto visit_null = []() { return Status::OK(); }; + auto visit_value = [&](uint32_t arg) { + const bool inserted = memo.insert(arg).second; + if (inserted) { + return builder.Append(arg); + } + return Status::OK(); + }; + auto rand = random::RandomArrayGenerator(0x1205643); + auto arr = rand.Numeric(1024, 0, 100, 0.0)->data(); + auto r = VisitArrayDataInline(*arr, visit_value, visit_null); + auto input = builder.Finish().ValueOrDie(); + Check(input, memo.size(), false); +} + // // Mean // diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index f699698e97b..590356e6489 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -199,6 +199,8 @@ the input to a single output value. +--------------------+-------+------------------+------------------------+----------------------------------+-------+ | count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) | +--------------------+-------+------------------+------------------------+----------------------------------+-------+ +| count_distinct | Unary | Non-nested types | Scalar Int64 | :struct:`CountOptions` | \(2) | ++--------------------+-------+------------------+------------------------+----------------------------------+-------+ | index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | | +--------------------+-------+------------------+------------------------+----------------------------------+-------+ | max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | | diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 2e5187259c3..521182f8a41 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -31,6 +31,7 @@ Aggregations any approximate_median count + count_distinct index max mean diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index fd5475a5864..24cf2e9570e 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2220,3 +2220,20 @@ def test_list_element(): result = pa.compute.list_element(lists, index) expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type) assert result.equals(expected) + + +def test_count_distinct(): + seed = datetime.now() + samples = [seed.replace(year=y) for y in range(1992, 2092)] + arr = pa.array(samples, pa.timestamp("ns")) + result = pa.compute.count_distinct(arr) + expected = pa.scalar(len(samples), type=pa.int64()) + assert result.equals(expected) + + +def test_count_distinct_options(): + arr = pa.array([1, 2, 3, None, None]) + assert pc.count_distinct(arr).as_py() == 3 + assert pc.count_distinct(arr, mode='only_valid').as_py() == 3 + assert pc.count_distinct(arr, mode='only_null').as_py() == 1 + assert pc.count_distinct(arr, mode='all').as_py() == 4 diff --git a/r/src/compute.cpp b/r/src/compute.cpp index de956550aba..f3cf514b885 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -208,7 +208,7 @@ std::shared_ptr make_compute_options( return out; } - if (func_name == "hash_count_distinct") { + if (func_name == "count_distinct" || func_name == "hash_count_distinct") { using Options = arrow::compute::CountOptions; auto out = std::make_shared(Options::Defaults()); out->mode = diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 6471e15e23d..aa2bf2374d6 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -212,7 +212,8 @@ test_that("Group by any/all", { ) }) -test_that("Group by n_distinct() on dataset", { +test_that("n_distinct() on dataset", { + # With groupby expect_dplyr_equal( input %>% group_by(some_grouping) %>% @@ -227,6 +228,19 @@ test_that("Group by n_distinct() on dataset", { collect(), tbl ) + # Without groupby + expect_dplyr_equal( + input %>% + summarize(distinct = n_distinct(lgl, na.rm = FALSE)) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + summarize(distinct = n_distinct(lgl, na.rm = TRUE)) %>% + collect(), + tbl + ) }) test_that("median()", { From 9761c385b11dd4b7d58741508380a8b0316e4796 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 5 Oct 2021 14:58:13 +0200 Subject: [PATCH 2/3] Simplify namespaces and function registration --- .../arrow/compute/kernels/aggregate_basic.cc | 209 ++++++++---------- .../compute/kernels/aggregate_basic_avx2.cc | 4 +- .../compute/kernels/aggregate_basic_avx512.cc | 7 +- .../kernels/aggregate_basic_internal.h | 6 +- .../compute/kernels/aggregate_internal.h | 12 +- .../arrow/compute/kernels/aggregate_test.cc | 2 + .../compute/kernels/aggregate_var_std.cc | 7 +- 7 files changed, 109 insertions(+), 138 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 38b91251dd8..25697f7d33b 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -26,6 +26,7 @@ namespace arrow { namespace compute { +namespace internal { namespace { @@ -62,7 +63,7 @@ void AddAggKernel(std::shared_ptr sig, KernelInit init, DCHECK_OK(func->AddKernel(std::move(kernel))); } -namespace aggregate { +namespace { // ---------------------------------------------------------------------- // Count implementation @@ -192,65 +193,46 @@ Result> CountDistinctInit(KernelContext* ctx, ctx->memory_pool(), static_cast(*args.options)); } -template +template void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) { AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())), - aggregate::CountDistinctInit, func); + CountDistinctInit, func); } -template -struct CountDistinctKernel { - static void Add(InputType type, ScalarAggregateFunction* func) { - using PhysicalType = typename Type::PhysicalType; - AddCountDistinctKernel(type, func); - } -}; - -template <> -struct CountDistinctKernel { - static void Add(InputType type, ScalarAggregateFunction* func) { - AddCountDistinctKernel(type, func); - } -}; - void AddCountDistinctKernels(ScalarAggregateFunction* func) { // Boolean - aggregate::CountDistinctKernel::Add(boolean(), func); + AddCountDistinctKernel(boolean(), func); // Number - aggregate::CountDistinctKernel::Add(int8(), func); - aggregate::CountDistinctKernel::Add(int16(), func); - aggregate::CountDistinctKernel::Add(int32(), func); - aggregate::CountDistinctKernel::Add(int64(), func); - aggregate::CountDistinctKernel::Add(uint8(), func); - aggregate::CountDistinctKernel::Add(uint16(), func); - aggregate::CountDistinctKernel::Add(uint32(), func); - aggregate::CountDistinctKernel::Add(uint64(), func); - aggregate::CountDistinctKernel::Add(float16(), func); - aggregate::CountDistinctKernel::Add(float32(), func); - aggregate::CountDistinctKernel::Add(float64(), func); + AddCountDistinctKernel(int8(), func); + AddCountDistinctKernel(int16(), func); + AddCountDistinctKernel(int32(), func); + AddCountDistinctKernel(int64(), func); + AddCountDistinctKernel(uint8(), func); + AddCountDistinctKernel(uint16(), func); + AddCountDistinctKernel(uint32(), func); + AddCountDistinctKernel(uint64(), func); + AddCountDistinctKernel(float16(), func); + AddCountDistinctKernel(float32(), func); + AddCountDistinctKernel(float64(), func); // Date - aggregate::CountDistinctKernel::Add(date32(), func); - aggregate::CountDistinctKernel::Add(date64(), func); + AddCountDistinctKernel(date32(), func); + AddCountDistinctKernel(date64(), func); // Time - aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::TIME32), func); - aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::TIME64), func); + AddCountDistinctKernel(match::SameTypeId(Type::TIME32), func); + AddCountDistinctKernel(match::SameTypeId(Type::TIME64), func); // Timestamp & Duration - aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::TIMESTAMP), - func); - aggregate::CountDistinctKernel::Add(match::SameTypeId(Type::DURATION), - func); + AddCountDistinctKernel(match::SameTypeId(Type::TIMESTAMP), func); + AddCountDistinctKernel(match::SameTypeId(Type::DURATION), func); // Interval - aggregate::CountDistinctKernel::Add(month_interval(), func); - aggregate::CountDistinctKernel::Add(day_time_interval(), func); - aggregate::CountDistinctKernel::Add(month_day_nano_interval(), - func); + AddCountDistinctKernel(month_interval(), func); + AddCountDistinctKernel(day_time_interval(), func); + AddCountDistinctKernel(month_day_nano_interval(), func); // Binary & String - aggregate::CountDistinctKernel::Add(match::BinaryLike(), - func); - aggregate::CountDistinctKernel::Add( - match::LargeBinaryLike(), func); + AddCountDistinctKernel(match::BinaryLike(), func); + AddCountDistinctKernel(match::LargeBinaryLike(), + func); // Fixed binary & Decimal - aggregate::CountDistinctKernel::Add( + AddCountDistinctKernel( match::FixedSizeBinaryLike(), func); } @@ -736,6 +718,8 @@ struct IndexInit { } }; +} // namespace + void AddBasicAggKernels(KernelInit init, const std::vector>& types, std::shared_ptr out_ty, ScalarAggregateFunction* func, @@ -769,12 +753,16 @@ void AddArrayScalarAggKernels(KernelInit init, AddScalarAggKernels(init, types, out_ty, func); } +namespace { + Result MinMaxType(KernelContext*, const std::vector& descrs) { // any[T] -> scalar[struct] auto ty = descrs.front().type; return ValueDescr::Scalar(struct_({field("min", ty), field("max", ty)})); } +} // namespace + void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id, ScalarAggregateFunction* func, SimdLevel::type simd_level) { auto sig = KernelSignature::Make({InputType(get_id.id)}, OutputType(MinMaxType)); @@ -789,6 +777,8 @@ void AddMinMaxKernels(KernelInit init, } } +namespace { + Result ScalarFirstType(KernelContext*, const std::vector& descrs) { ValueDescr result = descrs.front(); @@ -796,11 +786,6 @@ Result ScalarFirstType(KernelContext*, return result; } -} // namespace aggregate - -namespace internal { -namespace { - const FunctionDoc count_doc{"Count the number of null / non-null values", ("By default, only non-null values are counted.\n" "This can be changed through CountOptions."), @@ -889,92 +874,86 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Takes any input, outputs int64 scalar InputType any_input; - AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), - aggregate::CountInit, func.get()); + AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), CountInit, + func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared( "count_distinct", Arity::Unary(), &count_distinct_doc, &default_count_options); // Takes any input, outputs int64 scalar - aggregate::AddCountDistinctKernels(func.get()); + AddCountDistinctKernels(func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("sum", Arity::Unary(), &sum_doc, &default_scalar_aggregate_options); - aggregate::AddArrayScalarAggKernels(aggregate::SumInit, {boolean()}, uint64(), - func.get()); - AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL128)}, - OutputType(aggregate::ScalarFirstType)), - aggregate::SumInit, func.get(), SimdLevel::NONE); - AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL256)}, - OutputType(aggregate::ScalarFirstType)), - aggregate::SumInit, func.get(), SimdLevel::NONE); - aggregate::AddArrayScalarAggKernels(aggregate::SumInit, SignedIntTypes(), int64(), - func.get()); - aggregate::AddArrayScalarAggKernels(aggregate::SumInit, UnsignedIntTypes(), uint64(), - func.get()); - aggregate::AddArrayScalarAggKernels(aggregate::SumInit, FloatingPointTypes(), float64(), - func.get()); + AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get()); + AddAggKernel( + KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), + SumInit, func.get(), SimdLevel::NONE); + AddAggKernel( + KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), + SumInit, func.get(), SimdLevel::NONE); + AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get()); + AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get()); + AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get()); // Add the SIMD variants for sum #if defined(ARROW_HAVE_RUNTIME_AVX2) || defined(ARROW_HAVE_RUNTIME_AVX512) auto cpu_info = arrow::internal::CpuInfo::GetInstance(); #endif #if defined(ARROW_HAVE_RUNTIME_AVX2) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - aggregate::AddSumAvx2AggKernels(func.get()); + AddSumAvx2AggKernels(func.get()); } #endif #if defined(ARROW_HAVE_RUNTIME_AVX512) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - aggregate::AddSumAvx512AggKernels(func.get()); + AddSumAvx512AggKernels(func.get()); } #endif DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("mean", Arity::Unary(), &mean_doc, &default_scalar_aggregate_options); - aggregate::AddArrayScalarAggKernels(aggregate::MeanInit, {boolean()}, float64(), - func.get()); - aggregate::AddArrayScalarAggKernels(aggregate::MeanInit, NumericTypes(), float64(), - func.get()); - AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL128)}, - OutputType(aggregate::ScalarFirstType)), - aggregate::MeanInit, func.get(), SimdLevel::NONE); - AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL256)}, - OutputType(aggregate::ScalarFirstType)), - aggregate::MeanInit, func.get(), SimdLevel::NONE); + AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get()); + AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get()); + AddAggKernel( + KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), + MeanInit, func.get(), SimdLevel::NONE); + AddAggKernel( + KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), + MeanInit, func.get(), SimdLevel::NONE); // Add the SIMD variants for mean #if defined(ARROW_HAVE_RUNTIME_AVX2) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - aggregate::AddMeanAvx2AggKernels(func.get()); + AddMeanAvx2AggKernels(func.get()); } #endif #if defined(ARROW_HAVE_RUNTIME_AVX512) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - aggregate::AddMeanAvx512AggKernels(func.get()); + AddMeanAvx512AggKernels(func.get()); } #endif DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared( "min_max", Arity::Unary(), &min_max_doc, &default_scalar_aggregate_options); - aggregate::AddMinMaxKernels(aggregate::MinMaxInit, {null(), boolean()}, func.get()); - aggregate::AddMinMaxKernels(aggregate::MinMaxInit, NumericTypes(), func.get()); - aggregate::AddMinMaxKernels(aggregate::MinMaxInit, TemporalTypes(), func.get()); - aggregate::AddMinMaxKernels(aggregate::MinMaxInit, BaseBinaryTypes(), func.get()); - aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::FIXED_SIZE_BINARY, func.get()); - aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::INTERVAL_MONTHS, func.get()); - aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::DECIMAL128, func.get()); - aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::DECIMAL256, func.get()); + AddMinMaxKernels(MinMaxInit, {null(), boolean()}, func.get()); + AddMinMaxKernels(MinMaxInit, NumericTypes(), func.get()); + AddMinMaxKernels(MinMaxInit, TemporalTypes(), func.get()); + AddMinMaxKernels(MinMaxInit, BaseBinaryTypes(), func.get()); + AddMinMaxKernel(MinMaxInit, Type::FIXED_SIZE_BINARY, func.get()); + AddMinMaxKernel(MinMaxInit, Type::INTERVAL_MONTHS, func.get()); + AddMinMaxKernel(MinMaxInit, Type::DECIMAL128, func.get()); + AddMinMaxKernel(MinMaxInit, Type::DECIMAL256, func.get()); // Add the SIMD variants for min max #if defined(ARROW_HAVE_RUNTIME_AVX2) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { - aggregate::AddMinMaxAvx2AggKernels(func.get()); + AddMinMaxAvx2AggKernels(func.get()); } #endif #if defined(ARROW_HAVE_RUNTIME_AVX512) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX512)) { - aggregate::AddMinMaxAvx512AggKernels(func.get()); + AddMinMaxAvx512AggKernels(func.get()); } #endif @@ -984,54 +963,46 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Add min/max as convenience functions func = std::make_shared("min", Arity::Unary(), &min_or_max_doc, &default_scalar_aggregate_options); - aggregate::AddMinOrMaxAggKernel(func.get(), min_max_func); + AddMinOrMaxAggKernel(func.get(), min_max_func); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("max", Arity::Unary(), &min_or_max_doc, &default_scalar_aggregate_options); - aggregate::AddMinOrMaxAggKernel(func.get(), min_max_func); + AddMinOrMaxAggKernel(func.get(), min_max_func); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared( "product", Arity::Unary(), &product_doc, &default_scalar_aggregate_options); - aggregate::AddArrayScalarAggKernels(aggregate::ProductInit::Init, {boolean()}, uint64(), - func.get()); - aggregate::AddArrayScalarAggKernels(aggregate::ProductInit::Init, SignedIntTypes(), - int64(), func.get()); - aggregate::AddArrayScalarAggKernels(aggregate::ProductInit::Init, UnsignedIntTypes(), - uint64(), func.get()); - aggregate::AddArrayScalarAggKernels(aggregate::ProductInit::Init, FloatingPointTypes(), - float64(), func.get()); - AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL128)}, - OutputType(aggregate::ScalarFirstType)), - aggregate::ProductInit::Init, func.get(), SimdLevel::NONE); - AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL256)}, - OutputType(aggregate::ScalarFirstType)), - aggregate::ProductInit::Init, func.get(), SimdLevel::NONE); + AddArrayScalarAggKernels(ProductInit::Init, {boolean()}, uint64(), func.get()); + AddArrayScalarAggKernels(ProductInit::Init, SignedIntTypes(), int64(), func.get()); + AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get()); + AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(), + func.get()); + AddAggKernel( + KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), + ProductInit::Init, func.get(), SimdLevel::NONE); + AddAggKernel( + KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), + ProductInit::Init, func.get(), SimdLevel::NONE); DCHECK_OK(registry->AddFunction(std::move(func))); // any func = std::make_shared("any", Arity::Unary(), &any_doc, &default_scalar_aggregate_options); - aggregate::AddArrayScalarAggKernels(aggregate::AnyInit, {boolean()}, boolean(), - func.get()); + AddArrayScalarAggKernels(AnyInit, {boolean()}, boolean(), func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); // all func = std::make_shared("all", Arity::Unary(), &all_doc, &default_scalar_aggregate_options); - aggregate::AddArrayScalarAggKernels(aggregate::AllInit, {boolean()}, boolean(), - func.get()); + AddArrayScalarAggKernels(AllInit, {boolean()}, boolean(), func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); // index func = std::make_shared("index", Arity::Unary(), &index_doc); - aggregate::AddBasicAggKernels(aggregate::IndexInit::Init, BaseBinaryTypes(), int64(), - func.get()); - aggregate::AddBasicAggKernels(aggregate::IndexInit::Init, PrimitiveTypes(), int64(), - func.get()); - aggregate::AddBasicAggKernels(aggregate::IndexInit::Init, TemporalTypes(), int64(), - func.get()); + AddBasicAggKernels(IndexInit::Init, BaseBinaryTypes(), int64(), func.get()); + AddBasicAggKernels(IndexInit::Init, PrimitiveTypes(), int64(), func.get()); + AddBasicAggKernels(IndexInit::Init, TemporalTypes(), int64(), func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc index dd12d5244f5..00e3e2e5fd4 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc @@ -19,7 +19,7 @@ namespace arrow { namespace compute { -namespace aggregate { +namespace internal { // ---------------------------------------------------------------------- // Sum implementation @@ -83,6 +83,6 @@ void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func) { AddMinMaxKernel(MinMaxInitAvx2, Type::INTERVAL_MONTHS, func, SimdLevel::AVX2); } -} // namespace aggregate +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc index ebe748d685d..8c10eb19b07 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc @@ -19,7 +19,7 @@ namespace arrow { namespace compute { -namespace aggregate { +namespace internal { // ---------------------------------------------------------------------- // Sum implementation @@ -72,8 +72,7 @@ void AddSumAvx512AggKernels(ScalarAggregateFunction* func) { } void AddMeanAvx512AggKernels(ScalarAggregateFunction* func) { - aggregate::AddBasicAggKernels(MeanInitAvx512, NumericTypes(), float64(), func, - SimdLevel::AVX512); + AddBasicAggKernels(MeanInitAvx512, NumericTypes(), float64(), func, SimdLevel::AVX512); } void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) { @@ -86,6 +85,6 @@ void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) { AddMinMaxKernel(MinMaxInitAvx512, Type::INTERVAL_MONTHS, func, SimdLevel::AVX512); } -} // namespace aggregate +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index f5ea9a0d65a..156e908eadf 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -30,7 +30,7 @@ namespace arrow { namespace compute { -namespace aggregate { +namespace internal { void AddBasicAggKernels(KernelInit init, const std::vector>& types, @@ -83,7 +83,7 @@ struct SumImpl : public ScalarAggregator { if (is_boolean_type::value) { this->sum += static_cast(BooleanArray(data).true_count()); } else { - this->sum += arrow::compute::detail::SumArray(*data); + this->sum += SumArray(*data); } } else { const auto& data = *batch[0].scalar(); @@ -621,6 +621,6 @@ struct MinMaxInitState { } }; -} // namespace aggregate +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h index 33ccefd4cbd..22a54558f4e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h @@ -25,6 +25,7 @@ namespace arrow { namespace compute { +namespace internal { // Find the largest compatible primitive type for a primitive type. template @@ -110,10 +111,6 @@ void AddAggKernel(std::shared_ptr sig, KernelInit init, ScalarAggregateFinalize finalize, ScalarAggregateFunction* func, SimdLevel::type simd_level = SimdLevel::NONE); -namespace detail { - -using arrow::internal::VisitSetBitRunsVoid; - // SumArray must be parameterized with the SIMD level since it's called both from // translation units with and without vectorization. Normally it gets inlined but // if not, without the parameter, we'll have multiple definitions of the same @@ -125,6 +122,8 @@ template enable_if_t::value, SumType> SumArray( const ArrayData& data, ValueFunc&& func) { + using arrow::internal::VisitSetBitRunsVoid; + const int64_t data_size = data.length - data.GetNullCount(); if (data_size == 0) { return 0; @@ -200,6 +199,8 @@ template enable_if_t::value, SumType> SumArray( const ArrayData& data, ValueFunc&& func) { + using arrow::internal::VisitSetBitRunsVoid; + SumType sum = 0; const ValueType* values = data.GetValues(1); VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length, @@ -217,7 +218,6 @@ SumType SumArray(const ArrayData& data) { data, [](ValueType v) { return static_cast(v); }); } -} // namespace detail - +} // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 458e324826c..1937383aaef 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -54,6 +54,8 @@ using internal::checked_pointer_cast; namespace compute { +using internal::FindAccumulatorType; + // // Sum // diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index 42ac655877c..d0d3c514fae 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -54,12 +54,11 @@ struct VarStdState { using SumType = typename std::conditional::value, double, int128_t>::type; - SumType sum = - arrow::compute::detail::SumArray(*array.data()); + SumType sum = SumArray(*array.data()); const double mean = static_cast(sum) / count; - const double m2 = arrow::compute::detail::SumArray( - *array.data(), [mean](CType value) { + const double m2 = + SumArray(*array.data(), [mean](CType value) { const double v = static_cast(value); return (v - mean) * (v - mean); }); From 5de1a5cf4b7c408a7217af3912e5903a1533090b Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 5 Oct 2021 15:10:13 +0200 Subject: [PATCH 3/3] Some nits --- .../arrow/compute/kernels/aggregate_test.cc | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 1937383aaef..992f7369864 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -882,12 +882,6 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) { class TestCountDistinctKernel : public ::testing::Test { protected: - void SetUp() override { - only_valid = CountOptions(CountOptions::ONLY_VALID); - only_null = CountOptions(CountOptions::ONLY_NULL); - all = CountOptions(CountOptions::ALL); - } - Datum Expected(int64_t value) { return MakeScalar(static_cast(value)); } void Check(Datum input, int64_t expected_all, bool has_nulls = true) { @@ -918,18 +912,19 @@ class TestCountDistinctKernel : public ::testing::Test { EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); } - CountOptions only_valid; - CountOptions only_null; - CountOptions all; + CountOptions only_valid{CountOptions::ONLY_VALID}; + CountOptions only_null{CountOptions::ONLY_NULL}; + CountOptions all{CountOptions::ALL}; }; TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) { // Boolean + Check(boolean(), "[]", 0, /*has_nulls=*/false); Check(boolean(), "[true, null, false, null, false, true]", 3); // Number for (auto ty : NumericTypes()) { Check(ty, "[1, 1, null, 2, 5, 8, 9, 9, null, 10, 6, 6]", 8); - Check(ty, "[1, 1, 8, 2, 5, 8, 9, 9, 10, 10, 6, 6]", 7, false); + Check(ty, "[1, 1, 8, 2, 5, 8, 9, 9, 10, 10, 6, 6]", 7, /*has_nulls=*/false); } // Date Check(date32(), "[0, 11016, 0, null, 14241, 14241, null]", 4); @@ -942,7 +937,8 @@ TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) { // Timestamp & Duration for (auto u : TimeUnit::values()) { Check(duration(u), "[123456789, null, 987654321, 123456789, null]", 3); - Check(duration(u), "[123456789, 987654321, 123456789, 123456789]", 2, false); + Check(duration(u), "[123456789, 987654321, 123456789, 123456789]", 2, + /*has_nulls=*/false); auto ts = R"(["2009-12-31T04:20:20", "2020-01-01", null, "2009-12-31T04:20:20"])"; Check(timestamp(u), ts, 3); Check(timestamp(u, "Pacific/Marquesas"), ts, 3); @@ -1004,7 +1000,7 @@ TEST_F(TestCountDistinctKernel, AllScalarTypesWithNulls) { Check(decimal256(13, 3), sample); } -TEST_F(TestCountDistinctKernel, RandomValidsStdMap) { +TEST_F(TestCountDistinctKernel, Random) { UInt32Builder builder; std::unordered_set memo; auto visit_null = []() { return Status::OK(); };