From b40629443433f387e89495bed1b537ca175a0c59 Mon Sep 17 00:00:00 2001
From: David Li
Date: Sat, 14 Aug 2021 16:00:13 -0400
Subject: [PATCH 1/6] ARROW-13627: [C++] Support skip_nulls in sum/mean/product
---
cpp/src/arrow/compute/api_aggregate.h | 5 +-
.../arrow/compute/kernels/aggregate_basic.cc | 31 +-
.../kernels/aggregate_basic_internal.h | 16 +-
.../arrow/compute/kernels/aggregate_test.cc | 284 +++++++-------
.../arrow/compute/kernels/hash_aggregate.cc | 346 ++++++++++++------
.../compute/kernels/hash_aggregate_test.cc | 96 ++++-
python/pyarrow/tests/test_compute.py | 14 +-
r/R/arrowExports.R | 8 +-
r/R/compute.R | 6 -
r/R/dplyr-functions.R | 14 +-
r/src/arrowExports.cpp | 18 -
r/src/compute.cpp | 27 +-
r/tests/testthat/test-dplyr-aggregate.R | 25 +-
13 files changed, 563 insertions(+), 327 deletions(-)
diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h
index 880424e97f8..d8cda022de8 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -42,14 +42,17 @@ class ExecContext;
/// \brief Control general scalar aggregate kernel behavior
///
-/// By default, null values are ignored
+/// By default, null values are ignored (skip_nulls = true).
class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions {
public:
explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1);
constexpr static char const kTypeName[] = "ScalarAggregateOptions";
static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; }
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
uint32_t min_count;
};
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index 548a008c5ce..a3fafaae75d 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -162,6 +162,13 @@ struct ProductImpl : public ScalarAggregator {
if (batch[0].is_array()) {
const auto& data = batch[0].array();
this->count += data->length - data->GetNullCount();
+ this->nulls_observed = this->nulls_observed || data->GetNullCount();
+
+ if (!options.skip_nulls && this->nulls_observed) {
+ // Short-circuit
+ return Status::OK();
+ }
+
VisitArrayDataInline(
*data,
[&](typename TypeTraits::CType value) {
@@ -172,6 +179,7 @@ struct ProductImpl : public ScalarAggregator {
} else {
const auto& data = *batch[0].scalar();
this->count += data.is_valid * batch.length;
+ this->nulls_observed = this->nulls_observed || !data.is_valid;
if (data.is_valid) {
for (int64_t i = 0; i < batch.length; i++) {
auto value = internal::UnboxScalar::Unbox(data);
@@ -188,11 +196,13 @@ struct ProductImpl : public ScalarAggregator {
this->count += other.count;
this->product =
static_cast(to_unsigned(this->product) * to_unsigned(other.product));
+ this->nulls_observed = this->nulls_observed || other.nulls_observed;
return Status::OK();
}
Status Finalize(KernelContext*, Datum* out) override {
- if (this->count < options.min_count) {
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count)) {
out->value = std::make_shared();
} else {
out->value = MakeScalar(this->product);
@@ -201,6 +211,7 @@ struct ProductImpl : public ScalarAggregator {
}
size_t count = 0;
+ bool nulls_observed = false;
typename AccType::c_type product = 1;
ScalarAggregateOptions options;
};
@@ -268,17 +279,19 @@ struct BooleanAnyImpl : public ScalarAggregator {
Status Consume(KernelContext*, const ExecBatch& batch) override {
// short-circuit if seen a True already
- if (this->any == true) {
+ if (this->any == true && this->count >= options.min_count) {
return Status::OK();
}
if (batch[0].is_scalar()) {
const auto& scalar = *batch[0].scalar();
this->has_nulls = !scalar.is_valid;
this->any = scalar.is_valid && checked_cast(scalar).value;
+ this->count += scalar.is_valid;
return Status::OK();
}
const auto& data = *batch[0].array();
this->has_nulls = data.GetNullCount() > 0;
+ this->count += data.length - data.GetNullCount();
arrow::internal::OptionalBinaryBitBlockCounter counter(
data.buffers[0], data.offset, data.buffers[1], data.offset, data.length);
int64_t position = 0;
@@ -297,11 +310,13 @@ struct BooleanAnyImpl : public ScalarAggregator {
const auto& other = checked_cast(src);
this->any |= other.any;
this->has_nulls |= other.has_nulls;
+ this->count += other.count;
return Status::OK();
}
Status Finalize(KernelContext* ctx, Datum* out) override {
- if (!options.skip_nulls && !this->any && this->has_nulls) {
+ if ((!options.skip_nulls && !this->any && this->has_nulls) ||
+ this->count < options.min_count) {
out->value = std::make_shared();
} else {
out->value = std::make_shared(this->any);
@@ -311,6 +326,7 @@ struct BooleanAnyImpl : public ScalarAggregator {
bool any = false;
bool has_nulls = false;
+ int64_t count = 0;
ScalarAggregateOptions options;
};
@@ -329,7 +345,7 @@ struct BooleanAllImpl : public ScalarAggregator {
Status Consume(KernelContext*, const ExecBatch& batch) override {
// short-circuit if seen a false already
- if (this->all == false) {
+ if (this->all == false && this->count >= options.min_count) {
return Status::OK();
}
// short-circuit if seen a null already
@@ -339,11 +355,13 @@ struct BooleanAllImpl : public ScalarAggregator {
if (batch[0].is_scalar()) {
const auto& scalar = *batch[0].scalar();
this->has_nulls = !scalar.is_valid;
+ this->count += scalar.is_valid;
this->all = !scalar.is_valid || checked_cast(scalar).value;
return Status::OK();
}
const auto& data = *batch[0].array();
this->has_nulls = data.GetNullCount() > 0;
+ this->count += data.length - data.GetNullCount();
arrow::internal::OptionalBinaryBitBlockCounter counter(
data.buffers[1], data.offset, data.buffers[0], data.offset, data.length);
int64_t position = 0;
@@ -363,11 +381,13 @@ struct BooleanAllImpl : public ScalarAggregator {
const auto& other = checked_cast(src);
this->all &= other.all;
this->has_nulls |= other.has_nulls;
+ this->count += other.count;
return Status::OK();
}
Status Finalize(KernelContext*, Datum* out) override {
- if (!options.skip_nulls && this->all && this->has_nulls) {
+ if ((!options.skip_nulls && this->all && this->has_nulls) ||
+ this->count < options.min_count) {
out->value = std::make_shared();
} else {
out->value = std::make_shared(this->all);
@@ -377,6 +397,7 @@ struct BooleanAllImpl : public ScalarAggregator {
bool all = true;
bool has_nulls = false;
+ int64_t count = 0;
ScalarAggregateOptions options;
};
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
index eb314011229..b355a2e1b75 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
@@ -68,6 +68,13 @@ struct SumImpl : public ScalarAggregator {
if (batch[0].is_array()) {
const auto& data = batch[0].array();
this->count += data->length - data->GetNullCount();
+ this->nulls_observed = this->nulls_observed || data->GetNullCount();
+
+ if (!options.skip_nulls && this->nulls_observed) {
+ // Short-circuit
+ return Status::OK();
+ }
+
if (is_boolean_type::value) {
this->sum +=
static_cast(BooleanArray(data).true_count());
@@ -79,6 +86,7 @@ struct SumImpl : public ScalarAggregator {
} else {
const auto& data = *batch[0].scalar();
this->count += data.is_valid * batch.length;
+ this->nulls_observed = this->nulls_observed || !data.is_valid;
if (data.is_valid) {
this->sum += internal::UnboxScalar::Unbox(data) * batch.length;
}
@@ -90,11 +98,13 @@ struct SumImpl : public ScalarAggregator {
const auto& other = checked_cast(src);
this->count += other.count;
this->sum += other.sum;
+ this->nulls_observed = this->nulls_observed || other.nulls_observed;
return Status::OK();
}
Status Finalize(KernelContext*, Datum* out) override {
- if (this->count < options.min_count) {
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count)) {
out->value = std::make_shared();
} else {
out->value = MakeScalar(this->sum);
@@ -103,6 +113,7 @@ struct SumImpl : public ScalarAggregator {
}
size_t count = 0;
+ bool nulls_observed = false;
typename SumType::c_type sum = 0;
ScalarAggregateOptions options;
};
@@ -110,7 +121,8 @@ struct SumImpl : public ScalarAggregator {
template
struct MeanImpl : public SumImpl {
Status Finalize(KernelContext*, Datum* out) override {
- if (this->count < options.min_count) {
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count)) {
out->value = std::make_shared();
} else {
const double mean = static_cast(this->sum) / this->count;
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index aca693601c3..34346d6ca5c 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -139,14 +139,12 @@ template
void ValidateBooleanAgg(const std::string& json,
const std::shared_ptr& expected,
const ScalarAggregateOptions& options) {
+ SCOPED_TRACE(json);
auto array = ArrayFromJSON(boolean(), json);
ASSERT_OK_AND_ASSIGN(Datum result, Op(array, options, nullptr));
- const auto& exp = Datum(expected);
- const auto& res = checked_pointer_cast(result.scalar());
- if (!(std::isnan((double)res->value) && std::isnan((double)expected->value))) {
- ASSERT_TRUE(result.Equals(exp));
- }
+ auto equal_options = EqualOptions::Defaults().nans_equal(true);
+ AssertScalarsEqual(*expected, *result.scalar(), /*verbose=*/true, equal_options);
}
TEST(TestBooleanAggregation, Sum) {
@@ -174,12 +172,10 @@ TEST(TestBooleanAggregation, Sum) {
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2));
ValidateBooleanAgg(json, std::make_shared(),
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
- ValidateBooleanAgg(json, std::make_shared(1),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
- ValidateBooleanAgg(json, std::make_shared(1),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2));
+ ValidateBooleanAgg("[]", std::make_shared(0),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
ValidateBooleanAgg(json, std::make_shared(),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
EXPECT_THAT(Sum(MakeScalar(true)),
ResultWith(Datum(std::make_shared(1))));
@@ -187,6 +183,12 @@ TEST(TestBooleanAggregation, Sum) {
ResultWith(Datum(std::make_shared(0))));
EXPECT_THAT(Sum(MakeNullScalar(boolean())),
ResultWith(Datum(MakeNullScalar(uint64()))));
+ EXPECT_THAT(Sum(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "0")));
+ EXPECT_THAT(Sum(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "null")));
}
TEST(TestBooleanAggregation, Product) {
@@ -219,14 +221,11 @@ TEST(TestBooleanAggregation, Product) {
json, std::make_shared(),
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
ValidateBooleanAgg(
- json, std::make_shared(1),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
- ValidateBooleanAgg(
- json, std::make_shared(1),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2));
+ "[]", std::make_shared(1),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
ValidateBooleanAgg(
json, std::make_shared(),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
EXPECT_THAT(Product(MakeScalar(true)),
ResultWith(Datum(std::make_shared(1))));
@@ -234,6 +233,12 @@ TEST(TestBooleanAggregation, Product) {
ResultWith(Datum(std::make_shared(0))));
EXPECT_THAT(Product(MakeNullScalar(boolean())),
ResultWith(Datum(MakeNullScalar(uint64()))));
+ EXPECT_THAT(Product(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "1")));
+ EXPECT_THAT(Product(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(uint64(), "null")));
}
TEST(TestBooleanAggregation, Mean) {
@@ -264,17 +269,23 @@ TEST(TestBooleanAggregation, Mean) {
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2));
ValidateBooleanAgg(json, std::make_shared(),
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3));
- ValidateBooleanAgg(json, std::make_shared(0.5),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
- ValidateBooleanAgg(json, std::make_shared(0.5),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/2));
+ ValidateBooleanAgg("[]", std::make_shared(NAN),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
ValidateBooleanAgg(json, std::make_shared(),
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
- EXPECT_THAT(Mean(MakeScalar(true)), ResultWith(Datum(MakeScalar(1.0))));
- EXPECT_THAT(Mean(MakeScalar(false)), ResultWith(Datum(MakeScalar(0.0))));
+ EXPECT_THAT(Mean(MakeScalar(true)), ResultWith(ScalarFromJSON(float64(), "1.0")));
+ EXPECT_THAT(Mean(MakeScalar(false)), ResultWith(ScalarFromJSON(float64(), "0.0")));
EXPECT_THAT(Mean(MakeNullScalar(boolean())),
- ResultWith(Datum(MakeNullScalar(float64()))));
+ ResultWith(ScalarFromJSON(float64(), "null")));
+ ASSERT_OK_AND_ASSIGN(
+ auto result, Mean(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)));
+ AssertDatumsApproxEqual(result, ScalarFromJSON(float64(), "NaN"), /*detailed=*/true,
+ EqualOptions::Defaults().nans_equal(true));
+ EXPECT_THAT(Mean(MakeNullScalar(boolean()),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(ScalarFromJSON(float64(), "null")));
}
template
@@ -306,22 +317,23 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) {
ValidateSum(chunks,
Datum(std::make_shared(static_cast(5 * 6 / 2))));
- const ScalarAggregateOptions& options =
- ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
-
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
ValidateSum("[]", Datum(std::make_shared(static_cast(0))),
options);
-
ValidateSum("[null]", Datum(std::make_shared(static_cast(0))),
options);
-
chunks = {};
ValidateSum(chunks, Datum(std::make_shared(static_cast(0))),
options);
- const T expected_result = static_cast(14);
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0);
+ ValidateSum("[]", Datum(std::make_shared(static_cast(0))),
+ options);
+ ValidateSum("[null]", Datum(std::make_shared()), options);
ValidateSum("[1, null, 3, null, 3, null, 7]",
- Datum(std::make_shared(expected_result)), options);
+ Datum(std::make_shared()), options);
+ ValidateSum("[1, null, 3, null, 3, null, 7]",
+ Datum(std::make_shared(14)));
EXPECT_THAT(Sum(Datum(std::make_shared(static_cast(5)))),
ResultWith(Datum(std::make_shared(static_cast(5)))));
@@ -355,12 +367,10 @@ TYPED_TEST(TestNumericSumKernel, ScalarAggregateOptions) {
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
ValidateSum("[null]", null_result,
ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1));
- ValidateSum(json, result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
- ValidateSum(json, result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4));
+ ValidateSum("[]", zero_result,
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
ValidateSum(json, null_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5));
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
EXPECT_THAT(Sum(Datum(std::make_shared(static_cast(5))),
ScalarAggregateOptions(/*skip_nulls=*/false)),
@@ -490,16 +500,20 @@ TYPED_TEST(TestNumericProductKernel, SimpleProduct) {
chunks = ChunkedArrayFromJSON(ty, {"[1, 2]", "[]", "[3, 4, 5]"});
EXPECT_THAT(Product(chunks), ResultWith(Datum(static_cast(120))));
- const ScalarAggregateOptions& options =
- ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0);
-
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
EXPECT_THAT(Product(ArrayFromJSON(ty, "[]"), options), Datum(static_cast(1)));
EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]"), options),
Datum(static_cast(1)));
chunks = ChunkedArrayFromJSON(ty, {});
EXPECT_THAT(Product(chunks, options), Datum(static_cast(1)));
+ options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0);
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[]"), options),
+ ResultWith(Datum(static_cast(1))));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[null]"), options), ResultWith(null_result));
EXPECT_THAT(Product(ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]"), options),
+ ResultWith(null_result));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, "[1, null, 3, null, 3, null, 7]")),
Datum(static_cast(63)));
EXPECT_THAT(Product(Datum(static_cast(5))),
@@ -538,11 +552,10 @@ TYPED_TEST(TestNumericProductKernel, ScalarAggregateOptions) {
ResultWith(null_result));
EXPECT_THAT(Product(null, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
ResultWith(null_result));
- EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
- ResultWith(result));
- EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
- ResultWith(result));
- EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ ResultWith(one_result));
+ EXPECT_THAT(Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
ResultWith(null_result));
EXPECT_THAT(
@@ -680,10 +693,10 @@ void ValidateMean(const Array& array, const ScalarAggregateOptions& options =
}
template
-class TestMeanKernelNumeric : public ::testing::Test {};
+class TestNumericMeanKernel : public ::testing::Test {};
-TYPED_TEST_SUITE(TestMeanKernelNumeric, NumericArrowTypes);
-TYPED_TEST(TestMeanKernelNumeric, SimpleMean) {
+TYPED_TEST_SUITE(TestNumericMeanKernel, NumericArrowTypes);
+TYPED_TEST(TestNumericMeanKernel, SimpleMean) {
using ScalarType = typename TypeTraits::ScalarType;
using InputScalarType = typename TypeTraits::ScalarType;
using T = typename TypeParam::c_type;
@@ -716,7 +729,7 @@ TYPED_TEST(TestMeanKernelNumeric, SimpleMean) {
ResultWith(Datum(MakeNullScalar(float64()))));
}
-TYPED_TEST(TestMeanKernelNumeric, ScalarAggregateOptions) {
+TYPED_TEST(TestNumericMeanKernel, ScalarAggregateOptions) {
using ScalarType = typename TypeTraits::ScalarType;
using InputScalarType = typename TypeTraits::ScalarType;
using T = typename TypeParam::c_type;
@@ -744,20 +757,10 @@ TYPED_TEST(TestMeanKernelNumeric, ScalarAggregateOptions) {
ValidateMean("[]", nan_result,
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
- ValidateMean("[null]", nan_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
- ValidateMean("[]", null_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
ValidateMean("[null]", null_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
- ValidateMean(json, expected_result,
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
- ValidateMean(json, expected_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
- ValidateMean(json, expected_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4));
ValidateMean(json, null_result,
- ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/15));
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
EXPECT_THAT(Mean(Datum(std::make_shared(static_cast(5))),
ScalarAggregateOptions(/*skip_nulls=*/false)),
@@ -1234,19 +1237,21 @@ TYPED_TEST(TestRandomNumericMinMaxKernel, RandomArrayMinMax) {
// Any
//
-class TestPrimitiveAnyKernel : public ::testing::Test {
+class TestAnyKernel : public ::testing::Test {
public:
void AssertAnyIs(const Datum& array, const std::shared_ptr& expected,
const ScalarAggregateOptions& options) {
+ SCOPED_TRACE(options.ToString());
ASSERT_OK_AND_ASSIGN(Datum out, Any(array, options, nullptr));
const BooleanScalar& out_any = out.scalar_as();
- ASSERT_EQ(out_any, *expected);
+ AssertScalarsEqual(*expected, out_any, /*verbose=*/true);
}
void AssertAnyIs(
const std::string& json, const std::shared_ptr& expected,
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
- auto array = ArrayFromJSON(type_singleton(), json);
+ SCOPED_TRACE(json);
+ auto array = ArrayFromJSON(boolean(), json);
AssertAnyIs(array, expected, options);
}
@@ -1254,17 +1259,11 @@ class TestPrimitiveAnyKernel : public ::testing::Test {
const std::vector& json,
const std::shared_ptr& expected,
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
- auto array = ChunkedArrayFromJSON(type_singleton(), json);
+ auto array = ChunkedArrayFromJSON(boolean(), json);
AssertAnyIs(array, expected, options);
}
-
- std::shared_ptr type_singleton() {
- return TypeTraits::type_singleton();
- }
};
-class TestAnyKernel : public TestPrimitiveAnyKernel {};
-
TEST_F(TestAnyKernel, Basics) {
auto true_value = std::make_shared(true);
auto false_value = std::make_shared(false);
@@ -1277,26 +1276,27 @@ TEST_F(TestAnyKernel, Basics) {
std::vector chunked_input3 = {"[false, null]", "[null, false]"};
std::vector chunked_input4 = {"[true, null]", "[null, false]"};
- this->AssertAnyIs("[]", false_value);
- this->AssertAnyIs("[false]", false_value);
- this->AssertAnyIs("[true, false]", true_value);
- this->AssertAnyIs("[null, null, null]", false_value);
- this->AssertAnyIs("[false, false, false]", false_value);
- this->AssertAnyIs("[false, false, false, null]", false_value);
- this->AssertAnyIs("[true, null, true, true]", true_value);
- this->AssertAnyIs("[false, null, false, true]", true_value);
- this->AssertAnyIs("[true, null, false, true]", true_value);
- this->AssertAnyIs(chunked_input0, true_value);
- this->AssertAnyIs(chunked_input1, true_value);
- this->AssertAnyIs(chunked_input2, false_value);
- this->AssertAnyIs(chunked_input3, false_value);
- this->AssertAnyIs(chunked_input4, true_value);
-
- EXPECT_THAT(Any(Datum(true)), ResultWith(Datum(true)));
- EXPECT_THAT(Any(Datum(false)), ResultWith(Datum(false)));
- EXPECT_THAT(Any(MakeNullScalar(boolean())), ResultWith(Datum(false)));
-
- const ScalarAggregateOptions& keep_nulls = ScalarAggregateOptions(/*skip_nulls=*/false);
+ const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ this->AssertAnyIs("[]", false_value, options);
+ this->AssertAnyIs("[false]", false_value, options);
+ this->AssertAnyIs("[true, false]", true_value, options);
+ this->AssertAnyIs("[null, null, null]", false_value, options);
+ this->AssertAnyIs("[false, false, false]", false_value, options);
+ this->AssertAnyIs("[false, false, false, null]", false_value, options);
+ this->AssertAnyIs("[true, null, true, true]", true_value, options);
+ this->AssertAnyIs("[false, null, false, true]", true_value, options);
+ this->AssertAnyIs("[true, null, false, true]", true_value, options);
+ this->AssertAnyIs(chunked_input0, true_value, options);
+ this->AssertAnyIs(chunked_input1, true_value, options);
+ this->AssertAnyIs(chunked_input2, false_value, options);
+ this->AssertAnyIs(chunked_input3, false_value, options);
+ this->AssertAnyIs(chunked_input4, true_value, options);
+
+ EXPECT_THAT(Any(Datum(true), options), ResultWith(Datum(true)));
+ EXPECT_THAT(Any(Datum(false), options), ResultWith(Datum(false)));
+ EXPECT_THAT(Any(Datum(null_value), options), ResultWith(Datum(false)));
+
+ const ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
this->AssertAnyIs("[]", false_value, keep_nulls);
this->AssertAnyIs("[false]", false_value, keep_nulls);
this->AssertAnyIs("[true, false]", true_value, keep_nulls);
@@ -1314,27 +1314,48 @@ TEST_F(TestAnyKernel, Basics) {
EXPECT_THAT(Any(Datum(true), keep_nulls), ResultWith(Datum(true)));
EXPECT_THAT(Any(Datum(false), keep_nulls), ResultWith(Datum(false)));
- EXPECT_THAT(Any(MakeNullScalar(boolean()), keep_nulls),
- ResultWith(Datum(MakeNullScalar(boolean()))));
+ EXPECT_THAT(Any(Datum(null_value), keep_nulls), ResultWith(Datum(null_value)));
+
+ const ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/2);
+ this->AssertAnyIs("[]", null_value, min_count);
+ this->AssertAnyIs("[false]", null_value, min_count);
+ this->AssertAnyIs("[true, false]", true_value, min_count);
+ this->AssertAnyIs("[null, null, null]", null_value, min_count);
+ this->AssertAnyIs("[false, false, false]", false_value, min_count);
+ this->AssertAnyIs("[false, false, false, null]", false_value, min_count);
+ this->AssertAnyIs("[true, null, true, true]", true_value, min_count);
+ this->AssertAnyIs("[false, null, false, true]", true_value, min_count);
+ this->AssertAnyIs("[true, null, false, true]", true_value, min_count);
+ this->AssertAnyIs(chunked_input0, null_value, min_count);
+ this->AssertAnyIs(chunked_input1, true_value, min_count);
+ this->AssertAnyIs(chunked_input2, false_value, min_count);
+ this->AssertAnyIs(chunked_input3, false_value, min_count);
+ this->AssertAnyIs(chunked_input4, true_value, min_count);
+
+ EXPECT_THAT(Any(Datum(true), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(Any(Datum(false), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(Any(Datum(null_value), min_count), ResultWith(Datum(null_value)));
}
//
// All
//
-class TestPrimitiveAllKernel : public ::testing::Test {
+class TestAllKernel : public ::testing::Test {
public:
void AssertAllIs(const Datum& array, const std::shared_ptr& expected,
const ScalarAggregateOptions& options) {
+ SCOPED_TRACE(options.ToString());
ASSERT_OK_AND_ASSIGN(Datum out, All(array, options, nullptr));
const BooleanScalar& out_all = out.scalar_as();
- ASSERT_EQ(out_all, *expected);
+ AssertScalarsEqual(*expected, out_all, /*verbose=*/true);
}
void AssertAllIs(
const std::string& json, const std::shared_ptr& expected,
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
- auto array = ArrayFromJSON(type_singleton(), json);
+ SCOPED_TRACE(json);
+ auto array = ArrayFromJSON(boolean(), json);
AssertAllIs(array, expected, options);
}
@@ -1342,17 +1363,11 @@ class TestPrimitiveAllKernel : public ::testing::Test {
const std::vector& json,
const std::shared_ptr& expected,
const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults()) {
- auto array = ChunkedArrayFromJSON(type_singleton(), json);
+ auto array = ChunkedArrayFromJSON(boolean(), json);
AssertAllIs(array, expected, options);
}
-
- std::shared_ptr type_singleton() {
- return TypeTraits::type_singleton();
- }
};
-class TestAllKernel : public TestPrimitiveAllKernel {};
-
TEST_F(TestAllKernel, Basics) {
auto true_value = std::make_shared(true);
auto false_value = std::make_shared(false);
@@ -1366,27 +1381,28 @@ TEST_F(TestAllKernel, Basics) {
std::vector chunked_input4 = {"[true, null]", "[null, false]"};
std::vector chunked_input5 = {"[false, null]", "[null, true]"};
- this->AssertAllIs("[]", true_value);
- this->AssertAllIs("[false]", false_value);
- this->AssertAllIs("[true, false]", false_value);
- this->AssertAllIs("[null, null, null]", true_value);
- this->AssertAllIs("[false, false, false]", false_value);
- this->AssertAllIs("[false, false, false, null]", false_value);
- this->AssertAllIs("[true, null, true, true]", true_value);
- this->AssertAllIs("[false, null, false, true]", false_value);
- this->AssertAllIs("[true, null, false, true]", false_value);
- this->AssertAllIs(chunked_input0, true_value);
- this->AssertAllIs(chunked_input1, true_value);
- this->AssertAllIs(chunked_input2, false_value);
- this->AssertAllIs(chunked_input3, false_value);
- this->AssertAllIs(chunked_input4, false_value);
- this->AssertAllIs(chunked_input5, false_value);
-
- EXPECT_THAT(All(Datum(true)), ResultWith(Datum(true)));
- EXPECT_THAT(All(Datum(false)), ResultWith(Datum(false)));
- EXPECT_THAT(All(MakeNullScalar(boolean())), ResultWith(Datum(true)));
-
- const ScalarAggregateOptions keep_nulls = ScalarAggregateOptions(/*skip_nulls=*/false);
+ const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ this->AssertAllIs("[]", true_value, options);
+ this->AssertAllIs("[false]", false_value, options);
+ this->AssertAllIs("[true, false]", false_value, options);
+ this->AssertAllIs("[null, null, null]", true_value, options);
+ this->AssertAllIs("[false, false, false]", false_value, options);
+ this->AssertAllIs("[false, false, false, null]", false_value, options);
+ this->AssertAllIs("[true, null, true, true]", true_value, options);
+ this->AssertAllIs("[false, null, false, true]", false_value, options);
+ this->AssertAllIs("[true, null, false, true]", false_value, options);
+ this->AssertAllIs(chunked_input0, true_value, options);
+ this->AssertAllIs(chunked_input1, true_value, options);
+ this->AssertAllIs(chunked_input2, false_value, options);
+ this->AssertAllIs(chunked_input3, false_value, options);
+ this->AssertAllIs(chunked_input4, false_value, options);
+ this->AssertAllIs(chunked_input5, false_value, options);
+
+ EXPECT_THAT(All(Datum(true), options), ResultWith(Datum(true)));
+ EXPECT_THAT(All(Datum(false), options), ResultWith(Datum(false)));
+ EXPECT_THAT(All(Datum(null_value), options), ResultWith(Datum(true)));
+
+ const ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
this->AssertAllIs("[]", true_value, keep_nulls);
this->AssertAllIs("[false]", false_value, keep_nulls);
this->AssertAllIs("[true, false]", false_value, keep_nulls);
@@ -1405,8 +1421,28 @@ TEST_F(TestAllKernel, Basics) {
EXPECT_THAT(All(Datum(true), keep_nulls), ResultWith(Datum(true)));
EXPECT_THAT(All(Datum(false), keep_nulls), ResultWith(Datum(false)));
- EXPECT_THAT(All(MakeNullScalar(boolean()), keep_nulls),
- ResultWith(Datum(MakeNullScalar(boolean()))));
+ EXPECT_THAT(All(Datum(null_value), keep_nulls), ResultWith(Datum(null_value)));
+
+ const ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/2);
+ this->AssertAllIs("[]", null_value, min_count);
+ this->AssertAllIs("[false]", null_value, min_count);
+ this->AssertAllIs("[true, false]", false_value, min_count);
+ this->AssertAllIs("[null, null, null]", null_value, min_count);
+ this->AssertAllIs("[false, false, false]", false_value, min_count);
+ this->AssertAllIs("[false, false, false, null]", false_value, min_count);
+ this->AssertAllIs("[true, null, true, true]", true_value, min_count);
+ this->AssertAllIs("[false, null, false, true]", false_value, min_count);
+ this->AssertAllIs("[true, null, false, true]", false_value, min_count);
+ this->AssertAllIs(chunked_input0, null_value, min_count);
+ this->AssertAllIs(chunked_input1, true_value, min_count);
+ this->AssertAllIs(chunked_input2, false_value, min_count);
+ this->AssertAllIs(chunked_input3, false_value, min_count);
+ this->AssertAllIs(chunked_input4, false_value, min_count);
+ this->AssertAllIs(chunked_input5, false_value, min_count);
+
+ EXPECT_THAT(All(Datum(true), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(All(Datum(false), min_count), ResultWith(Datum(null_value)));
+ EXPECT_THAT(All(Datum(null_value), min_count), ResultWith(Datum(null_value)));
}
//
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index 9222c5dd18f..bfcf17f4aa2 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -916,8 +916,9 @@ struct GroupedSumImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
pool_ = ctx->memory_pool();
options_ = checked_cast(*options);
- sums_ = BufferBuilder(pool_);
- counts_ = BufferBuilder(pool_);
+ sums_ = TypedBufferBuilder(pool_);
+ counts_ = TypedBufferBuilder(pool_);
+ no_nulls_ = TypedBufferBuilder(pool_);
out_type_ = TypeTraits::type_singleton();
return Status::OK();
}
@@ -925,14 +926,16 @@ struct GroupedSumImpl : public GroupedAggregator {
Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
- RETURN_NOT_OK(sums_.Append(added_groups * sizeof(AccType), 0));
- RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0));
+ RETURN_NOT_OK(sums_.Append(added_groups, SumType()));
+ RETURN_NOT_OK(counts_.Append(added_groups, 0));
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
return Status::OK();
}
Status Consume(const ExecBatch& batch) override {
- auto sums = reinterpret_cast(sums_.mutable_data());
- auto counts = reinterpret_cast(counts_.mutable_data());
+ SumType* sums = sums_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
// XXX this uses naive summation; we should switch to pairwise summation as was
// done for the scalar aggregate kernel in ARROW-11758
@@ -944,7 +947,7 @@ struct GroupedSumImpl : public GroupedAggregator {
counts[*g] += 1;
++g;
},
- [&] { ++g; });
+ [&] { BitUtil::SetBitTo(no_nulls, *g++, false); });
return Status::OK();
}
@@ -952,23 +955,28 @@ struct GroupedSumImpl : public GroupedAggregator {
const ArrayData& group_id_mapping) override {
auto other = checked_cast(&raw_other);
- auto counts = reinterpret_cast(counts_.mutable_data());
- auto sums = reinterpret_cast(sums_.mutable_data());
+ SumType* sums = sums_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
- auto other_counts = reinterpret_cast(other->counts_.mutable_data());
- auto other_sums = reinterpret_cast(other->sums_.mutable_data());
+ const SumType* other_sums = other->sums_.data();
+ const int64_t* other_counts = other->counts_.data();
+ const uint8_t* other_no_nulls = no_nulls_.mutable_data();
auto g = group_id_mapping.GetValues(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
counts[*g] += other_counts[other_g];
sums[*g] += other_sums[other_g];
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
}
return Status::OK();
}
Result Finalize() override {
std::shared_ptr null_bitmap;
- const int64_t* counts = reinterpret_cast(counts_.data());
+ const int64_t* counts = counts_.data();
int64_t null_count = 0;
for (int64_t i = 0; i < num_groups_; ++i) {
@@ -983,19 +991,30 @@ struct GroupedSumImpl : public GroupedAggregator {
BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
}
- ARROW_ASSIGN_OR_RAISE(auto sums, sums_.Finish());
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls_.data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
+ }
+ }
+ ARROW_ASSIGN_OR_RAISE(auto sums, sums_.Finish());
return ArrayData::Make(std::move(out_type_), num_groups_,
{std::move(null_bitmap), std::move(sums)}, null_count);
}
std::shared_ptr out_type() const override { return out_type_; }
- // NB: counts are used here instead of a simple "has_values_" bitmap since
- // we expect to reuse this kernel to handle Mean
+ // NB: counts are used here to support Mean below
int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
- BufferBuilder sums_, counts_;
+ TypedBufferBuilder sums_;
+ TypedBufferBuilder counts_;
+ TypedBufferBuilder no_nulls_;
std::shared_ptr out_type_;
MemoryPool* pool_;
};
@@ -1048,12 +1067,14 @@ struct GroupedProductImpl final : public GroupedAggregator {
num_groups_ = new_num_groups;
RETURN_NOT_OK(products_.Append(added_groups * sizeof(AccType), 1));
RETURN_NOT_OK(counts_.Append(added_groups, 0));
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
return Status::OK();
}
Status Consume(const ExecBatch& batch) override {
ProductType* products = products_.mutable_data();
int64_t* counts = counts_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
auto g = batch[1].array()->GetValues(1);
VisitArrayDataInline(
*batch[0].array(),
@@ -1062,7 +1083,7 @@ struct GroupedProductImpl final : public GroupedAggregator {
to_unsigned(products[*g]) * to_unsigned(static_cast(value)));
counts[*g++] += 1;
},
- [&] { ++g; });
+ [&] { BitUtil::SetBitTo(no_nulls, *g++, false); });
return Status::OK();
}
@@ -1072,15 +1093,20 @@ struct GroupedProductImpl final : public GroupedAggregator {
int64_t* counts = counts_.mutable_data();
ProductType* products = products_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
- const int64_t* other_counts = other->counts_.mutable_data();
- const ProductType* other_products = other->products_.mutable_data();
+ const int64_t* other_counts = other->counts_.data();
+ const ProductType* other_products = other->products_.data();
+ const uint8_t* other_no_nulls = other->no_nulls_.data();
const uint32_t* g = group_id_mapping.GetValues(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
products[*g] = static_cast(to_unsigned(products[*g]) *
to_unsigned(other_products[other_g]));
counts[*g] += other_counts[other_g];
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
}
return Status::OK();
}
@@ -1104,6 +1130,17 @@ struct GroupedProductImpl final : public GroupedAggregator {
BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
}
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls_.data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
+ }
+ }
+
return ArrayData::Make(std::move(out_type_), num_groups_,
{std::move(null_bitmap), std::move(products)}, null_count);
}
@@ -1114,6 +1151,7 @@ struct GroupedProductImpl final : public GroupedAggregator {
ScalarAggregateOptions options_;
TypedBufferBuilder products_;
TypedBufferBuilder counts_;
+ TypedBufferBuilder no_nulls_;
std::shared_ptr out_type_;
MemoryPool* pool_;
};
@@ -1157,8 +1195,8 @@ struct GroupedMeanImpl : public GroupedSumImpl {
AllocateBuffer(num_groups_ * sizeof(double), pool_));
int64_t null_count = 0;
- const int64_t* counts = reinterpret_cast(counts_.data());
- const auto* sums = reinterpret_cast(sums_.data());
+ const int64_t* counts = counts_.data();
+ const SumType* sums = sums_.data();
double* means = reinterpret_cast(values->mutable_data());
for (int64_t i = 0; i < num_groups_; ++i) {
if (counts[i] >= options_.min_count) {
@@ -1176,6 +1214,17 @@ struct GroupedMeanImpl : public GroupedSumImpl {
BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
}
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls_.data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
+ }
+ }
+
return ArrayData::Make(float64(), num_groups_,
{std::move(null_bitmap), std::move(values)}, null_count);
}
@@ -1187,6 +1236,7 @@ struct GroupedMeanImpl : public GroupedSumImpl {
using GroupedSumImpl::pool_;
using GroupedSumImpl::counts_;
using GroupedSumImpl::sums_;
+ using GroupedSumImpl::no_nulls_;
};
struct GroupedMeanFactory {
@@ -1741,9 +1791,11 @@ struct GroupedMinMaxFactory {
struct GroupedAnyImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
- options_ = *checked_cast(options);
- seen_ = TypedBufferBuilder(ctx->memory_pool());
- has_nulls_ = TypedBufferBuilder(ctx->memory_pool());
+ options_ = checked_cast(*options);
+ pool_ = ctx->memory_pool();
+ seen_ = TypedBufferBuilder(pool_);
+ no_nulls_ = TypedBufferBuilder(pool_);
+ counts_ = TypedBufferBuilder(pool_);
return Status::OK();
}
@@ -1751,75 +1803,117 @@ struct GroupedAnyImpl : public GroupedAggregator {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
RETURN_NOT_OK(seen_.Append(added_groups, false));
- return has_nulls_.Append(added_groups, false);
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
+ return counts_.Append(added_groups, 0);
+ }
+
+ Status Consume(const ExecBatch& batch) override {
+ uint8_t* seen = seen_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+ const auto& input = *batch[0].array();
+ auto g = batch[1].array()->GetValues(1);
+
+ if (input.MayHaveNulls()) {
+ const uint8_t* bitmap = input.buffers[1]->data();
+ arrow::internal::VisitBitBlocksVoid(
+ input.buffers[0], input.offset, input.length,
+ [&](int64_t position) {
+ counts[*g]++;
+ BitUtil::SetBitTo(
+ seen, *g, BitUtil::GetBit(seen, *g) || BitUtil::GetBit(bitmap, position));
+ g++;
+ },
+ [&] { BitUtil::SetBitTo(no_nulls, *g++, false); });
+ } else {
+ arrow::internal::VisitBitBlocksVoid(
+ input.buffers[1], input.offset, input.length,
+ [&](int64_t) {
+ counts[*g++]++;
+ BitUtil::SetBitTo(seen, *g++, true);
+ },
+ [&]() { counts[*g]++; });
+ }
+ return Status::OK();
}
Status Merge(GroupedAggregator&& raw_other,
const ArrayData& group_id_mapping) override {
auto other = checked_cast(&raw_other);
- auto seen = seen_.mutable_data();
- auto other_seen = other->seen_.data();
- auto has_nulls = has_nulls_.mutable_data();
- auto other_has_nulls = other->has_nulls_.data();
+ uint8_t* seen = seen_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+
+ const uint8_t* other_seen = other->seen_.mutable_data();
+ const uint8_t* other_no_nulls = other->no_nulls_.mutable_data();
+ const int64_t* other_counts = other->counts_.mutable_data();
auto g = group_id_mapping.GetValues(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
- if (BitUtil::GetBit(other_seen, other_g)) BitUtil::SetBitTo(seen, *g, true);
- if (BitUtil::GetBit(other_has_nulls, other_g)) {
- BitUtil::SetBitTo(has_nulls, *g, true);
- }
+ counts[*g] += other_counts[other_g];
+ BitUtil::SetBitTo(
+ seen, *g, BitUtil::GetBit(seen, *g) || BitUtil::GetBit(other_seen, other_g));
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
}
return Status::OK();
}
- Status Consume(const ExecBatch& batch) override {
- auto seen = seen_.mutable_data();
- auto has_nulls = has_nulls_.mutable_data();
+ Result Finalize() override {
+ std::shared_ptr null_bitmap;
+ const int64_t* counts = counts_.data();
+ int64_t null_count = 0;
- const auto& input = *batch[0].array();
+ for (int64_t i = 0; i < num_groups_; ++i) {
+ if (counts[i] >= options_.min_count) continue;
- auto g = batch[1].array()->GetValues(1);
- auto values = input.buffers[1]->data();
- arrow::internal::VisitBitBlocksVoid(
- input.buffers[0], input.offset, input.length,
- [&](int64_t offset) {
- BitUtil::SetBitTo(seen, *g,
- BitUtil::GetBit(seen, *g) ||
- BitUtil::GetBit(values, input.offset + offset));
- g++;
- },
- [&]() { BitUtil::SetBitTo(has_nulls, *g++, true); });
- return Status::OK();
- }
+ if (null_bitmap == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_));
+ BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true);
+ }
+
+ null_count += 1;
+ BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+ }
- Result Finalize() override {
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
- if (options_.skip_nulls) {
- return std::make_shared(num_groups_, std::move(seen));
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ ARROW_ASSIGN_OR_RAISE(auto no_nulls, no_nulls_.Finish());
+ arrow::internal::BitmapOr(no_nulls->data(), /*left_offset=*/0, seen->data(),
+ /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, no_nulls->mutable_data());
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls->data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ null_bitmap = std::move(no_nulls);
+ }
}
- ARROW_ASSIGN_OR_RAISE(auto bitmap, has_nulls_.Finish());
- // null if (~seen & has_nulls) -> not null if (seen | ~has_nulls)
- ::arrow::internal::BitmapOrNot(seen->data(), /*left_offset=*/0, bitmap->data(),
- /*right_offset=*/0, num_groups_, /*out_offset=*/0,
- bitmap->mutable_data());
- return std::make_shared(num_groups_, std::move(seen),
- std::move(bitmap));
+
+ return ArrayData::Make(out_type(), num_groups_,
+ {std::move(null_bitmap), std::move(seen)}, null_count);
}
std::shared_ptr out_type() const override { return boolean(); }
int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
- TypedBufferBuilder seen_;
- TypedBufferBuilder has_nulls_;
+ TypedBufferBuilder seen_, no_nulls_;
+ TypedBufferBuilder counts_;
+ MemoryPool* pool_;
};
struct GroupedAllImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
- options_ = *checked_cast(options);
- seen_ = TypedBufferBuilder(ctx->memory_pool());
- has_nulls_ = TypedBufferBuilder(ctx->memory_pool());
+ options_ = checked_cast(*options);
+ pool_ = ctx->memory_pool();
+ seen_ = TypedBufferBuilder(pool_);
+ no_nulls_ = TypedBufferBuilder(pool_);
+ counts_ = TypedBufferBuilder(pool_);
return Status::OK();
}
@@ -1827,77 +1921,108 @@ struct GroupedAllImpl : public GroupedAggregator {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
RETURN_NOT_OK(seen_.Append(added_groups, true));
- return has_nulls_.Append(added_groups, false);
- }
-
- Status Merge(GroupedAggregator&& raw_other,
- const ArrayData& group_id_mapping) override {
- auto other = checked_cast(&raw_other);
-
- auto seen = seen_.mutable_data();
- auto other_seen = other->seen_.data();
- auto has_nulls = has_nulls_.mutable_data();
- auto other_has_nulls = other->has_nulls_.data();
-
- auto g = group_id_mapping.GetValues(1);
- for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
- BitUtil::SetBitTo(
- seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g));
- if (BitUtil::GetBit(other_has_nulls, other_g)) {
- BitUtil::SetBitTo(has_nulls, *g, true);
- }
- }
- return Status::OK();
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
+ return counts_.Append(added_groups, 0);
}
Status Consume(const ExecBatch& batch) override {
- auto seen = seen_.mutable_data();
- auto has_nulls = has_nulls_.mutable_data();
-
+ uint8_t* seen = seen_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
const auto& input = *batch[0].array();
-
auto g = batch[1].array()->GetValues(1);
+
if (input.MayHaveNulls()) {
const uint8_t* bitmap = input.buffers[1]->data();
arrow::internal::VisitBitBlocksVoid(
input.buffers[0], input.offset, input.length,
[&](int64_t position) {
+ counts[*g]++;
BitUtil::SetBitTo(seen, *g,
BitUtil::GetBit(seen, *g) &&
BitUtil::GetBit(bitmap, input.offset + position));
g++;
},
- [&]() { BitUtil::SetBitTo(has_nulls, *g++, true); });
+ [&]() { BitUtil::SetBitTo(no_nulls, *g++, false); });
} else {
arrow::internal::VisitBitBlocksVoid(
- input.buffers[1], input.offset, input.length, [&](int64_t) { g++; },
- [&]() { BitUtil::SetBitTo(seen, *g++, false); });
+ input.buffers[1], input.offset, input.length, [&](int64_t) { counts[*g++]++; },
+ [&]() {
+ counts[*g]++;
+ BitUtil::SetBitTo(seen, *g++, false);
+ });
+ }
+ return Status::OK();
+ }
+
+ Status Merge(GroupedAggregator&& raw_other,
+ const ArrayData& group_id_mapping) override {
+ auto other = checked_cast(&raw_other);
+
+ uint8_t* seen = seen_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ int64_t* counts = counts_.mutable_data();
+
+ const uint8_t* other_seen = other->seen_.mutable_data();
+ const uint8_t* other_no_nulls = other->no_nulls_.mutable_data();
+ const int64_t* other_counts = other->counts_.mutable_data();
+
+ auto g = group_id_mapping.GetValues(1);
+ for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ counts[*g] += other_counts[other_g];
+ BitUtil::SetBitTo(
+ seen, *g, BitUtil::GetBit(seen, *g) && BitUtil::GetBit(other_seen, other_g));
+ BitUtil::SetBitTo(
+ no_nulls, *g,
+ BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
}
return Status::OK();
}
Result Finalize() override {
+ std::shared_ptr null_bitmap;
+ const int64_t* counts = counts_.data();
+ int64_t null_count = 0;
+
+ for (int64_t i = 0; i < num_groups_; ++i) {
+ if (counts[i] >= options_.min_count) continue;
+
+ if (null_bitmap == nullptr) {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_groups_, pool_));
+ BitUtil::SetBitsTo(null_bitmap->mutable_data(), 0, num_groups_, true);
+ }
+
+ null_count += 1;
+ BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
+ }
+
ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish());
- if (options_.skip_nulls) {
- return std::make_shared(num_groups_, std::move(seen));
+ if (!options_.skip_nulls) {
+ null_count = kUnknownNullCount;
+ ARROW_ASSIGN_OR_RAISE(auto no_nulls, no_nulls_.Finish());
+ arrow::internal::BitmapOrNot(no_nulls->data(), /*left_offset=*/0, seen->data(),
+ /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, no_nulls->mutable_data());
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0,
+ no_nulls->data(), /*right_offset=*/0, num_groups_,
+ /*out_offset=*/0, null_bitmap->mutable_data());
+ } else {
+ null_bitmap = std::move(no_nulls);
+ }
}
- ARROW_ASSIGN_OR_RAISE(auto bitmap, has_nulls_.Finish());
- // null if (seen & has_nulls)
- ::arrow::internal::BitmapAnd(seen->data(), /*left_offset=*/0, bitmap->data(),
- /*right_offset=*/0, num_groups_, /*out_offset=*/0,
- bitmap->mutable_data());
- ::arrow::internal::InvertBitmap(bitmap->data(), /*offset=*/0, num_groups_,
- bitmap->mutable_data(), /*dest_offset=*/0);
- return std::make_shared(num_groups_, std::move(seen),
- std::move(bitmap));
+
+ return ArrayData::Make(out_type(), num_groups_,
+ {std::move(null_bitmap), std::move(seen)}, null_count);
}
std::shared_ptr out_type() const override { return boolean(); }
int64_t num_groups_ = 0;
ScalarAggregateOptions options_;
- TypedBufferBuilder seen_;
- TypedBufferBuilder has_nulls_;
+ TypedBufferBuilder seen_, no_nulls_;
+ TypedBufferBuilder counts_;
+ MemoryPool* pool_;
};
} // namespace
@@ -2176,7 +2301,8 @@ const FunctionDoc hash_product_doc{
"Compute product of values of a numeric array",
("Null values are ignored.\n"
"Overflow will wrap around as if the calculation was done with unsigned integers."),
- {"array", "group_id_array"}};
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
const FunctionDoc hash_mean_doc{"Average values of a numeric array",
("Null values are ignored."),
@@ -2216,11 +2342,13 @@ const FunctionDoc hash_min_max_doc{
const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
("Null values are ignored."),
- {"array", "group_id_array"}};
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
("Null values are ignored."),
- {"array", "group_id_array"}};
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
} // namespace
void RegisterHashAggregateBasic(FunctionRegistry* registry) {
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index e96fdcd6084..21440248493 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -1108,6 +1108,10 @@ TEST(GroupBy, AnyAndAll) {
[null, 3]
])"});
+ ScalarAggregateOptions no_min(/*skip_nulls=*/true, /*min_count=*/0);
+ ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
+ ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0);
+ ScalarAggregateOptions keep_nulls_min_count(/*skip_nulls=*/false, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy(
{
@@ -1115,13 +1119,21 @@ TEST(GroupBy, AnyAndAll) {
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
+ table->GetColumnByName("argument"),
},
{table->GetColumnByName("key")},
{
- {"hash_any", nullptr},
- {"hash_all", nullptr},
- {"hash_any", &options},
- {"hash_all", &options},
+ {"hash_any", &no_min},
+ {"hash_any", &min_count},
+ {"hash_any", &keep_nulls},
+ {"hash_any", &keep_nulls_min_count},
+ {"hash_all", &no_min},
+ {"hash_all", &min_count},
+ {"hash_all", &keep_nulls},
+ {"hash_all", &keep_nulls_min_count},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);
@@ -1134,18 +1146,22 @@ TEST(GroupBy, AnyAndAll) {
// Group null: falses
AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_any", boolean()),
- field("hash_all", boolean()),
field("hash_any", boolean()),
+ field("hash_any", boolean()),
+ field("hash_any", boolean()),
+ field("hash_all", boolean()),
+ field("hash_all", boolean()),
+ field("hash_all", boolean()),
field("hash_all", boolean()),
field("key_0", int64()),
}),
R"([
- [true, true, true, null, 1],
- [true, false, true, false, 2],
- [false, true, null, null, 3],
- [false, false, null, false, 4],
- [true, true, true, true, 5],
- [false, false, false, false, null]
+ [true, null, true, null, true, null, null, null, 1],
+ [true, true, true, true, false, false, false, false, 2],
+ [false, null, null, null, true, null, null, null, 3],
+ [false, null, null, null, false, null, false, null, 4],
+ [true, null, true, null, true, null, true, null, 5],
+ [false, null, false, null, false, null, false, null, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
@@ -1290,6 +1306,64 @@ TEST(GroupBy, Product) {
/*verbose=*/true);
}
+TEST(GroupBy, SumMeanProductKeepNulls) {
+ auto batch = RecordBatchFromJSON(
+ schema({field("argument", float64()), field("key", int64())}), R"([
+ [-1.0, 1],
+ [null, 1],
+ [0.0, 2],
+ [null, 3],
+ [4.0, null],
+ [3.25, 1],
+ [0.125, 2],
+ [-0.25, 2],
+ [0.75, null],
+ [null, 3]
+ ])");
+
+ ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false);
+ ScalarAggregateOptions min_count(/*skip_nulls=*/false, /*min_count=*/3);
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ internal::GroupBy(
+ {
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ },
+ {
+ batch->GetColumnByName("key"),
+ },
+ {
+ {"hash_sum", &keep_nulls},
+ {"hash_sum", &min_count},
+ {"hash_mean", &keep_nulls},
+ {"hash_mean", &min_count},
+ {"hash_product", &keep_nulls},
+ {"hash_product", &min_count},
+ }));
+
+ AssertDatumsApproxEqual(ArrayFromJSON(struct_({
+ field("hash_sum", float64()),
+ field("hash_sum", float64()),
+ field("hash_mean", float64()),
+ field("hash_mean", float64()),
+ field("hash_product", float64()),
+ field("hash_product", float64()),
+ field("key_0", int64()),
+ }),
+ R"([
+ [null, null, null, null, null, null, 1],
+ [-0.125, -0.125, -0.0416667, -0.0416667, 0.0, 0.0, 2],
+ [null, null, null, null, null, null, 3],
+ [4.75, null, 2.375, null, 3.0, null, null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+}
+
TEST(GroupBy, SumOnlyStringAndDictKeys) {
for (auto key_type : {utf8(), dictionary(int32(), utf8())}) {
SCOPED_TRACE("key type: " + key_type->ToString());
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 6c1684b87dc..cac7a3c95e5 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -551,7 +551,13 @@ def test_min_max():
def test_any():
# ARROW-1846
- options = pc.ScalarAggregateOptions(skip_nulls=False)
+ options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0)
+
+ a = pa.array([], type='bool')
+ assert pc.any(a).as_py() is None
+ assert pc.any(a, min_count=0).as_py() is False
+ assert pc.any(a, options=options).as_py() is False
+
a = pa.array([False, None, True])
assert pc.any(a).as_py() is True
assert pc.any(a, options=options).as_py() is True
@@ -564,9 +570,11 @@ def test_any():
def test_all():
# ARROW-10301
- options = pc.ScalarAggregateOptions(skip_nulls=False)
+ options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0)
+
a = pa.array([], type='bool')
- assert pc.all(a).as_py() is True
+ assert pc.all(a).as_py() is None
+ assert pc.all(a, min_count=0).as_py() is True
assert pc.all(a, options=options).as_py() is True
a = pa.array([False, True])
diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R
index a04c0efcdae..4e1e4dcb859 100644
--- a/r/R/arrowExports.R
+++ b/r/R/arrowExports.R
@@ -300,8 +300,8 @@ ExecNode_Project <- function(input, exprs, names) {
.Call(`_arrow_ExecNode_Project`, input, exprs, names)
}
-ExecNode_Aggregate <- function(input, options, target_names, out_field_names, key_names){
- .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names, out_field_names, key_names)
+ExecNode_Aggregate <- function(input, options, target_names, out_field_names, key_names) {
+ .Call(`_arrow_ExecNode_Aggregate`, input, options, target_names, out_field_names, key_names)
}
RecordBatch__cast <- function(batch, schema, options) {
@@ -316,10 +316,6 @@ compute__CallFunction <- function(func_name, args, options) {
.Call(`_arrow_compute__CallFunction`, func_name, args, options)
}
-compute__GroupBy <- function(arguments, keys, options) {
- .Call(`_arrow_compute__GroupBy`, arguments, keys, options)
-}
-
compute__GetFunctionNames <- function() {
.Call(`_arrow_compute__GetFunctionNames`)
}
diff --git a/r/R/compute.R b/r/R/compute.R
index 8cfaaf7b415..3953eaa5dfc 100644
--- a/r/R/compute.R
+++ b/r/R/compute.R
@@ -122,12 +122,6 @@ max.ArrowDatum <- function(..., na.rm = FALSE) {
scalar_aggregate <- function(FUN, ..., na.rm = FALSE, na.min_count = 0) {
a <- collect_arrays_from_dots(list(...))
- if (!na.rm) {
- # When not removing null values, we require all values to be not null and
- # return null otherwise. We do that by setting minimum count of non-null
- # option values to the full array length.
- na.min_count <- length(a)
- }
if (FUN == "min_max" && na.rm && a$null_count == length(a)) {
Array$create(data.frame(min = Inf, max = -Inf))
# If na.rm == TRUE and all values in array are NA, R returns
diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R
index 9972d4796a8..efba9f287f9 100644
--- a/r/R/dplyr-functions.R
+++ b/r/R/dplyr-functions.R
@@ -791,28 +791,20 @@ agg_funcs$sum <- function(x, na.rm = FALSE) {
list(
fun = "sum",
data = x,
- options = arrow_na_rm(na.rm = na.rm)
+ options = list(na.rm = na.rm, na.min_count = 0L)
)
}
agg_funcs$any <- function(x, na.rm = FALSE) {
list(
fun = "any",
data = x,
- options = arrow_na_rm(na.rm)
+ options = list(na.rm = na.rm, na.min_count = 0L)
)
}
agg_funcs$all <- function(x, na.rm = FALSE) {
list(
fun = "all",
data = x,
- options = arrow_na_rm(na.rm)
+ options = list(na.rm = na.rm, na.min_count = 0L)
)
}
-
-arrow_na_rm <- function(na.rm) {
- if (!isTRUE(na.rm)) {
- # TODO: ARROW-13497
- arrow_not_supported(paste("na.rm =", na.rm))
- }
- list(na.rm = na.rm, na.min_count = 0L)
-}
diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp
index 0645a64eac9..5ef39215c73 100644
--- a/r/src/arrowExports.cpp
+++ b/r/src/arrowExports.cpp
@@ -1244,23 +1244,6 @@ extern "C" SEXP _arrow_compute__CallFunction(SEXP func_name_sexp, SEXP args_sexp
}
#endif
-// compute.cpp
-#if defined(ARROW_R_WITH_ARROW)
-SEXP compute__GroupBy(cpp11::list arguments, cpp11::list keys, cpp11::list options);
-extern "C" SEXP _arrow_compute__GroupBy(SEXP arguments_sexp, SEXP keys_sexp, SEXP options_sexp){
-BEGIN_CPP11
- arrow::r::Input::type arguments(arguments_sexp);
- arrow::r::Input::type keys(keys_sexp);
- arrow::r::Input::type options(options_sexp);
- return cpp11::as_sexp(compute__GroupBy(arguments, keys, options));
-END_CPP11
-}
-#else
-extern "C" SEXP _arrow_compute__GroupBy(SEXP arguments_sexp, SEXP keys_sexp, SEXP options_sexp){
- Rf_error("Cannot call compute__GroupBy(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. ");
-}
-#endif
-
// compute.cpp
#if defined(ARROW_R_WITH_ARROW)
std::vector compute__GetFunctionNames();
@@ -7136,7 +7119,6 @@ static const R_CallMethodDef CallEntries[] = {
{ "_arrow_RecordBatch__cast", (DL_FUNC) &_arrow_RecordBatch__cast, 3},
{ "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3},
{ "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3},
- { "_arrow_compute__GroupBy", (DL_FUNC) &_arrow_compute__GroupBy, 3},
{ "_arrow_compute__GetFunctionNames", (DL_FUNC) &_arrow_compute__GetFunctionNames, 0},
{ "_arrow_build_info", (DL_FUNC) &_arrow_build_info, 0},
{ "_arrow_runtime_info", (DL_FUNC) &_arrow_runtime_info, 0},
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 0695e2525f7..b697ecd96a0 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -172,7 +172,9 @@ std::shared_ptr make_compute_options(
}
if (func_name == "min_max" || func_name == "sum" || func_name == "mean" ||
- func_name == "any" || func_name == "all") {
+ func_name == "any" || func_name == "all" || func_name == "hash_min_max" ||
+ func_name == "hash_sum" || func_name == "hash_mean" || func_name == "hash_any" ||
+ func_name == "hash_all") {
using Options = arrow::compute::ScalarAggregateOptions;
auto out = std::make_shared(Options::Defaults());
out->min_count = cpp11::as_cpp(options["na.min_count"]);
@@ -390,29 +392,6 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list
return from_datum(std::move(out));
}
-// [[arrow::export]]
-SEXP compute__GroupBy(cpp11::list arguments, cpp11::list keys, cpp11::list options) {
- // options is a list of pairs: string function name, list of options
-
- std::vector> keep_alives;
- std::vector aggregates;
-
- for (cpp11::list name_opts : options) {
- auto name = cpp11::as_cpp(name_opts[0]);
- auto opts = make_compute_options(name, name_opts[1]);
-
- aggregates.push_back(
- arrow::compute::internal::Aggregate{std::move(name), opts.get()});
- keep_alives.push_back(std::move(opts));
- }
-
- auto datum_arguments = arrow::r::from_r_list(arguments);
- auto datum_keys = arrow::r::from_r_list(keys);
- auto out = ValueOrStop(arrow::compute::internal::GroupBy(datum_arguments, datum_keys,
- aggregates, gc_context()));
- return from_datum(std::move(out));
-}
-
// [[arrow::export]]
std::vector compute__GetFunctionNames() {
return arrow::compute::GetFunctionRegistry()->GetFunctionNames();
diff --git a/r/tests/testthat/test-dplyr-aggregate.R b/r/tests/testthat/test-dplyr-aggregate.R
index 8235ef29948..ef2929ee65b 100644
--- a/r/tests/testthat/test-dplyr-aggregate.R
+++ b/r/tests/testthat/test-dplyr-aggregate.R
@@ -59,9 +59,7 @@ test_that("Can aggregate in Arrow", {
input %>%
summarize(total = sum(int)) %>%
collect(),
- tbl,
- # ARROW-13497: This is failing because the default is na.rm = FALSE
- warning = TRUE
+ tbl
)
})
@@ -90,9 +88,7 @@ test_that("Group by sum on dataset", {
summarize(total = sum(int)) %>%
arrange(some_grouping) %>%
collect(),
- tbl,
- # ARROW-13497: This is failing because the default is na.rm = FALSE
- warning = TRUE
+ tbl
)
})
@@ -115,7 +111,22 @@ test_that("Group by any/all", {
collect(),
tbl
)
- # ARROW-13497: na.rm option also is not being passed/received to any/all
+ expect_dplyr_equal(
+ input %>%
+ group_by(some_grouping) %>%
+ summarize(any(lgl, na.rm = FALSE)) %>%
+ arrange(some_grouping) %>%
+ collect(),
+ tbl
+ )
+ expect_dplyr_equal(
+ input %>%
+ group_by(some_grouping) %>%
+ summarize(all(lgl, na.rm = FALSE)) %>%
+ arrange(some_grouping) %>%
+ collect(),
+ tbl
+ )
expect_dplyr_equal(
input %>%
From 68bd7d75e0fb68d9637b3a8fd37d5e9f8c1b5898 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 19 Aug 2021 14:13:46 -0400
Subject: [PATCH 2/6] ARROW-13627: [C++] Short-circuit
---
cpp/src/arrow/compute/kernels/hash_aggregate.cc | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index bfcf17f4aa2..ccd3ca2955b 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -1820,8 +1820,9 @@ struct GroupedAnyImpl : public GroupedAggregator {
input.buffers[0], input.offset, input.length,
[&](int64_t position) {
counts[*g]++;
- BitUtil::SetBitTo(
- seen, *g, BitUtil::GetBit(seen, *g) || BitUtil::GetBit(bitmap, position));
+ if (!BitUtil::GetBit(seen, *g) && BitUtil::GetBit(bitmap, position)) {
+ BitUtil::SetBit(seen, *g);
+ }
g++;
},
[&] { BitUtil::SetBitTo(no_nulls, *g++, false); });
From 0f8857f637c89bd9c93bd9eb0ef185b4f781a770 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 19 Aug 2021 14:15:58 -0400
Subject: [PATCH 3/6] ARROW-13627: [C++] Restore tests
---
cpp/src/arrow/compute/kernels/aggregate_test.cc | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index 34346d6ca5c..fe023cad1b6 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -176,6 +176,10 @@ TEST(TestBooleanAggregation, Sum) {
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
ValidateBooleanAgg(json, std::make_shared(),
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg("[]", std::make_shared(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ValidateBooleanAgg(json, std::make_shared(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
EXPECT_THAT(Sum(MakeScalar(true)),
ResultWith(Datum(std::make_shared(1))));
@@ -226,6 +230,12 @@ TEST(TestBooleanAggregation, Product) {
ValidateBooleanAgg(
json, std::make_shared(),
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg(
+ "[]", std::make_shared(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ValidateBooleanAgg(
+ json, std::make_shared(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
EXPECT_THAT(Product(MakeScalar(true)),
ResultWith(Datum(std::make_shared(1))));
@@ -273,6 +283,10 @@ TEST(TestBooleanAggregation, Mean) {
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
ValidateBooleanAgg(json, std::make_shared(),
ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0));
+ ValidateBooleanAgg("[]", std::make_shared(),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3));
+ ValidateBooleanAgg(json, std::make_shared