diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index 880424e97f8..d8cda022de8 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -42,14 +42,17 @@ class ExecContext; /// \brief Control general scalar aggregate kernel behavior /// -/// By default, null values are ignored +/// By default, null values are ignored (skip_nulls = true). class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { public: explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1); constexpr static char const kTypeName[] = "ScalarAggregateOptions"; static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; } + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. bool skip_nulls; + /// If less than this many non-null values are observed, emit null. uint32_t min_count; }; diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 548a008c5ce..a3fafaae75d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -162,6 +162,13 @@ struct ProductImpl : public ScalarAggregator { if (batch[0].is_array()) { const auto& data = batch[0].array(); this->count += data->length - data->GetNullCount(); + this->nulls_observed = this->nulls_observed || data->GetNullCount(); + + if (!options.skip_nulls && this->nulls_observed) { + // Short-circuit + return Status::OK(); + } + VisitArrayDataInline( *data, [&](typename TypeTraits::CType value) { @@ -172,6 +179,7 @@ struct ProductImpl : public ScalarAggregator { } else { const auto& data = *batch[0].scalar(); this->count += data.is_valid * batch.length; + this->nulls_observed = this->nulls_observed || !data.is_valid; if (data.is_valid) { for (int64_t i = 0; i < batch.length; i++) { auto value = internal::UnboxScalar::Unbox(data); @@ -188,11 +196,13 @@ struct ProductImpl : public ScalarAggregator { this->count += other.count; this->product = static_cast(to_unsigned(this->product) * to_unsigned(other.product)); + this->nulls_observed = this->nulls_observed || other.nulls_observed; return Status::OK(); } Status Finalize(KernelContext*, Datum* out) override { - if (this->count < options.min_count) { + if ((!options.skip_nulls && this->nulls_observed) || + (this->count < options.min_count)) { out->value = std::make_shared(); } else { out->value = MakeScalar(this->product); @@ -201,6 +211,7 @@ struct ProductImpl : public ScalarAggregator { } size_t count = 0; + bool nulls_observed = false; typename AccType::c_type product = 1; ScalarAggregateOptions options; }; @@ -268,17 +279,19 @@ struct BooleanAnyImpl : public ScalarAggregator { Status Consume(KernelContext*, const ExecBatch& batch) override { // short-circuit if seen a True already - if (this->any == true) { + if (this->any == true && this->count >= options.min_count) { return Status::OK(); } if (batch[0].is_scalar()) { const auto& scalar = *batch[0].scalar(); this->has_nulls = !scalar.is_valid; this->any = scalar.is_valid && checked_cast(scalar).value; + this->count += scalar.is_valid; return Status::OK(); } const auto& data = *batch[0].array(); this->has_nulls = data.GetNullCount() > 0; + this->count += data.length - data.GetNullCount(); arrow::internal::OptionalBinaryBitBlockCounter counter( data.buffers[0], data.offset, data.buffers[1], data.offset, data.length); int64_t position = 0; @@ -297,11 +310,13 @@ struct BooleanAnyImpl : public ScalarAggregator { const auto& other = checked_cast(src); this->any |= other.any; this->has_nulls |= other.has_nulls; + this->count += other.count; return Status::OK(); } Status Finalize(KernelContext* ctx, Datum* out) override { - if (!options.skip_nulls && !this->any && this->has_nulls) { + if ((!options.skip_nulls && !this->any && this->has_nulls) || + this->count < options.min_count) { out->value = std::make_shared(); } else { out->value = std::make_shared(this->any); @@ -311,6 +326,7 @@ struct BooleanAnyImpl : public ScalarAggregator { bool any = false; bool has_nulls = false; + int64_t count = 0; ScalarAggregateOptions options; }; @@ -329,7 +345,7 @@ struct BooleanAllImpl : public ScalarAggregator { Status Consume(KernelContext*, const ExecBatch& batch) override { // short-circuit if seen a false already - if (this->all == false) { + if (this->all == false && this->count >= options.min_count) { return Status::OK(); } // short-circuit if seen a null already @@ -339,11 +355,13 @@ struct BooleanAllImpl : public ScalarAggregator { if (batch[0].is_scalar()) { const auto& scalar = *batch[0].scalar(); this->has_nulls = !scalar.is_valid; + this->count += scalar.is_valid; this->all = !scalar.is_valid || checked_cast(scalar).value; return Status::OK(); } const auto& data = *batch[0].array(); this->has_nulls = data.GetNullCount() > 0; + this->count += data.length - data.GetNullCount(); arrow::internal::OptionalBinaryBitBlockCounter counter( data.buffers[1], data.offset, data.buffers[0], data.offset, data.length); int64_t position = 0; @@ -363,11 +381,13 @@ struct BooleanAllImpl : public ScalarAggregator { const auto& other = checked_cast(src); this->all &= other.all; this->has_nulls |= other.has_nulls; + this->count += other.count; return Status::OK(); } Status Finalize(KernelContext*, Datum* out) override { - if (!options.skip_nulls && this->all && this->has_nulls) { + if ((!options.skip_nulls && this->all && this->has_nulls) || + this->count < options.min_count) { out->value = std::make_shared(); } else { out->value = std::make_shared(this->all); @@ -377,6 +397,7 @@ struct BooleanAllImpl : public ScalarAggregator { bool all = true; bool has_nulls = false; + int64_t count = 0; ScalarAggregateOptions options; }; diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index eb314011229..b355a2e1b75 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -68,6 +68,13 @@ struct SumImpl : public ScalarAggregator { if (batch[0].is_array()) { const auto& data = batch[0].array(); this->count += data->length - data->GetNullCount(); + this->nulls_observed = this->nulls_observed || data->GetNullCount(); + + if (!options.skip_nulls && this->nulls_observed) { + // Short-circuit + return Status::OK(); + } + if (is_boolean_type::value) { this->sum += static_cast(BooleanArray(data).true_count()); @@ -79,6 +86,7 @@ struct SumImpl : public ScalarAggregator { } else { const auto& data = *batch[0].scalar(); this->count += data.is_valid * batch.length; + this->nulls_observed = this->nulls_observed || !data.is_valid; if (data.is_valid) { this->sum += internal::UnboxScalar::Unbox(data) * batch.length; } @@ -90,11 +98,13 @@ struct SumImpl : public ScalarAggregator { const auto& other = checked_cast(src); this->count += other.count; this->sum += other.sum; + this->nulls_observed = this->nulls_observed || other.nulls_observed; return Status::OK(); } Status Finalize(KernelContext*, Datum* out) override { - if (this->count < options.min_count) { + if ((!options.skip_nulls && this->nulls_observed) || + (this->count < options.min_count)) { out->value = std::make_shared(); } else { out->value = MakeScalar(this->sum); @@ -103,6 +113,7 @@ struct SumImpl : public ScalarAggregator { } size_t count = 0; + bool nulls_observed = false; typename SumType::c_type sum = 0; ScalarAggregateOptions options; }; @@ -110,7 +121,8 @@ struct SumImpl : public ScalarAggregator { template struct MeanImpl : public SumImpl { Status Finalize(KernelContext*, Datum* out) override { - if (this->count < options.min_count) { + if ((!options.skip_nulls && this->nulls_observed) || + (this->count < options.min_count)) { out->value = std::make_shared(); } else { const double mean = static_cast(this->sum) / this->count; diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index aca693601c3..9893923a097 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -139,14 +139,12 @@ template void ValidateBooleanAgg(const std::string& json, const std::shared_ptr& expected, const ScalarAggregateOptions& options) { + SCOPED_TRACE(json); auto array = ArrayFromJSON(boolean(), json); ASSERT_OK_AND_ASSIGN(Datum result, Op(array, options, nullptr)); - const auto& exp = Datum(expected); - const auto& res = checked_pointer_cast(result.scalar()); - if (!(std::isnan((double)res->value) && std::isnan((double)expected->value))) { - ASSERT_TRUE(result.Equals(exp)); - } + auto equal_options = EqualOptions::Defaults().nans_equal(true); + AssertScalarsEqual(*expected, *result.scalar(), /*verbose=*/true, equal_options); } TEST(TestBooleanAggregation, Sum) { @@ -167,13 +165,23 @@ TEST(TestBooleanAggregation, Sum) { ValidateBooleanAgg("[null]", std::make_shared(0), options_min_count_zero); - const char* json = "[true, null, false, null]"; + std::string json = "[true, null, false, null]"; ValidateBooleanAgg(json, std::make_shared(1), ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)); ValidateBooleanAgg(json, std::make_shared(1), ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)); ValidateBooleanAgg(json, std::make_shared(), ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)); + ValidateBooleanAgg("[]", std::make_shared(0), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateBooleanAgg(json, std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateBooleanAgg("[]", std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); + ValidateBooleanAgg(json, std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); + + json = "[true, false]"; ValidateBooleanAgg(json, std::make_shared(1), ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); ValidateBooleanAgg(json, std::make_shared(1), @@ -187,6 +195,12 @@ TEST(TestBooleanAggregation, Sum) { ResultWith(Datum(std::make_shared(0)))); EXPECT_THAT(Sum(MakeNullScalar(boolean())), ResultWith(Datum(MakeNullScalar(uint64())))); + EXPECT_THAT(Sum(MakeNullScalar(boolean()), + ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)), + ResultWith(ScalarFromJSON(uint64(), "0"))); + EXPECT_THAT(Sum(MakeNullScalar(boolean()), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)), + ResultWith(ScalarFromJSON(uint64(), "null"))); } TEST(TestBooleanAggregation, Product) { @@ -219,11 +233,14 @@ TEST(TestBooleanAggregation, Product) { json, std::make_shared(), ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)); ValidateBooleanAgg( - json, std::make_shared(1), - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); + "[]", std::make_shared(1), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); ValidateBooleanAgg( - json, std::make_shared(1), - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2)); + json, std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateBooleanAgg( + "[]", std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); ValidateBooleanAgg( json, std::make_shared(), ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); @@ -234,6 +251,12 @@ TEST(TestBooleanAggregation, Product) { ResultWith(Datum(std::make_shared(0)))); EXPECT_THAT(Product(MakeNullScalar(boolean())), ResultWith(Datum(MakeNullScalar(uint64())))); + EXPECT_THAT(Product(MakeNullScalar(boolean()), + ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)), + ResultWith(ScalarFromJSON(uint64(), "1"))); + EXPECT_THAT(Product(MakeNullScalar(boolean()), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)), + ResultWith(ScalarFromJSON(uint64(), "null"))); } TEST(TestBooleanAggregation, Mean) { @@ -264,17 +287,27 @@ TEST(TestBooleanAggregation, Mean) { ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)); ValidateBooleanAgg(json, std::make_shared(), ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)); - ValidateBooleanAgg(json, std::make_shared(0.5), - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); - ValidateBooleanAgg(json, std::make_shared(0.5), - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2)); + ValidateBooleanAgg("[]", std::make_shared(NAN), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateBooleanAgg(json, std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateBooleanAgg("[]", std::make_shared(), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); ValidateBooleanAgg(json, std::make_shared(), ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); - EXPECT_THAT(Mean(MakeScalar(true)), ResultWith(Datum(MakeScalar(1.0)))); - EXPECT_THAT(Mean(MakeScalar(false)), ResultWith(Datum(MakeScalar(0.0)))); + EXPECT_THAT(Mean(MakeScalar(true)), ResultWith(ScalarFromJSON(float64(), "1.0"))); + EXPECT_THAT(Mean(MakeScalar(false)), ResultWith(ScalarFromJSON(float64(), "0.0"))); EXPECT_THAT(Mean(MakeNullScalar(boolean())), - ResultWith(Datum(MakeNullScalar(float64())))); + ResultWith(ScalarFromJSON(float64(), "null"))); + ASSERT_OK_AND_ASSIGN( + auto result, Mean(MakeNullScalar(boolean()), + ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0))); + AssertDatumsApproxEqual(result, ScalarFromJSON(float64(), "NaN"), /*detailed=*/true, + EqualOptions::Defaults().nans_equal(true)); + EXPECT_THAT(Mean(MakeNullScalar(boolean()), + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)), + ResultWith(ScalarFromJSON(float64(), "null"))); } template @@ -306,22 +339,23 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) { ValidateSum(chunks, Datum(std::make_shared(static_cast(5 * 6 / 2)))); - const ScalarAggregateOptions& options = - ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0); - + ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0); ValidateSum("[]", Datum(std::make_shared(static_cast(0))), options); - ValidateSum("[null]", Datum(std::make_shared(static_cast(0))), options); - chunks = {}; ValidateSum(chunks, Datum(std::make_shared(static_cast(0))), options); - const T expected_result = static_cast(14); + options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0); + ValidateSum("[]", Datum(std::make_shared(static_cast(0))), + options); + ValidateSum("[null]", Datum(std::make_shared()), options); ValidateSum("[1, null, 3, null, 3, null, 7]", - Datum(std::make_shared(expected_result)), options); + Datum(std::make_shared()), options); + ValidateSum("[1, null, 3, null, 3, null, 7]", + Datum(std::make_shared(14))); EXPECT_THAT(Sum(Datum(std::make_shared(static_cast(5)))), ResultWith(Datum(std::make_shared(static_cast(5))))); @@ -355,12 +389,10 @@ TYPED_TEST(TestNumericSumKernel, ScalarAggregateOptions) { ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)); ValidateSum("[null]", null_result, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)); - ValidateSum(json, result, - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); - ValidateSum(json, result, - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)); + ValidateSum("[]", zero_result, + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); ValidateSum(json, null_result, - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)); + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); EXPECT_THAT(Sum(Datum(std::make_shared(static_cast(5))), ScalarAggregateOptions(/*skip_nulls=*/false)), @@ -490,16 +522,20 @@ TYPED_TEST(TestNumericProductKernel, SimpleProduct) { chunks = ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3, 4, 5]"}); EXPECT_THAT(Product(chunks), ResultWith(Datum(static_cast(120)))); - const ScalarAggregateOptions& options = - ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0); - + ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0); EXPECT_THAT(Product(ArrayFromJSON(ty, "[]"), options), Datum(static_cast(1))); EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]"), options), Datum(static_cast(1))); chunks = ChunkedArrayFromJSON(ty, {}); EXPECT_THAT(Product(chunks, options), Datum(static_cast(1))); + options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0); + EXPECT_THAT(Product(ArrayFromJSON(ty, "[]"), options), + ResultWith(Datum(static_cast(1)))); + EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]"), options), ResultWith(null_result)); EXPECT_THAT(Product(ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]"), options), + ResultWith(null_result)); + EXPECT_THAT(Product(ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]")), Datum(static_cast(63))); EXPECT_THAT(Product(Datum(static_cast(5))), @@ -538,11 +574,10 @@ TYPED_TEST(TestNumericProductKernel, ScalarAggregateOptions) { ResultWith(null_result)); EXPECT_THAT(Product(null, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)), ResultWith(null_result)); - EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)), - ResultWith(result)); - EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)), - ResultWith(result)); - EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)), + EXPECT_THAT( + Product(empty, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)), + ResultWith(one_result)); + EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)), ResultWith(null_result)); EXPECT_THAT( @@ -667,7 +702,7 @@ void ValidateMean(const Array& input, Datum expected, template void ValidateMean( - const char* json, Datum expected, + const std::string& json, Datum expected, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) { auto array = ArrayFromJSON(TypeTraits::type_singleton(), json); ValidateMean(*array, expected, options); @@ -680,10 +715,10 @@ void ValidateMean(const Array& array, const ScalarAggregateOptions& options = } template -class TestMeanKernelNumeric : public ::testing::Test {}; +class TestNumericMeanKernel : public ::testing::Test {}; -TYPED_TEST_SUITE(TestMeanKernelNumeric, NumericArrowTypes); -TYPED_TEST(TestMeanKernelNumeric, SimpleMean) { +TYPED_TEST_SUITE(TestNumericMeanKernel, NumericArrowTypes); +TYPED_TEST(TestNumericMeanKernel, SimpleMean) { using ScalarType = typename TypeTraits::ScalarType; using InputScalarType = typename TypeTraits::ScalarType; using T = typename TypeParam::c_type; @@ -716,14 +751,14 @@ TYPED_TEST(TestMeanKernelNumeric, SimpleMean) { ResultWith(Datum(MakeNullScalar(float64())))); } -TYPED_TEST(TestMeanKernelNumeric, ScalarAggregateOptions) { +TYPED_TEST(TestNumericMeanKernel, ScalarAggregateOptions) { using ScalarType = typename TypeTraits::ScalarType; using InputScalarType = typename TypeTraits::ScalarType; using T = typename TypeParam::c_type; auto expected_result = Datum(std::make_shared(3)); auto null_result = Datum(std::make_shared()); auto nan_result = Datum(std::make_shared(NAN)); - const char* json = "[1, null, 2, 2, null, 7]"; + std::string json = "[1, null, 2, 2, null, 7]"; ValidateMean("[]", nan_result, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)); @@ -744,20 +779,25 @@ TYPED_TEST(TestMeanKernelNumeric, ScalarAggregateOptions) { ValidateMean("[]", nan_result, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); - ValidateMean("[null]", nan_result, + ValidateMean("[null]", null_result, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateMean(json, null_result, + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); + ValidateMean("[]", null_result, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); ValidateMean("[null]", null_result, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); + ValidateMean(json, null_result, + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); + + json = "[1, 2, 2, 7]"; ValidateMean(json, expected_result, - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)); - ValidateMean(json, expected_result, - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)); + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); ValidateMean(json, expected_result, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)); ValidateMean(json, null_result, - ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/15)); + ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)); EXPECT_THAT(Mean(Datum(std::make_shared(static_cast(5))), ScalarAggregateOptions(/*skip_nulls=*/false)), @@ -1234,19 +1274,21 @@ TYPED_TEST(TestRandomNumericMinMaxKernel, RandomArrayMinMax) { // Any // -class TestPrimitiveAnyKernel : public ::testing::Test { +class TestAnyKernel : public ::testing::Test { public: void AssertAnyIs(const Datum& array, const std::shared_ptr& expected, const ScalarAggregateOptions& options) { + SCOPED_TRACE(options.ToString()); ASSERT_OK_AND_ASSIGN(Datum out, Any(array, options, nullptr)); const BooleanScalar& out_any = out.scalar_as(); - ASSERT_EQ(out_any, *expected); + AssertScalarsEqual(*expected, out_any, /*verbose=*/true); } void AssertAnyIs( const std::string& json, const std::shared_ptr& expected, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) { - auto array = ArrayFromJSON(type_singleton(), json); + SCOPED_TRACE(json); + auto array = ArrayFromJSON(boolean(), json); AssertAnyIs(array, expected, options); } @@ -1254,17 +1296,11 @@ class TestPrimitiveAnyKernel : public ::testing::Test { const std::vector& json, const std::shared_ptr& expected, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) { - auto array = ChunkedArrayFromJSON(type_singleton(), json); + auto array = ChunkedArrayFromJSON(boolean(), json); AssertAnyIs(array, expected, options); } - - std::shared_ptr type_singleton() { - return TypeTraits::type_singleton(); - } }; -class TestAnyKernel : public TestPrimitiveAnyKernel {}; - TEST_F(TestAnyKernel, Basics) { auto true_value = std::make_shared(true); auto false_value = std::make_shared(false); @@ -1277,26 +1313,27 @@ TEST_F(TestAnyKernel, Basics) { std::vector chunked_input3 = {"[false, null]", "[null, false]"}; std::vector chunked_input4 = {"[true, null]", "[null, false]"}; - this->AssertAnyIs("[]", false_value); - this->AssertAnyIs("[false]", false_value); - this->AssertAnyIs("[true, false]", true_value); - this->AssertAnyIs("[null, null, null]", false_value); - this->AssertAnyIs("[false, false, false]", false_value); - this->AssertAnyIs("[false, false, false, null]", false_value); - this->AssertAnyIs("[true, null, true, true]", true_value); - this->AssertAnyIs("[false, null, false, true]", true_value); - this->AssertAnyIs("[true, null, false, true]", true_value); - this->AssertAnyIs(chunked_input0, true_value); - this->AssertAnyIs(chunked_input1, true_value); - this->AssertAnyIs(chunked_input2, false_value); - this->AssertAnyIs(chunked_input3, false_value); - this->AssertAnyIs(chunked_input4, true_value); - - EXPECT_THAT(Any(Datum(true)), ResultWith(Datum(true))); - EXPECT_THAT(Any(Datum(false)), ResultWith(Datum(false))); - EXPECT_THAT(Any(MakeNullScalar(boolean())), ResultWith(Datum(false))); - - const ScalarAggregateOptions& keep_nulls = ScalarAggregateOptions(/*skip_nulls=*/false); + const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0); + this->AssertAnyIs("[]", false_value, options); + this->AssertAnyIs("[false]", false_value, options); + this->AssertAnyIs("[true, false]", true_value, options); + this->AssertAnyIs("[null, null, null]", false_value, options); + this->AssertAnyIs("[false, false, false]", false_value, options); + this->AssertAnyIs("[false, false, false, null]", false_value, options); + this->AssertAnyIs("[true, null, true, true]", true_value, options); + this->AssertAnyIs("[false, null, false, true]", true_value, options); + this->AssertAnyIs("[true, null, false, true]", true_value, options); + this->AssertAnyIs(chunked_input0, true_value, options); + this->AssertAnyIs(chunked_input1, true_value, options); + this->AssertAnyIs(chunked_input2, false_value, options); + this->AssertAnyIs(chunked_input3, false_value, options); + this->AssertAnyIs(chunked_input4, true_value, options); + + EXPECT_THAT(Any(Datum(true), options), ResultWith(Datum(true))); + EXPECT_THAT(Any(Datum(false), options), ResultWith(Datum(false))); + EXPECT_THAT(Any(Datum(null_value), options), ResultWith(Datum(false))); + + const ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0); this->AssertAnyIs("[]", false_value, keep_nulls); this->AssertAnyIs("[false]", false_value, keep_nulls); this->AssertAnyIs("[true, false]", true_value, keep_nulls); @@ -1314,27 +1351,48 @@ TEST_F(TestAnyKernel, Basics) { EXPECT_THAT(Any(Datum(true), keep_nulls), ResultWith(Datum(true))); EXPECT_THAT(Any(Datum(false), keep_nulls), ResultWith(Datum(false))); - EXPECT_THAT(Any(MakeNullScalar(boolean()), keep_nulls), - ResultWith(Datum(MakeNullScalar(boolean())))); + EXPECT_THAT(Any(Datum(null_value), keep_nulls), ResultWith(Datum(null_value))); + + const ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/2); + this->AssertAnyIs("[]", null_value, min_count); + this->AssertAnyIs("[false]", null_value, min_count); + this->AssertAnyIs("[true, false]", true_value, min_count); + this->AssertAnyIs("[null, null, null]", null_value, min_count); + this->AssertAnyIs("[false, false, false]", false_value, min_count); + this->AssertAnyIs("[false, false, false, null]", false_value, min_count); + this->AssertAnyIs("[true, null, true, true]", true_value, min_count); + this->AssertAnyIs("[false, null, false, true]", true_value, min_count); + this->AssertAnyIs("[true, null, false, true]", true_value, min_count); + this->AssertAnyIs(chunked_input0, null_value, min_count); + this->AssertAnyIs(chunked_input1, true_value, min_count); + this->AssertAnyIs(chunked_input2, false_value, min_count); + this->AssertAnyIs(chunked_input3, false_value, min_count); + this->AssertAnyIs(chunked_input4, true_value, min_count); + + EXPECT_THAT(Any(Datum(true), min_count), ResultWith(Datum(null_value))); + EXPECT_THAT(Any(Datum(false), min_count), ResultWith(Datum(null_value))); + EXPECT_THAT(Any(Datum(null_value), min_count), ResultWith(Datum(null_value))); } // // All // -class TestPrimitiveAllKernel : public ::testing::Test { +class TestAllKernel : public ::testing::Test { public: void AssertAllIs(const Datum& array, const std::shared_ptr& expected, const ScalarAggregateOptions& options) { + SCOPED_TRACE(options.ToString()); ASSERT_OK_AND_ASSIGN(Datum out, All(array, options, nullptr)); const BooleanScalar& out_all = out.scalar_as(); - ASSERT_EQ(out_all, *expected); + AssertScalarsEqual(*expected, out_all, /*verbose=*/true); } void AssertAllIs( const std::string& json, const std::shared_ptr& expected, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) { - auto array = ArrayFromJSON(type_singleton(), json); + SCOPED_TRACE(json); + auto array = ArrayFromJSON(boolean(), json); AssertAllIs(array, expected, options); } @@ -1342,17 +1400,11 @@ class TestPrimitiveAllKernel : public ::testing::Test { const std::vector& json, const std::shared_ptr& expected, const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) { - auto array = ChunkedArrayFromJSON(type_singleton(), json); + auto array = ChunkedArrayFromJSON(boolean(), json); AssertAllIs(array, expected, options); } - - std::shared_ptr type_singleton() { - return TypeTraits::type_singleton(); - } }; -class TestAllKernel : public TestPrimitiveAllKernel {}; - TEST_F(TestAllKernel, Basics) { auto true_value = std::make_shared(true); auto false_value = std::make_shared(false); @@ -1366,27 +1418,28 @@ TEST_F(TestAllKernel, Basics) { std::vector chunked_input4 = {"[true, null]", "[null, false]"}; std::vector chunked_input5 = {"[false, null]", "[null, true]"}; - this->AssertAllIs("[]", true_value); - this->AssertAllIs("[false]", false_value); - this->AssertAllIs("[true, false]", false_value); - this->AssertAllIs("[null, null, null]", true_value); - this->AssertAllIs("[false, false, false]", false_value); - this->AssertAllIs("[false, false, false, null]", false_value); - this->AssertAllIs("[true, null, true, true]", true_value); - this->AssertAllIs("[false, null, false, true]", false_value); - this->AssertAllIs("[true, null, false, true]", false_value); - this->AssertAllIs(chunked_input0, true_value); - this->AssertAllIs(chunked_input1, true_value); - this->AssertAllIs(chunked_input2, false_value); - this->AssertAllIs(chunked_input3, false_value); - this->AssertAllIs(chunked_input4, false_value); - this->AssertAllIs(chunked_input5, false_value); - - EXPECT_THAT(All(Datum(true)), ResultWith(Datum(true))); - EXPECT_THAT(All(Datum(false)), ResultWith(Datum(false))); - EXPECT_THAT(All(MakeNullScalar(boolean())), ResultWith(Datum(true))); - - const ScalarAggregateOptions keep_nulls = ScalarAggregateOptions(/*skip_nulls=*/false); + const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0); + this->AssertAllIs("[]", true_value, options); + this->AssertAllIs("[false]", false_value, options); + this->AssertAllIs("[true, false]", false_value, options); + this->AssertAllIs("[null, null, null]", true_value, options); + this->AssertAllIs("[false, false, false]", false_value, options); + this->AssertAllIs("[false, false, false, null]", false_value, options); + this->AssertAllIs("[true, null, true, true]", true_value, options); + this->AssertAllIs("[false, null, false, true]", false_value, options); + this->AssertAllIs("[true, null, false, true]", false_value, options); + this->AssertAllIs(chunked_input0, true_value, options); + this->AssertAllIs(chunked_input1, true_value, options); + this->AssertAllIs(chunked_input2, false_value, options); + this->AssertAllIs(chunked_input3, false_value, options); + this->AssertAllIs(chunked_input4, false_value, options); + this->AssertAllIs(chunked_input5, false_value, options); + + EXPECT_THAT(All(Datum(true), options), ResultWith(Datum(true))); + EXPECT_THAT(All(Datum(false), options), ResultWith(Datum(false))); + EXPECT_THAT(All(Datum(null_value), options), ResultWith(Datum(true))); + + const ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0); this->AssertAllIs("[]", true_value, keep_nulls); this->AssertAllIs("[false]", false_value, keep_nulls); this->AssertAllIs("[true, false]", false_value, keep_nulls); @@ -1405,8 +1458,28 @@ TEST_F(TestAllKernel, Basics) { EXPECT_THAT(All(Datum(true), keep_nulls), ResultWith(Datum(true))); EXPECT_THAT(All(Datum(false), keep_nulls), ResultWith(Datum(false))); - EXPECT_THAT(All(MakeNullScalar(boolean()), keep_nulls), - ResultWith(Datum(MakeNullScalar(boolean())))); + EXPECT_THAT(All(Datum(null_value), keep_nulls), ResultWith(Datum(null_value))); + + const ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/2); + this->AssertAllIs("[]", null_value, min_count); + this->AssertAllIs("[false]", null_value, min_count); + this->AssertAllIs("[true, false]", false_value, min_count); + this->AssertAllIs("[null, null, null]", null_value, min_count); + this->AssertAllIs("[false, false, false]", false_value, min_count); + this->AssertAllIs("[false, false, false, null]", false_value, min_count); + this->AssertAllIs("[true, null, true, true]", true_value, min_count); + this->AssertAllIs("[false, null, false, true]", false_value, min_count); + this->AssertAllIs("[true, null, false, true]", false_value, min_count); + this->AssertAllIs(chunked_input0, null_value, min_count); + this->AssertAllIs(chunked_input1, true_value, min_count); + this->AssertAllIs(chunked_input2, false_value, min_count); + this->AssertAllIs(chunked_input3, false_value, min_count); + this->AssertAllIs(chunked_input4, false_value, min_count); + this->AssertAllIs(chunked_input5, false_value, min_count); + + EXPECT_THAT(All(Datum(true), min_count), ResultWith(Datum(null_value))); + EXPECT_THAT(All(Datum(false), min_count), ResultWith(Datum(null_value))); + EXPECT_THAT(All(Datum(null_value), min_count), ResultWith(Datum(null_value))); } // diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 9222c5dd18f..e36ff749c89 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -906,18 +906,19 @@ struct GroupedCountImpl : public GroupedAggregator { }; // ---------------------------------------------------------------------- -// Sum implementation +// Sum/Mean/Product implementation -template -struct GroupedSumImpl : public GroupedAggregator { +template +struct GroupedReducingAggregator : public GroupedAggregator { using AccType = typename FindAccumulatorType::Type; - using SumType = typename TypeTraits::CType; + using c_type = typename TypeTraits::CType; Status Init(ExecContext* ctx, const FunctionOptions* options) override { pool_ = ctx->memory_pool(); options_ = checked_cast(*options); - sums_ = BufferBuilder(pool_); - counts_ = BufferBuilder(pool_); + reduced_ = TypedBufferBuilder(pool_); + counts_ = TypedBufferBuilder(pool_); + no_nulls_ = TypedBufferBuilder(pool_); out_type_ = TypeTraits::type_singleton(); return Status::OK(); } @@ -925,81 +926,136 @@ struct GroupedSumImpl : public GroupedAggregator { Status Resize(int64_t new_num_groups) override { auto added_groups = new_num_groups - num_groups_; num_groups_ = new_num_groups; - RETURN_NOT_OK(sums_.Append(added_groups * sizeof(AccType), 0)); - RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0)); + RETURN_NOT_OK(reduced_.Append(added_groups, Impl::NullValue())); + RETURN_NOT_OK(counts_.Append(added_groups, 0)); + RETURN_NOT_OK(no_nulls_.Append(added_groups, true)); return Status::OK(); } Status Consume(const ExecBatch& batch) override { - auto sums = reinterpret_cast(sums_.mutable_data()); - auto counts = reinterpret_cast(counts_.mutable_data()); + c_type* reduced = reduced_.mutable_data(); + int64_t* counts = counts_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); - // XXX this uses naive summation; we should switch to pairwise summation as was - // done for the scalar aggregate kernel in ARROW-11758 auto g = batch[1].array()->GetValues(1); - VisitArrayDataInline( - *batch[0].array(), - [&](typename TypeTraits::CType value) { - sums[*g] += value; - counts[*g] += 1; - ++g; - }, - [&] { ++g; }); - return Status::OK(); + + return Impl::Consume(*batch[0].array(), reduced, counts, no_nulls, g); } Status Merge(GroupedAggregator&& raw_other, const ArrayData& group_id_mapping) override { - auto other = checked_cast(&raw_other); + auto other = checked_cast*>(&raw_other); - auto counts = reinterpret_cast(counts_.mutable_data()); - auto sums = reinterpret_cast(sums_.mutable_data()); + c_type* reduced = reduced_.mutable_data(); + int64_t* counts = counts_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); - auto other_counts = reinterpret_cast(other->counts_.mutable_data()); - auto other_sums = reinterpret_cast(other->sums_.mutable_data()); + const c_type* other_reduced = other->reduced_.data(); + const int64_t* other_counts = other->counts_.data(); + const uint8_t* other_no_nulls = no_nulls_.mutable_data(); auto g = group_id_mapping.GetValues(1); for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { counts[*g] += other_counts[other_g]; - sums[*g] += other_sums[other_g]; + Impl::UpdateGroupWith(reduced, *g, other_reduced[other_g]); + BitUtil::SetBitTo( + no_nulls, *g, + BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g)); } return Status::OK(); } + // Generate the values/nulls buffers + static Result> Finish(MemoryPool* pool, + const ScalarAggregateOptions& options, + const int64_t* counts, + TypedBufferBuilder* reduced, + int64_t num_groups, int64_t* null_count, + std::shared_ptr* null_bitmap) { + for (int64_t i = 0; i < num_groups; ++i) { + if (counts[i] >= options.min_count) continue; + + if ((*null_bitmap) == nullptr) { + ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(num_groups, pool)); + BitUtil::SetBitsTo((*null_bitmap)->mutable_data(), 0, num_groups, true); + } + + (*null_count)++; + BitUtil::SetBitTo((*null_bitmap)->mutable_data(), i, false); + } + return reduced->Finish(); + } + Result Finalize() override { - std::shared_ptr null_bitmap; - const int64_t* counts = reinterpret_cast(counts_.data()); + std::shared_ptr null_bitmap = nullptr; + const int64_t* counts = counts_.data(); int64_t null_count = 0; - for (int64_t i = 0; i < num_groups_; ++i) { - if (counts[i] >= options_.min_count) continue; + ARROW_ASSIGN_OR_RAISE(auto values, + Impl::Finish(pool_, options_, counts, &reduced_, num_groups_, + &null_count, &null_bitmap)); - if (null_bitmap == nullptr) { - ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_)); - BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true); + if (!options_.skip_nulls) { + null_count = kUnknownNullCount; + if (null_bitmap) { + arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0, + no_nulls_.data(), /*right_offset=*/0, num_groups_, + /*out_offset=*/0, null_bitmap->mutable_data()); + } else { + ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish()); } - - null_count += 1; - BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); } - ARROW_ASSIGN_OR_RAISE(auto sums, sums_.Finish()); - - return ArrayData::Make(std::move(out_type_), num_groups_, - {std::move(null_bitmap), std::move(sums)}, null_count); + return ArrayData::Make(out_type(), num_groups_, + {std::move(null_bitmap), std::move(values)}, null_count); } std::shared_ptr out_type() const override { return out_type_; } - // NB: counts are used here instead of a simple "has_values_" bitmap since - // we expect to reuse this kernel to handle Mean int64_t num_groups_ = 0; ScalarAggregateOptions options_; - BufferBuilder sums_, counts_; + TypedBufferBuilder reduced_; + TypedBufferBuilder counts_; + TypedBufferBuilder no_nulls_; std::shared_ptr out_type_; MemoryPool* pool_; }; +// ---------------------------------------------------------------------- +// Sum implementation + +template +struct GroupedSumImpl : public GroupedReducingAggregator> { + using Base = GroupedReducingAggregator>; + using c_type = typename Base::c_type; + + // Default value for a group + static c_type NullValue() { return c_type(0); } + + // Update all groups + static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts, + uint8_t* no_nulls, const uint32_t* g) { + // XXX this uses naive summation; we should switch to pairwise summation as was + // done for the scalar aggregate kernel in ARROW-11758 + VisitArrayDataInline( + values, + [&](typename TypeTraits::CType value) { + reduced[*g] = static_cast(to_unsigned(reduced[*g]) + + to_unsigned(static_cast(value))); + counts[*g++] += 1; + }, + [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); + return Status::OK(); + } + + // Update a single group during merge + static void UpdateGroupWith(c_type* reduced, uint32_t g, c_type value) { + reduced[g] += value; + } + + using Base::Finish; +}; + struct GroupedSumFactory { template ::Type> Status Visit(const T&) { @@ -1030,92 +1086,31 @@ struct GroupedSumFactory { // Product implementation template -struct GroupedProductImpl final : public GroupedAggregator { - using AccType = typename FindAccumulatorType::Type; - using ProductType = typename TypeTraits::CType; +struct GroupedProductImpl final + : public GroupedReducingAggregator> { + using Base = GroupedReducingAggregator>; + using c_type = typename Base::c_type; - Status Init(ExecContext* ctx, const FunctionOptions* options) override { - pool_ = ctx->memory_pool(); - options_ = checked_cast(*options); - products_ = TypedBufferBuilder(pool_); - counts_ = TypedBufferBuilder(pool_); - out_type_ = TypeTraits::type_singleton(); - return Status::OK(); - } + static c_type NullValue() { return c_type(1); } - Status Resize(int64_t new_num_groups) override { - auto added_groups = new_num_groups - num_groups_; - num_groups_ = new_num_groups; - RETURN_NOT_OK(products_.Append(added_groups * sizeof(AccType), 1)); - RETURN_NOT_OK(counts_.Append(added_groups, 0)); - return Status::OK(); - } - - Status Consume(const ExecBatch& batch) override { - ProductType* products = products_.mutable_data(); - int64_t* counts = counts_.mutable_data(); - auto g = batch[1].array()->GetValues(1); + static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts, + uint8_t* no_nulls, const uint32_t* g) { VisitArrayDataInline( - *batch[0].array(), + values, [&](typename TypeTraits::CType value) { - products[*g] = static_cast( - to_unsigned(products[*g]) * to_unsigned(static_cast(value))); + reduced[*g] = static_cast(to_unsigned(reduced[*g]) * + to_unsigned(static_cast(value))); counts[*g++] += 1; }, - [&] { ++g; }); + [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); return Status::OK(); } - Status Merge(GroupedAggregator&& raw_other, - const ArrayData& group_id_mapping) override { - auto other = checked_cast(&raw_other); - - int64_t* counts = counts_.mutable_data(); - ProductType* products = products_.mutable_data(); - - const int64_t* other_counts = other->counts_.mutable_data(); - const ProductType* other_products = other->products_.mutable_data(); - const uint32_t* g = group_id_mapping.GetValues(1); - - for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { - products[*g] = static_cast(to_unsigned(products[*g]) * - to_unsigned(other_products[other_g])); - counts[*g] += other_counts[other_g]; - } - return Status::OK(); + static void UpdateGroupWith(c_type* reduced, uint32_t g, c_type value) { + reduced[g] *= value; } - Result Finalize() override { - ARROW_ASSIGN_OR_RAISE(auto products, products_.Finish()); - const int64_t* counts = counts_.data(); - - std::shared_ptr null_bitmap; - int64_t null_count = 0; - - for (int64_t i = 0; i < num_groups_; ++i) { - if (counts[i] >= options_.min_count) continue; - - if (null_bitmap == nullptr) { - ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_)); - BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true); - } - - null_count += 1; - BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); - } - - return ArrayData::Make(std::move(out_type_), num_groups_, - {std::move(null_bitmap), std::move(products)}, null_count); - } - - std::shared_ptr out_type() const override { return out_type_; } - - int64_t num_groups_ = 0; - ScalarAggregateOptions options_; - TypedBufferBuilder products_; - TypedBufferBuilder counts_; - std::shared_ptr out_type_; - MemoryPool* pool_; + using Base::Finish; }; struct GroupedProductFactory { @@ -1149,44 +1144,60 @@ struct GroupedProductFactory { // Mean implementation template -struct GroupedMeanImpl : public GroupedSumImpl { - Result Finalize() override { - using SumType = typename GroupedSumImpl::SumType; - std::shared_ptr null_bitmap; - ARROW_ASSIGN_OR_RAISE(std::shared_ptr values, - AllocateBuffer(num_groups_ * sizeof(double), pool_)); - int64_t null_count = 0; +struct GroupedMeanImpl : public GroupedReducingAggregator> { + using Base = GroupedReducingAggregator>; + using c_type = typename Base::c_type; - const int64_t* counts = reinterpret_cast(counts_.data()); - const auto* sums = reinterpret_cast(sums_.data()); + static c_type NullValue() { return c_type(0); } + + static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts, + uint8_t* no_nulls, const uint32_t* g) { + // XXX this uses naive summation; we should switch to pairwise summation as was + // done for the scalar aggregate kernel in ARROW-11758 + VisitArrayDataInline( + values, + [&](typename TypeTraits::CType value) { + reduced[*g] = static_cast(to_unsigned(reduced[*g]) + + to_unsigned(static_cast(value))); + counts[*g++] += 1; + }, + [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); + return Status::OK(); + } + + static void UpdateGroupWith(c_type* reduced, uint32_t g, c_type value) { + reduced[g] += value; + } + + static Result> Finish(MemoryPool* pool, + const ScalarAggregateOptions& options, + const int64_t* counts, + TypedBufferBuilder* reduced_, + int64_t num_groups, int64_t* null_count, + std::shared_ptr* null_bitmap) { + const c_type* reduced = reduced_->data(); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr values, + AllocateBuffer(num_groups * sizeof(double), pool)); double* means = reinterpret_cast(values->mutable_data()); - for (int64_t i = 0; i < num_groups_; ++i) { - if (counts[i] >= options_.min_count) { - means[i] = static_cast(sums[i]) / counts[i]; + for (int64_t i = 0; i < num_groups; ++i) { + if (counts[i] >= options.min_count) { + means[i] = static_cast(reduced[i]) / counts[i]; continue; } means[i] = 0; - if (null_bitmap == nullptr) { - ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_)); - BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true); + if ((*null_bitmap) == nullptr) { + ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(num_groups, pool)); + BitUtil::SetBitsTo((*null_bitmap)->mutable_data(), 0, num_groups, true); } - null_count += 1; - BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); + (*null_count)++; + BitUtil::SetBitTo((*null_bitmap)->mutable_data(), i, false); } - - return ArrayData::Make(float64(), num_groups_, - {std::move(null_bitmap), std::move(values)}, null_count); + return std::move(values); } std::shared_ptr out_type() const override { return float64(); } - - using GroupedSumImpl::num_groups_; - using GroupedSumImpl::options_; - using GroupedSumImpl::pool_; - using GroupedSumImpl::counts_; - using GroupedSumImpl::sums_; }; struct GroupedMeanFactory { @@ -1739,167 +1750,158 @@ struct GroupedMinMaxFactory { // ---------------------------------------------------------------------- // Any/All implementation -struct GroupedAnyImpl : public GroupedAggregator { +template +struct GroupedBooleanAggregator : public GroupedAggregator { Status Init(ExecContext* ctx, const FunctionOptions* options) override { - options_ = *checked_cast(options); - seen_ = TypedBufferBuilder(ctx->memory_pool()); - has_nulls_ = TypedBufferBuilder(ctx->memory_pool()); + options_ = checked_cast(*options); + pool_ = ctx->memory_pool(); + seen_ = TypedBufferBuilder(pool_); + no_nulls_ = TypedBufferBuilder(pool_); + counts_ = TypedBufferBuilder(pool_); return Status::OK(); } Status Resize(int64_t new_num_groups) override { auto added_groups = new_num_groups - num_groups_; num_groups_ = new_num_groups; - RETURN_NOT_OK(seen_.Append(added_groups, false)); - return has_nulls_.Append(added_groups, false); + RETURN_NOT_OK(seen_.Append(added_groups, Impl::NullValue())); + RETURN_NOT_OK(no_nulls_.Append(added_groups, true)); + return counts_.Append(added_groups, 0); + } + + Status Consume(const ExecBatch& batch) override { + uint8_t* seen = seen_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); + int64_t* counts = counts_.mutable_data(); + const auto& input = *batch[0].array(); + auto g = batch[1].array()->GetValues(1); + + if (input.MayHaveNulls()) { + const uint8_t* bitmap = input.buffers[1]->data(); + arrow::internal::VisitBitBlocksVoid( + input.buffers[0], input.offset, input.length, + [&](int64_t position) { + counts[*g]++; + Impl::UpdateGroupWith(seen, *g, BitUtil::GetBit(bitmap, position)); + g++; + }, + [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); + } else { + arrow::internal::VisitBitBlocksVoid( + input.buffers[1], input.offset, input.length, + [&](int64_t) { + Impl::UpdateGroupWith(seen, *g, true); + counts[*g++]++; + }, + [&]() { + Impl::UpdateGroupWith(seen, *g, false); + counts[*g++]++; + }); + } + return Status::OK(); } Status Merge(GroupedAggregator&& raw_other, const ArrayData& group_id_mapping) override { - auto other = checked_cast(&raw_other); + auto other = checked_cast*>(&raw_other); + + uint8_t* seen = seen_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); + int64_t* counts = counts_.mutable_data(); - auto seen = seen_.mutable_data(); - auto other_seen = other->seen_.data(); - auto has_nulls = has_nulls_.mutable_data(); - auto other_has_nulls = other->has_nulls_.data(); + const uint8_t* other_seen = other->seen_.mutable_data(); + const uint8_t* other_no_nulls = other->no_nulls_.mutable_data(); + const int64_t* other_counts = other->counts_.mutable_data(); auto g = group_id_mapping.GetValues(1); for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { - if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true); - if (BitUtil::GetBit(other_has_nulls, other_g)) { - BitUtil::SetBitTo(has_nulls, *g, true); - } + counts[*g] += other_counts[other_g]; + Impl::UpdateGroupWith(seen, *g, BitUtil::GetBit(other_seen, other_g)); + BitUtil::SetBitTo( + no_nulls, *g, + BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g)); } return Status::OK(); } - Status Consume(const ExecBatch& batch) override { - auto seen = seen_.mutable_data(); - auto has_nulls = has_nulls_.mutable_data(); + Result Finalize() override { + std::shared_ptr null_bitmap; + const int64_t* counts = counts_.data(); + int64_t null_count = 0; - const auto& input = *batch[0].array(); + for (int64_t i = 0; i < num_groups_; ++i) { + if (counts[i] >= options_.min_count) continue; - auto g = batch[1].array()->GetValues(1); - auto values = input.buffers[1]->data(); - arrow::internal::VisitBitBlocksVoid( - input.buffers[0], input.offset, input.length, - [&](int64_t offset) { - BitUtil::SetBitTo(seen, *g, - BitUtil::GetBit(seen, *g) || - BitUtil::GetBit(values, input.offset + offset)); - g++; - }, - [&]() { BitUtil::SetBitTo(has_nulls, *g++, true); }); - return Status::OK(); - } + if (null_bitmap == nullptr) { + ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_)); + BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true); + } + + null_count += 1; + BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); + } - Result Finalize() override { ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish()); - if (options_.skip_nulls) { - return std::make_shared(num_groups_, std::move(seen)); + if (!options_.skip_nulls) { + null_count = kUnknownNullCount; + ARROW_ASSIGN_OR_RAISE(auto no_nulls, no_nulls_.Finish()); + Impl::AdjustForMinCount(no_nulls->mutable_data(), seen->data(), num_groups_); + if (null_bitmap) { + arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0, + no_nulls->data(), /*right_offset=*/0, num_groups_, + /*out_offset=*/0, null_bitmap->mutable_data()); + } else { + null_bitmap = std::move(no_nulls); + } } - ARROW_ASSIGN_OR_RAISE(auto bitmap, has_nulls_.Finish()); - // null if (~seen & has_nulls) -> not null if (seen | ~has_nulls) - ::arrow::internal::BitmapOrNot(seen->data(), /*left_offset=*/0, bitmap->data(), - /*right_offset=*/0, num_groups_, /*out_offset=*/0, - bitmap->mutable_data()); - return std::make_shared(num_groups_, std::move(seen), - std::move(bitmap)); + + return ArrayData::Make(out_type(), num_groups_, + {std::move(null_bitmap), std::move(seen)}, null_count); } std::shared_ptr out_type() const override { return boolean(); } int64_t num_groups_ = 0; ScalarAggregateOptions options_; - TypedBufferBuilder seen_; - TypedBufferBuilder has_nulls_; + TypedBufferBuilder seen_, no_nulls_; + TypedBufferBuilder counts_; + MemoryPool* pool_; }; -struct GroupedAllImpl : public GroupedAggregator { - Status Init(ExecContext* ctx, const FunctionOptions* options) override { - options_ = *checked_cast(options); - seen_ = TypedBufferBuilder(ctx->memory_pool()); - has_nulls_ = TypedBufferBuilder(ctx->memory_pool()); - return Status::OK(); - } - - Status Resize(int64_t new_num_groups) override { - auto added_groups = new_num_groups - num_groups_; - num_groups_ = new_num_groups; - RETURN_NOT_OK(seen_.Append(added_groups, true)); - return has_nulls_.Append(added_groups, false); - } +struct GroupedAnyImpl : public GroupedBooleanAggregator { + // The default value for a group. + static bool NullValue() { return false; } - Status Merge(GroupedAggregator&& raw_other, - const ArrayData& group_id_mapping) override { - auto other = checked_cast(&raw_other); - - auto seen = seen_.mutable_data(); - auto other_seen = other->seen_.data(); - auto has_nulls = has_nulls_.mutable_data(); - auto other_has_nulls = other->has_nulls_.data(); - - auto g = group_id_mapping.GetValues(1); - for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { - BitUtil::SetBitTo( - seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g)); - if (BitUtil::GetBit(other_has_nulls, other_g)) { - BitUtil::SetBitTo(has_nulls, *g, true); - } + // Update the value for a group given an observation. + static void UpdateGroupWith(uint8_t* seen, uint32_t g, bool value) { + if (!BitUtil::GetBit(seen, g) && value) { + BitUtil::SetBit(seen, g); } - return Status::OK(); } - Status Consume(const ExecBatch& batch) override { - auto seen = seen_.mutable_data(); - auto has_nulls = has_nulls_.mutable_data(); + // Combine the array of observed nulls with the array of group values. + static void AdjustForMinCount(uint8_t* no_nulls, const uint8_t* seen, + int64_t num_groups) { + arrow::internal::BitmapOr(no_nulls, /*left_offset=*/0, seen, /*right_offset=*/0, + num_groups, /*out_offset=*/0, no_nulls); + } +}; - const auto& input = *batch[0].array(); +struct GroupedAllImpl : public GroupedBooleanAggregator { + static bool NullValue() { return true; } - auto g = batch[1].array()->GetValues(1); - if (input.MayHaveNulls()) { - const uint8_t* bitmap = input.buffers[1]->data(); - arrow::internal::VisitBitBlocksVoid( - input.buffers[0], input.offset, input.length, - [&](int64_t position) { - BitUtil::SetBitTo(seen, *g, - BitUtil::GetBit(seen, *g) && - BitUtil::GetBit(bitmap, input.offset + position)); - g++; - }, - [&]() { BitUtil::SetBitTo(has_nulls, *g++, true); }); - } else { - arrow::internal::VisitBitBlocksVoid( - input.buffers[1], input.offset, input.length, [&](int64_t) { g++; }, - [&]() { BitUtil::SetBitTo(seen, *g++, false); }); + static void UpdateGroupWith(uint8_t* seen, uint32_t g, bool value) { + if (!value) { + BitUtil::ClearBit(seen, g); } - return Status::OK(); } - Result Finalize() override { - ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish()); - if (options_.skip_nulls) { - return std::make_shared(num_groups_, std::move(seen)); - } - ARROW_ASSIGN_OR_RAISE(auto bitmap, has_nulls_.Finish()); - // null if (seen & has_nulls) - ::arrow::internal::BitmapAnd(seen->data(), /*left_offset=*/0, bitmap->data(), - /*right_offset=*/0, num_groups_, /*out_offset=*/0, - bitmap->mutable_data()); - ::arrow::internal::InvertBitmap(bitmap->data(), /*offset=*/0, num_groups_, - bitmap->mutable_data(), /*dest_offset=*/0); - return std::make_shared(num_groups_, std::move(seen), - std::move(bitmap)); + static void AdjustForMinCount(uint8_t* no_nulls, const uint8_t* seen, + int64_t num_groups) { + arrow::internal::BitmapOrNot(no_nulls, /*left_offset=*/0, seen, /*right_offset=*/0, + num_groups, /*out_offset=*/0, no_nulls); } - - std::shared_ptr out_type() const override { return boolean(); } - - int64_t num_groups_ = 0; - ScalarAggregateOptions options_; - TypedBufferBuilder seen_; - TypedBufferBuilder has_nulls_; }; - } // namespace Result> GetKernels( @@ -2176,7 +2178,8 @@ const FunctionDoc hash_product_doc{ "Compute product of values of a numeric array", ("Null values are ignored.\n" "Overflow will wrap around as if the calculation was done with unsigned integers."), - {"array", "group_id_array"}}; + {"array", "group_id_array"}, + "ScalarAggregateOptions"}; const FunctionDoc hash_mean_doc{"Average values of a numeric array", ("Null values are ignored."), @@ -2216,11 +2219,13 @@ const FunctionDoc hash_min_max_doc{ const FunctionDoc hash_any_doc{"Test whether any element evaluates to true", ("Null values are ignored."), - {"array", "group_id_array"}}; + {"array", "group_id_array"}, + "ScalarAggregateOptions"}; const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true", ("Null values are ignored."), - {"array", "group_id_array"}}; + {"array", "group_id_array"}, + "ScalarAggregateOptions"}; } // namespace void RegisterHashAggregateBasic(FunctionRegistry* registry) { diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index e96fdcd6084..21440248493 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1108,6 +1108,10 @@ TEST(GroupBy, AnyAndAll) { [null, 3] ])"}); + ScalarAggregateOptions no_min(/*skip_nulls=*/true, /*min_count=*/0); + ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3); + ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0); + ScalarAggregateOptions keep_nulls_min_count(/*skip_nulls=*/false, /*min_count=*/3); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, internal::GroupBy( { @@ -1115,13 +1119,21 @@ TEST(GroupBy, AnyAndAll) { table->GetColumnByName("argument"), table->GetColumnByName("argument"), table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), }, {table->GetColumnByName("key")}, { - {"hash_any", nullptr}, - {"hash_all", nullptr}, - {"hash_any", &options}, - {"hash_all", &options}, + {"hash_any", &no_min}, + {"hash_any", &min_count}, + {"hash_any", &keep_nulls}, + {"hash_any", &keep_nulls_min_count}, + {"hash_all", &no_min}, + {"hash_all", &min_count}, + {"hash_all", &keep_nulls}, + {"hash_all", &keep_nulls_min_count}, }, use_threads)); SortBy({"key_0"}, &aggregated_and_grouped); @@ -1134,18 +1146,22 @@ TEST(GroupBy, AnyAndAll) { // Group null: falses AssertDatumsEqual(ArrayFromJSON(struct_({ field("hash_any", boolean()), - field("hash_all", boolean()), field("hash_any", boolean()), + field("hash_any", boolean()), + field("hash_any", boolean()), + field("hash_all", boolean()), + field("hash_all", boolean()), + field("hash_all", boolean()), field("hash_all", boolean()), field("key_0", int64()), }), R"([ - [true, true, true, null, 1], - [true, false, true, false, 2], - [false, true, null, null, 3], - [false, false, null, false, 4], - [true, true, true, true, 5], - [false, false, false, false, null] + [true, null, true, null, true, null, null, null, 1], + [true, true, true, true, false, false, false, false, 2], + [false, null, null, null, true, null, null, null, 3], + [false, null, null, null, false, null, false, null, 4], + [true, null, true, null, true, null, true, null, 5], + [false, null, false, null, false, null, false, null, null] ])"), aggregated_and_grouped, /*verbose=*/true); @@ -1290,6 +1306,64 @@ TEST(GroupBy, Product) { /*verbose=*/true); } +TEST(GroupBy, SumMeanProductKeepNulls) { + auto batch = RecordBatchFromJSON( + schema({field("argument", float64()), field("key", int64())}), R"([ + [-1.0, 1], + [null, 1], + [0.0, 2], + [null, 3], + [4.0, null], + [3.25, 1], + [0.125, 2], + [-0.25, 2], + [0.75, null], + [null, 3] + ])"); + + ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false); + ScalarAggregateOptions min_count(/*skip_nulls=*/false, /*min_count=*/3); + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + internal::GroupBy( + { + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + }, + { + batch->GetColumnByName("key"), + }, + { + {"hash_sum", &keep_nulls}, + {"hash_sum", &min_count}, + {"hash_mean", &keep_nulls}, + {"hash_mean", &min_count}, + {"hash_product", &keep_nulls}, + {"hash_product", &min_count}, + })); + + AssertDatumsApproxEqual(ArrayFromJSON(struct_({ + field("hash_sum", float64()), + field("hash_sum", float64()), + field("hash_mean", float64()), + field("hash_mean", float64()), + field("hash_product", float64()), + field("hash_product", float64()), + field("key_0", int64()), + }), + R"([ + [null, null, null, null, null, null, 1], + [-0.125, -0.125, -0.0416667, -0.0416667, 0.0, 0.0, 2], + [null, null, null, null, null, null, 3], + [4.75, null, 2.375, null, 3.0, null, null] + ])"), + aggregated_and_grouped, + /*verbose=*/true); +} + TEST(GroupBy, SumOnlyStringAndDictKeys) { for (auto key_type : {utf8(), dictionary(int32(), utf8())}) { SCOPED_TRACE("key type: " + key_type->ToString()); diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 6c1684b87dc..cac7a3c95e5 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -551,7 +551,13 @@ def test_min_max(): def test_any(): # ARROW-1846 - options = pc.ScalarAggregateOptions(skip_nulls=False) + options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0) + + a = pa.array([], type='bool') + assert pc.any(a).as_py() is None + assert pc.any(a, min_count=0).as_py() is False + assert pc.any(a, options=options).as_py() is False + a = pa.array([False, None, True]) assert pc.any(a).as_py() is True assert pc.any(a, options=options).as_py() is True @@ -564,9 +570,11 @@ def test_any(): def test_all(): # ARROW-10301 - options = pc.ScalarAggregateOptions(skip_nulls=False) + options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0) + a = pa.array([], type='bool') - assert pc.all(a).as_py() is True + assert pc.all(a).as_py() is None + assert pc.all(a, min_count=0).as_py() is True assert pc.all(a, options=options).as_py() is True a = pa.array([False, True]) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index a04c0efcdae..4e1e4dcb859 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -300,8 +300,8 @@ ExecNode_Project <- function(input, exprs, names) { .Call(`_arrow_ExecNode_Project`, input, exprs, names) } -ExecNode_Aggregate <- function(input, options, target_names, out_field_names, key_names){ - .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names, out_field_names, key_names) +ExecNode_Aggregate <- function(input, options, target_names, out_field_names, key_names) { + .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names, out_field_names, key_names) } RecordBatch__cast <- function(batch, schema, options) { @@ -316,10 +316,6 @@ compute__CallFunction <- function(func_name, args, options) { .Call(`_arrow_compute__CallFunction`, func_name, args, options) } -compute__GroupBy <- function(arguments, keys, options) { - .Call(`_arrow_compute__GroupBy`, arguments, keys, options) -} - compute__GetFunctionNames <- function() { .Call(`_arrow_compute__GetFunctionNames`) } diff --git a/r/R/compute.R b/r/R/compute.R index 8cfaaf7b415..3953eaa5dfc 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -122,12 +122,6 @@ max.ArrowDatum <- function(..., na.rm = FALSE) { scalar_aggregate <- function(FUN, ..., na.rm = FALSE, na.min_count = 0) { a <- collect_arrays_from_dots(list(...)) - if (!na.rm) { - # When not removing null values, we require all values to be not null and - # return null otherwise. We do that by setting minimum count of non-null - # option values to the full array length. - na.min_count <- length(a) - } if (FUN == "min_max" && na.rm && a$null_count == length(a)) { Array$create(data.frame(min = Inf, max = -Inf)) # If na.rm == TRUE and all values in array are NA, R returns diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 9972d4796a8..efba9f287f9 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -791,28 +791,20 @@ agg_funcs$sum <- function(x, na.rm = FALSE) { list( fun = "sum", data = x, - options = arrow_na_rm(na.rm = na.rm) + options = list(na.rm = na.rm, na.min_count = 0L) ) } agg_funcs$any <- function(x, na.rm = FALSE) { list( fun = "any", data = x, - options = arrow_na_rm(na.rm) + options = list(na.rm = na.rm, na.min_count = 0L) ) } agg_funcs$all <- function(x, na.rm = FALSE) { list( fun = "all", data = x, - options = arrow_na_rm(na.rm) + options = list(na.rm = na.rm, na.min_count = 0L) ) } - -arrow_na_rm <- function(na.rm) { - if (!isTRUE(na.rm)) { - # TODO: ARROW-13497 - arrow_not_supported(paste("na.rm =", na.rm)) - } - list(na.rm = na.rm, na.min_count = 0L) -} diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 0645a64eac9..5ef39215c73 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1244,23 +1244,6 @@ extern "C" SEXP _arrow_compute__CallFunction(SEXP func_name_sexp, SEXP args_sexp } #endif -// compute.cpp -#if defined(ARROW_R_WITH_ARROW) -SEXP compute__GroupBy(cpp11::list arguments, cpp11::list keys, cpp11::list options); -extern "C" SEXP _arrow_compute__GroupBy(SEXP arguments_sexp, SEXP keys_sexp, SEXP options_sexp){ -BEGIN_CPP11 - arrow::r::Input::type arguments(arguments_sexp); - arrow::r::Input::type keys(keys_sexp); - arrow::r::Input::type options(options_sexp); - return cpp11::as_sexp(compute__GroupBy(arguments, keys, options)); -END_CPP11 -} -#else -extern "C" SEXP _arrow_compute__GroupBy(SEXP arguments_sexp, SEXP keys_sexp, SEXP options_sexp){ - Rf_error("Cannot call compute__GroupBy(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); -} -#endif - // compute.cpp #if defined(ARROW_R_WITH_ARROW) std::vector compute__GetFunctionNames(); @@ -7136,7 +7119,6 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_RecordBatch__cast", (DL_FUNC) &_arrow_RecordBatch__cast, 3}, { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, { "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3}, - { "_arrow_compute__GroupBy", (DL_FUNC) &_arrow_compute__GroupBy, 3}, { "_arrow_compute__GetFunctionNames", (DL_FUNC) &_arrow_compute__GetFunctionNames, 0}, { "_arrow_build_info", (DL_FUNC) &_arrow_build_info, 0}, { "_arrow_runtime_info", (DL_FUNC) &_arrow_runtime_info, 0}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 0695e2525f7..b697ecd96a0 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -172,7 +172,9 @@ std::shared_ptr make_compute_options( } if (func_name == "min_max" || func_name == "sum" || func_name == "mean" || - func_name == "any" || func_name == "all") { + func_name == "any" || func_name == "all" || func_name == "hash_min_max" || + func_name == "hash_sum" || func_name == "hash_mean" || func_name == "hash_any" || + func_name == "hash_all") { using Options = arrow::compute::ScalarAggregateOptions; auto out = std::make_shared(Options::Defaults()); out->min_count = cpp11::as_cpp(options["na.min_count"]); @@ -390,29 +392,6 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list return from_datum(std::move(out)); } -// [[arrow::export]] -SEXP compute__GroupBy(cpp11::list arguments, cpp11::list keys, cpp11::list options) { - // options is a list of pairs: string function name, list of options - - std::vector> keep_alives; - std::vector aggregates; - - for (cpp11::list name_opts : options) { - auto name = cpp11::as_cpp(name_opts[0]); - auto opts = make_compute_options(name, name_opts[1]); - - aggregates.push_back( - arrow::compute::internal::Aggregate{std::move(name), opts.get()}); - keep_alives.push_back(std::move(opts)); - } - - auto datum_arguments = arrow::r::from_r_list(arguments); - auto datum_keys = arrow::r::from_r_list(keys); - auto out = ValueOrStop(arrow::compute::internal::GroupBy(datum_arguments, datum_keys, - aggregates, gc_context())); - return from_datum(std::move(out)); -} - // [[arrow::export]] std::vector compute__GetFunctionNames() { return arrow::compute::GetFunctionRegistry()->GetFunctionNames(); diff --git a/r/tests/testthat/test-dplyr-aggregate.R b/r/tests/testthat/test-dplyr-aggregate.R index 8235ef29948..ef2929ee65b 100644 --- a/r/tests/testthat/test-dplyr-aggregate.R +++ b/r/tests/testthat/test-dplyr-aggregate.R @@ -59,9 +59,7 @@ test_that("Can aggregate in Arrow", { input %>% summarize(total = sum(int)) %>% collect(), - tbl, - # ARROW-13497: This is failing because the default is na.rm = FALSE - warning = TRUE + tbl ) }) @@ -90,9 +88,7 @@ test_that("Group by sum on dataset", { summarize(total = sum(int)) %>% arrange(some_grouping) %>% collect(), - tbl, - # ARROW-13497: This is failing because the default is na.rm = FALSE - warning = TRUE + tbl ) }) @@ -115,7 +111,22 @@ test_that("Group by any/all", { collect(), tbl ) - # ARROW-13497: na.rm option also is not being passed/received to any/all + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize(any(lgl, na.rm = FALSE)) %>% + arrange(some_grouping) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize(all(lgl, na.rm = FALSE)) %>% + arrange(some_grouping) %>% + collect(), + tbl + ) expect_dplyr_equal( input %>%