From 375a26a6606ea8825435fca5d8e86c5156fef7e6 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 25 Aug 2021 15:11:15 -0400
Subject: [PATCH 1/3] ARROW-13691: [C++] Support min_count/skip_nulls in
VarianceOptions
---
cpp/src/arrow/compute/api_aggregate.cc | 13 +-
cpp/src/arrow/compute/api_aggregate.h | 7 +-
.../arrow/compute/kernels/hash_aggregate.cc | 119 ++++++++++++------
.../compute/kernels/hash_aggregate_test.cc | 101 +++++++++++++++
4 files changed, 196 insertions(+), 44 deletions(-)
diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc
index af7aec865fc..6d7bdfa6cf9 100644
--- a/cpp/src/arrow/compute/api_aggregate.cc
+++ b/cpp/src/arrow/compute/api_aggregate.cc
@@ -87,8 +87,10 @@ static auto kCountOptionsType =
GetFunctionOptionsType(DataMember("mode", &CountOptions::mode));
static auto kModeOptionsType =
GetFunctionOptionsType(DataMember("n", &ModeOptions::n));
-static auto kVarianceOptionsType =
- GetFunctionOptionsType(DataMember("ddof", &VarianceOptions::ddof));
+static auto kVarianceOptionsType = GetFunctionOptionsType(
+ DataMember("ddof", &VarianceOptions::ddof),
+ DataMember("skip_nulls", &VarianceOptions::skip_nulls),
+ DataMember("min_count", &VarianceOptions::min_count));
static auto kQuantileOptionsType = GetFunctionOptionsType(
DataMember("q", &QuantileOptions::q),
DataMember("interpolation", &QuantileOptions::interpolation));
@@ -113,8 +115,11 @@ constexpr char CountOptions::kTypeName[];
ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {}
constexpr char ModeOptions::kTypeName[];
-VarianceOptions::VarianceOptions(int ddof)
- : FunctionOptions(internal::kVarianceOptionsType), ddof(ddof) {}
+VarianceOptions::VarianceOptions(int ddof, bool skip_nulls, uint32_t min_count)
+ : FunctionOptions(internal::kVarianceOptionsType),
+ ddof(ddof),
+ skip_nulls(skip_nulls),
+ min_count(min_count) {}
constexpr char VarianceOptions::kTypeName[];
QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation)
diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h
index d8cda022de8..8c27da49765 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -95,11 +95,16 @@ class ARROW_EXPORT ModeOptions : public FunctionOptions {
/// By default, ddof is zero, and population variance or stddev is returned.
class ARROW_EXPORT VarianceOptions : public FunctionOptions {
public:
- explicit VarianceOptions(int ddof = 0);
+ explicit VarianceOptions(int ddof = 0, bool skip_nulls = true, uint32_t min_count = 0);
constexpr static char const kTypeName[] = "VarianceOptions";
static VarianceOptions Defaults() { return VarianceOptions{}; }
int ddof = 0;
+ /// If true (the default), null values are ignored. Otherwise, if any value is null,
+ /// emit null.
+ bool skip_nulls;
+ /// If less than this many non-null values are observed, emit null.
+ uint32_t min_count;
};
/// \brief Control Quantile kernel behavior
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index f75009f0077..3ea692857cf 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -1304,18 +1304,20 @@ struct GroupedVarStdImpl : public GroupedAggregator {
options_ = *checked_cast(options);
ctx_ = ctx;
pool_ = ctx->memory_pool();
- counts_ = BufferBuilder(pool_);
- means_ = BufferBuilder(pool_);
- m2s_ = BufferBuilder(pool_);
+ counts_ = TypedBufferBuilder(pool_);
+ means_ = TypedBufferBuilder(pool_);
+ m2s_ = TypedBufferBuilder(pool_);
+ no_nulls_ = TypedBufferBuilder(pool_);
return Status::OK();
}
Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
- RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0));
- RETURN_NOT_OK(means_.Append(added_groups * sizeof(double), 0));
- RETURN_NOT_OK(m2s_.Append(added_groups * sizeof(double), 0));
+ RETURN_NOT_OK(counts_.Append(added_groups, 0));
+ RETURN_NOT_OK(means_.Append(added_groups, 0));
+ RETURN_NOT_OK(m2s_.Append(added_groups, 0));
+ RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
return Status::OK();
}
@@ -1332,18 +1334,21 @@ struct GroupedVarStdImpl : public GroupedAggregator {
GroupedVarStdImpl state;
RETURN_NOT_OK(state.Init(ctx_, &options_));
RETURN_NOT_OK(state.Resize(num_groups_));
- int64_t* counts = reinterpret_cast(state.counts_.mutable_data());
- double* means = reinterpret_cast(state.means_.mutable_data());
- double* m2s = reinterpret_cast(state.m2s_.mutable_data());
+ int64_t* counts = state.counts_.mutable_data();
+ double* means = state.means_.mutable_data();
+ double* m2s = state.m2s_.mutable_data();
+ uint8_t* no_nulls = state.no_nulls_.mutable_data();
// XXX this uses naive summation; we should switch to pairwise summation as was
// done for the scalar aggregate kernel in ARROW-11567
std::vector sums(num_groups_);
- VisitGroupedValuesNonNull(
- batch, [&](uint32_t g, typename TypeTraits::CType value) {
+ VisitGroupedValues(
+ batch,
+ [&](uint32_t g, typename TypeTraits::CType value) {
sums[g] += value;
counts[g]++;
- });
+ },
+ [&](uint32_t g) { BitUtil::ClearBit(no_nulls, g); });
for (int64_t i = 0; i < num_groups_; i++) {
means[i] = static_cast(sums[i]) / counts[i];
@@ -1362,9 +1367,7 @@ struct GroupedVarStdImpl : public GroupedAggregator {
}
ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)},
/*null_count=*/0);
- RETURN_NOT_OK(this->Merge(std::move(state), group_id_mapping));
-
- return Status::OK();
+ return this->Merge(std::move(state), group_id_mapping);
}
// int32/16/8: textbook one pass algorithm with integer arithmetic (see
@@ -1377,12 +1380,15 @@ struct GroupedVarStdImpl : public GroupedAggregator {
// for int32: -2^62 <= sum < 2^62
constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);
+ const auto g = batch[1].array()->GetValues(1);
if (batch[0].is_scalar() && !batch[0].scalar()->is_valid) {
+ uint8_t* no_nulls = no_nulls_.mutable_data();
+ for (int64_t i = 0; i < batch.length; i++) {
+ BitUtil::ClearBit(no_nulls, g[i]);
+ }
return Status::OK();
}
- const auto g = batch[1].array()->GetValues(1);
-
std::vector> var_std(num_groups_);
ARROW_ASSIGN_OR_RAISE(auto mapping,
@@ -1402,23 +1408,42 @@ struct GroupedVarStdImpl : public GroupedAggregator {
GroupedVarStdImpl state;
RETURN_NOT_OK(state.Init(ctx_, &options_));
RETURN_NOT_OK(state.Resize(num_groups_));
- int64_t* other_counts = reinterpret_cast(state.counts_.mutable_data());
- double* other_means = reinterpret_cast(state.means_.mutable_data());
- double* other_m2s = reinterpret_cast(state.m2s_.mutable_data());
+ int64_t* other_counts = state.counts_.mutable_data();
+ double* other_means = state.means_.mutable_data();
+ double* other_m2s = state.m2s_.mutable_data();
+ uint8_t* other_no_nulls = state.no_nulls_.mutable_data();
if (batch[0].is_array()) {
const auto& array = *batch[0].array();
const CType* values = array.GetValues(1);
- arrow::internal::VisitSetBitRunsVoid(
- array.buffers[0], array.offset + start_index,
- std::min(max_length, batch.length - start_index),
- [&](int64_t pos, int64_t len) {
- for (int64_t i = 0; i < len; ++i) {
- const int64_t index = start_index + pos + i;
- const auto value = values[index];
- var_std[g[index]].ConsumeOne(value);
+ auto visit_values = [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ const int64_t index = start_index + pos + i;
+ const auto value = values[index];
+ var_std[g[index]].ConsumeOne(value);
+ }
+ };
+
+ if (array.MayHaveNulls()) {
+ arrow::internal::BitRunReader reader(
+ array.buffers[0]->data(), array.offset + start_index,
+ std::min(max_length, batch.length - start_index));
+ int64_t position = 0;
+ while (true) {
+ auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ visit_values(position, run.length);
+ } else {
+ for (int64_t i = 0; i < run.length; ++i) {
+ BitUtil::ClearBit(other_no_nulls, g[start_index + position + i]);
}
- });
+ }
+ position += run.length;
+ }
+ } else {
+ visit_values(0, array.length);
+ }
} else {
const auto value = UnboxScalar::Unbox(*batch[0].scalar());
for (int64_t i = 0; i < std::min(max_length, batch.length - start_index); ++i) {
@@ -1444,16 +1469,21 @@ struct GroupedVarStdImpl : public GroupedAggregator {
// Combine m2 from two chunks (see aggregate_var_std.cc)
auto other = checked_cast(&raw_other);
- auto counts = reinterpret_cast(counts_.mutable_data());
- auto means = reinterpret_cast(means_.mutable_data());
- auto m2s = reinterpret_cast(m2s_.mutable_data());
+ int64_t* counts = counts_.mutable_data();
+ double* means = means_.mutable_data();
+ double* m2s = m2s_.mutable_data();
+ uint8_t* no_nulls = no_nulls_.mutable_data();
- const auto* other_counts = reinterpret_cast(other->counts_.data());
- const auto* other_means = reinterpret_cast(other->means_.data());
- const auto* other_m2s = reinterpret_cast(other->m2s_.data());
+ const int64_t* other_counts = other->counts_.data();
+ const double* other_means = other->means_.data();
+ const double* other_m2s = other->m2s_.data();
+ const uint8_t* other_no_nulls = other->no_nulls_.data();
auto g = group_id_mapping.GetValues(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
+ if (!BitUtil::GetBit(other_no_nulls, other_g)) {
+ BitUtil::ClearBit(no_nulls, *g);
+ }
if (other_counts[other_g] == 0) continue;
MergeVarStd(counts[*g], means[*g], other_counts[other_g], other_means[other_g],
other_m2s[other_g], &counts[*g], &means[*g], &m2s[*g]);
@@ -1468,10 +1498,10 @@ struct GroupedVarStdImpl : public GroupedAggregator {
int64_t null_count = 0;
double* results = reinterpret_cast(values->mutable_data());
- const int64_t* counts = reinterpret_cast(counts_.data());
- const double* m2s = reinterpret_cast(m2s_.data());
+ const int64_t* counts = counts_.data();
+ const double* m2s = m2s_.data();
for (int64_t i = 0; i < num_groups_; ++i) {
- if (counts[i] > options_.ddof) {
+ if (counts[i] > options_.ddof && counts[i] >= options_.min_count) {
const double variance = m2s[i] / (counts[i] - options_.ddof);
results[i] = result_type_ == VarOrStd::Var ? variance : std::sqrt(variance);
continue;
@@ -1486,6 +1516,15 @@ struct GroupedVarStdImpl : public GroupedAggregator {
null_count += 1;
BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
}
+ if (!options_.skip_nulls) {
+ if (null_bitmap) {
+ arrow::internal::BitmapAnd(null_bitmap->data(), 0, no_nulls_.data(), 0,
+ num_groups_, 0, null_bitmap->mutable_data());
+ } else {
+ ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
+ }
+ null_count = kUnknownNullCount;
+ }
return ArrayData::Make(float64(), num_groups_,
{std::move(null_bitmap), std::move(values)}, null_count);
@@ -1497,7 +1536,9 @@ struct GroupedVarStdImpl : public GroupedAggregator {
VarianceOptions options_;
int64_t num_groups_ = 0;
// m2 = count * s2 = sum((X-mean)^2)
- BufferBuilder counts_, means_, m2s_;
+ TypedBufferBuilder counts_;
+ TypedBufferBuilder means_, m2s_;
+ TypedBufferBuilder no_nulls_;
ExecContext* ctx_;
MemoryPool* pool_;
};
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 8fe027490be..32e8efa0ab8 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -1183,6 +1183,107 @@ TEST(GroupBy, StddevVarianceTDigestScalar) {
}
}
+TEST(GroupBy, VarianceOptions) {
+ BatchesWithSchema input;
+ input.batches = {
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[1, 1.0, 1], [1, 1.0, 1], [1, 1.0, 2], [1, 1.0, 2], [1, 1.0, 3]]"),
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[1, 1.0, 4], [1, 1.0, 4]]"),
+ ExecBatchFromJSON(
+ {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()},
+ "[[null, null, 1]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()}, "[[2, 2.0, 1], [3, 3.0, 2]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()}, "[[4, 4.0, 2], [2, 2.0, 4]]"),
+ ExecBatchFromJSON({int32(), float32(), int64()}, "[[null, null, 4]]"),
+ };
+ input.schema = schema(
+ {field("argument", int32()), field("argument1", float32()), field("key", int64())});
+
+ VarianceOptions keep_nulls(/*ddof=*/0, /*skip_nulls=*/false, /*min_count=*/0);
+ VarianceOptions min_count(/*ddof=*/0, /*skip_nulls=*/true, /*min_count=*/3);
+ VarianceOptions keep_nulls_min_count(/*ddof=*/0, /*skip_nulls=*/false, /*min_count=*/3);
+
+ for (bool use_threads : {false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+ ASSERT_OK_AND_ASSIGN(
+ Datum actual, GroupByUsingExecPlan(input, {"key"},
+ {
+ "argument",
+ "argument",
+ "argument",
+ "argument",
+ "argument",
+ "argument",
+ },
+ {
+ {"hash_stddev", &keep_nulls},
+ {"hash_stddev", &min_count},
+ {"hash_stddev", &keep_nulls_min_count},
+ {"hash_variance", &keep_nulls},
+ {"hash_variance", &min_count},
+ {"hash_variance", &keep_nulls_min_count},
+ },
+ use_threads, default_exec_context()));
+ Datum expected = ArrayFromJSON(struct_({
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("key", int64()),
+ }),
+ R"([
+ [null, 0.471405, null, null, 0.222222, null, 1],
+ [1.29904, 1.29904, 1.29904, 1.6875, 1.6875, 1.6875, 2],
+ [0.0, null, null, 0.0, null, null, 3],
+ [null, 0.471405, null, null, 0.222222, null, 4]
+ ])");
+ ValidateOutput(expected);
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(
+ actual, GroupByUsingExecPlan(input, {"key"},
+ {
+ "argument1",
+ "argument1",
+ "argument1",
+ "argument1",
+ "argument1",
+ "argument1",
+ },
+ {
+ {"hash_stddev", &keep_nulls},
+ {"hash_stddev", &min_count},
+ {"hash_stddev", &keep_nulls_min_count},
+ {"hash_variance", &keep_nulls},
+ {"hash_variance", &min_count},
+ {"hash_variance", &keep_nulls_min_count},
+ },
+ use_threads, default_exec_context()));
+ expected = ArrayFromJSON(struct_({
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_stddev", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("hash_variance", float64()),
+ field("key", int64()),
+ }),
+ R"([
+ [null, 0.471405, null, null, 0.222222, null, 1],
+ [1.29904, 1.29904, 1.29904, 1.6875, 1.6875, 1.6875, 2],
+ [0.0, null, null, 0.0, null, null, 3],
+ [null, 0.471405, null, null, 0.222222, null, 4]
+ ])");
+ ValidateOutput(expected);
+ AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+ }
+}
+
TEST(GroupBy, MinMaxOnly) {
for (bool use_exec_plan : {false, true}) {
for (bool use_threads : {true, false}) {
From 774c31bdef28b33a26166b8b1e27c73532e5d6b1 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 25 Aug 2021 15:26:38 -0400
Subject: [PATCH 2/3] ARROW-13691: [Python][R] Update VarianceOptions bindings
---
.../arrow/compute/kernels/aggregate_test.cc | 26 +++++++++++++++++++
.../compute/kernels/aggregate_var_std.cc | 21 ++++++++++-----
python/pyarrow/_compute.pyx | 8 +++---
python/pyarrow/includes/libarrow.pxd | 4 ++-
r/src/compute.cpp | 10 ++++++-
5 files changed, 57 insertions(+), 12 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index b93e33b05e1..bf6aaf8fd13 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -2264,6 +2264,16 @@ class TestPrimitiveVarStdKernel : public ::testing::Test {
AssertVarStdIsInvalid(array, options);
}
+ void AssertVarStdIsNull(const std::string& json, const VarianceOptions& options) {
+ auto array = ArrayFromJSON(type_singleton(), json);
+ ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options));
+ ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options));
+ auto var = checked_cast(out_var.scalar().get());
+ auto std = checked_cast(out_std.scalar().get());
+ ASSERT_FALSE(var->is_valid);
+ ASSERT_FALSE(std->is_valid);
+ }
+
std::shared_ptr type_singleton() { return Traits::type_singleton(); }
private:
@@ -2336,6 +2346,22 @@ TYPED_TEST(TestNumericVarStdKernel, Basics) {
ResultWith(Datum(MakeNullScalar(float64()))));
EXPECT_THAT(Stddev(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64()))));
EXPECT_THAT(Variance(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64()))));
+
+ // skip_nulls and min_count
+ options.ddof = 0;
+ options.min_count = 3;
+ this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
+ this->AssertVarStdIsNull("[1, 2, null]", options);
+
+ options.min_count = 0;
+ options.skip_nulls = false;
+ this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
+ this->AssertVarStdIsNull("[1, 2, 3, null]", options);
+
+ options.min_count = 4;
+ options.skip_nulls = false;
+ this->AssertVarStdIsNull("[1, 2, 3]", options);
+ this->AssertVarStdIsNull("[1, 2, 3, null]", options);
}
// Test numerical stability
diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
index 4fcea3e8e3a..353763035dc 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
@@ -39,13 +39,16 @@ struct VarStdState {
using CType = typename ArrowType::c_type;
using ThisType = VarStdState;
+ explicit VarStdState(VarianceOptions options) : options(options) {}
+
// float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm`
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
template
enable_if_t::value || (sizeof(CType) > 4)> Consume(
const ArrayType& array) {
+ this->valid = array.null_count() == 0;
int64_t count = array.length() - array.null_count();
- if (count == 0) {
+ if (count == 0 || (!this->valid && !options.skip_nulls)) {
return;
}
@@ -75,6 +78,8 @@ struct VarStdState {
// for int32: -2^62 <= sum < 2^62
constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);
+ this->valid = array.null_count() == 0;
+ if (!this->valid && !options.skip_nulls) return;
int64_t start_index = 0;
int64_t valid_count = array.length() - array.null_count();
@@ -98,7 +103,7 @@ struct VarStdState {
});
// merge variance
- ThisType state;
+ ThisType state(options);
state.count = var_std.count;
state.mean = var_std.mean();
state.m2 = var_std.m2();
@@ -116,12 +121,14 @@ struct VarStdState {
} else {
this->count = 0;
this->mean = 0;
+ this->valid = false;
}
}
// Combine `m2` from two chunks (m2 = n*s2)
// https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html
void MergeFrom(const ThisType& state) {
+ this->valid = this->valid && state.valid;
if (state.count == 0) {
return;
}
@@ -135,9 +142,11 @@ struct VarStdState {
&this->mean, &this->m2);
}
+ VarianceOptions options;
int64_t count = 0;
double mean = 0;
double m2 = 0; // m2 = count*s2 = sum((X-mean)^2)
+ bool valid = true;
};
template
@@ -147,7 +156,7 @@ struct VarStdImpl : public ScalarAggregator {
explicit VarStdImpl(const std::shared_ptr& out_type,
const VarianceOptions& options, VarOrStd return_type)
- : out_type(out_type), options(options), return_type(return_type) {}
+ : out_type(out_type), state(options), return_type(return_type) {}
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
@@ -166,10 +175,11 @@ struct VarStdImpl : public ScalarAggregator {
}
Status Finalize(KernelContext*, Datum* out) override {
- if (this->state.count <= options.ddof) {
+ if (state.count <= state.options.ddof || state.count < state.options.min_count ||
+ (!state.valid && !state.options.skip_nulls)) {
out->value = std::make_shared();
} else {
- double var = this->state.m2 / (this->state.count - options.ddof);
+ double var = state.m2 / (state.count - state.options.ddof);
out->value =
std::make_shared(return_type == VarOrStd::Var ? var : sqrt(var));
}
@@ -178,7 +188,6 @@ struct VarStdImpl : public ScalarAggregator {
std::shared_ptr out_type;
VarStdState state;
- VarianceOptions options;
VarOrStd return_type;
};
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 99ad14496ca..39bb5315f7a 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1018,13 +1018,13 @@ class NullOptions(_NullOptions):
cdef class _VarianceOptions(FunctionOptions):
- def _set_options(self, ddof):
- self.wrapped.reset(new CVarianceOptions(ddof))
+ def _set_options(self, ddof, skip_nulls, min_count):
+ self.wrapped.reset(new CVarianceOptions(ddof, skip_nulls, min_count))
class VarianceOptions(_VarianceOptions):
- def __init__(self, *, ddof=0):
- self._set_options(ddof)
+ def __init__(self, *, ddof=0, skip_nulls=True, min_count=0):
+ self._set_options(ddof, skip_nulls, min_count)
cdef class _SplitOptions(FunctionOptions):
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 0271770214b..4f9f4184b2d 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1970,8 +1970,10 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
cdef cppclass CVarianceOptions \
"arrow::compute::VarianceOptions"(CFunctionOptions):
- CVarianceOptions(int ddof)
+ CVarianceOptions(int ddof, c_bool skip_nulls, uint32_t min_count)
int ddof
+ c_bool skip_nulls
+ uint32_t min_count
cdef cppclass CScalarAggregateOptions \
"arrow::compute::ScalarAggregateOptions"(CFunctionOptions):
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 5468fa83113..48c54187e70 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -366,7 +366,15 @@ std::shared_ptr make_compute_options(
if (func_name == "variance" || func_name == "stddev" || func_name == "hash_variance" ||
func_name == "hash_stddev") {
using Options = arrow::compute::VarianceOptions;
- return std::make_shared(cpp11::as_cpp(options["ddof"]));
+ auto out = std::make_shared();
+ out->ddof = cpp11::as_cpp(options["ddof"]);
+ if (!Rf_isNull(options["na.min_count"])) {
+ out->min_count = cpp11::as_cpp(options["na.min_count"]);
+ }
+ if (!Rf_isNull(options["na.rm"])) {
+ out->skip_nulls = cpp11::as_cpp(options["na.rm"]);
+ }
+ return out;
}
return nullptr;
From de982fa0757db90b25413873277a5677355f0bf8 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 31 Aug 2021 12:12:16 -0400
Subject: [PATCH 3/3] ARROW-13691: [C++] Address review feedback
---
.../arrow/compute/kernels/aggregate_test.cc | 18 ++++--------------
.../arrow/compute/kernels/aggregate_var_std.cc | 18 +++++++++---------
2 files changed, 13 insertions(+), 23 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index bf6aaf8fd13..eb73e703b6e 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -2264,16 +2264,6 @@ class TestPrimitiveVarStdKernel : public ::testing::Test {
AssertVarStdIsInvalid(array, options);
}
- void AssertVarStdIsNull(const std::string& json, const VarianceOptions& options) {
- auto array = ArrayFromJSON(type_singleton(), json);
- ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options));
- ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options));
- auto var = checked_cast(out_var.scalar().get());
- auto std = checked_cast(out_std.scalar().get());
- ASSERT_FALSE(var->is_valid);
- ASSERT_FALSE(std->is_valid);
- }
-
std::shared_ptr type_singleton() { return Traits::type_singleton(); }
private:
@@ -2351,17 +2341,17 @@ TYPED_TEST(TestNumericVarStdKernel, Basics) {
options.ddof = 0;
options.min_count = 3;
this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
- this->AssertVarStdIsNull("[1, 2, null]", options);
+ this->AssertVarStdIsInvalid("[1, 2, null]", options);
options.min_count = 0;
options.skip_nulls = false;
this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
- this->AssertVarStdIsNull("[1, 2, 3, null]", options);
+ this->AssertVarStdIsInvalid("[1, 2, 3, null]", options);
options.min_count = 4;
options.skip_nulls = false;
- this->AssertVarStdIsNull("[1, 2, 3]", options);
- this->AssertVarStdIsNull("[1, 2, 3, null]", options);
+ this->AssertVarStdIsInvalid("[1, 2, 3]", options);
+ this->AssertVarStdIsInvalid("[1, 2, 3, null]", options);
}
// Test numerical stability
diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
index 353763035dc..42ac655877c 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
@@ -46,9 +46,9 @@ struct VarStdState {
template
enable_if_t::value || (sizeof(CType) > 4)> Consume(
const ArrayType& array) {
- this->valid = array.null_count() == 0;
+ this->all_valid = array.null_count() == 0;
int64_t count = array.length() - array.null_count();
- if (count == 0 || (!this->valid && !options.skip_nulls)) {
+ if (count == 0 || (!this->all_valid && !options.skip_nulls)) {
return;
}
@@ -78,8 +78,8 @@ struct VarStdState {
// for int32: -2^62 <= sum < 2^62
constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);
- this->valid = array.null_count() == 0;
- if (!this->valid && !options.skip_nulls) return;
+ this->all_valid = array.null_count() == 0;
+ if (!this->all_valid && !options.skip_nulls) return;
int64_t start_index = 0;
int64_t valid_count = array.length() - array.null_count();
@@ -121,14 +121,14 @@ struct VarStdState {
} else {
this->count = 0;
this->mean = 0;
- this->valid = false;
+ this->all_valid = false;
}
}
// Combine `m2` from two chunks (m2 = n*s2)
// https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html
void MergeFrom(const ThisType& state) {
- this->valid = this->valid && state.valid;
+ this->all_valid = this->all_valid && state.all_valid;
if (state.count == 0) {
return;
}
@@ -142,11 +142,11 @@ struct VarStdState {
&this->mean, &this->m2);
}
- VarianceOptions options;
+ const VarianceOptions options;
int64_t count = 0;
double mean = 0;
double m2 = 0; // m2 = count*s2 = sum((X-mean)^2)
- bool valid = true;
+ bool all_valid = true;
};
template
@@ -176,7 +176,7 @@ struct VarStdImpl : public ScalarAggregator {
Status Finalize(KernelContext*, Datum* out) override {
if (state.count <= state.options.ddof || state.count < state.options.min_count ||
- (!state.valid && !state.options.skip_nulls)) {
+ (!state.all_valid && !state.options.skip_nulls)) {
out->value = std::make_shared();
} else {
double var = state.m2 / (state.count - state.options.ddof);