From 375a26a6606ea8825435fca5d8e86c5156fef7e6 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 25 Aug 2021 15:11:15 -0400 Subject: [PATCH 1/3] ARROW-13691: [C++] Support min_count/skip_nulls in VarianceOptions --- cpp/src/arrow/compute/api_aggregate.cc | 13 +- cpp/src/arrow/compute/api_aggregate.h | 7 +- .../arrow/compute/kernels/hash_aggregate.cc | 119 ++++++++++++------ .../compute/kernels/hash_aggregate_test.cc | 101 +++++++++++++++ 4 files changed, 196 insertions(+), 44 deletions(-) diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc index af7aec865fc..6d7bdfa6cf9 100644 --- a/cpp/src/arrow/compute/api_aggregate.cc +++ b/cpp/src/arrow/compute/api_aggregate.cc @@ -87,8 +87,10 @@ static auto kCountOptionsType = GetFunctionOptionsType(DataMember("mode", &CountOptions::mode)); static auto kModeOptionsType = GetFunctionOptionsType(DataMember("n", &ModeOptions::n)); -static auto kVarianceOptionsType = - GetFunctionOptionsType(DataMember("ddof", &VarianceOptions::ddof)); +static auto kVarianceOptionsType = GetFunctionOptionsType( + DataMember("ddof", &VarianceOptions::ddof), + DataMember("skip_nulls", &VarianceOptions::skip_nulls), + DataMember("min_count", &VarianceOptions::min_count)); static auto kQuantileOptionsType = GetFunctionOptionsType( DataMember("q", &QuantileOptions::q), DataMember("interpolation", &QuantileOptions::interpolation)); @@ -113,8 +115,11 @@ constexpr char CountOptions::kTypeName[]; ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {} constexpr char ModeOptions::kTypeName[]; -VarianceOptions::VarianceOptions(int ddof) - : FunctionOptions(internal::kVarianceOptionsType), ddof(ddof) {} +VarianceOptions::VarianceOptions(int ddof, bool skip_nulls, uint32_t min_count) + : FunctionOptions(internal::kVarianceOptionsType), + ddof(ddof), + skip_nulls(skip_nulls), + min_count(min_count) {} constexpr char VarianceOptions::kTypeName[]; QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation) diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index d8cda022de8..8c27da49765 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -95,11 +95,16 @@ class ARROW_EXPORT ModeOptions : public FunctionOptions { /// By default, ddof is zero, and population variance or stddev is returned. class ARROW_EXPORT VarianceOptions : public FunctionOptions { public: - explicit VarianceOptions(int ddof = 0); + explicit VarianceOptions(int ddof = 0, bool skip_nulls = true, uint32_t min_count = 0); constexpr static char const kTypeName[] = "VarianceOptions"; static VarianceOptions Defaults() { return VarianceOptions{}; } int ddof = 0; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; }; /// \brief Control Quantile kernel behavior diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index f75009f0077..3ea692857cf 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -1304,18 +1304,20 @@ struct GroupedVarStdImpl : public GroupedAggregator { options_ = *checked_cast(options); ctx_ = ctx; pool_ = ctx->memory_pool(); - counts_ = BufferBuilder(pool_); - means_ = BufferBuilder(pool_); - m2s_ = BufferBuilder(pool_); + counts_ = TypedBufferBuilder(pool_); + means_ = TypedBufferBuilder(pool_); + m2s_ = TypedBufferBuilder(pool_); + no_nulls_ = TypedBufferBuilder(pool_); return Status::OK(); } Status Resize(int64_t new_num_groups) override { auto added_groups = new_num_groups - num_groups_; num_groups_ = new_num_groups; - RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0)); - RETURN_NOT_OK(means_.Append(added_groups * sizeof(double), 0)); - RETURN_NOT_OK(m2s_.Append(added_groups * sizeof(double), 0)); + RETURN_NOT_OK(counts_.Append(added_groups, 0)); + RETURN_NOT_OK(means_.Append(added_groups, 0)); + RETURN_NOT_OK(m2s_.Append(added_groups, 0)); + RETURN_NOT_OK(no_nulls_.Append(added_groups, true)); return Status::OK(); } @@ -1332,18 +1334,21 @@ struct GroupedVarStdImpl : public GroupedAggregator { GroupedVarStdImpl state; RETURN_NOT_OK(state.Init(ctx_, &options_)); RETURN_NOT_OK(state.Resize(num_groups_)); - int64_t* counts = reinterpret_cast(state.counts_.mutable_data()); - double* means = reinterpret_cast(state.means_.mutable_data()); - double* m2s = reinterpret_cast(state.m2s_.mutable_data()); + int64_t* counts = state.counts_.mutable_data(); + double* means = state.means_.mutable_data(); + double* m2s = state.m2s_.mutable_data(); + uint8_t* no_nulls = state.no_nulls_.mutable_data(); // XXX this uses naive summation; we should switch to pairwise summation as was // done for the scalar aggregate kernel in ARROW-11567 std::vector sums(num_groups_); - VisitGroupedValuesNonNull( - batch, [&](uint32_t g, typename TypeTraits::CType value) { + VisitGroupedValues( + batch, + [&](uint32_t g, typename TypeTraits::CType value) { sums[g] += value; counts[g]++; - }); + }, + [&](uint32_t g) { BitUtil::ClearBit(no_nulls, g); }); for (int64_t i = 0; i < num_groups_; i++) { means[i] = static_cast(sums[i]) / counts[i]; @@ -1362,9 +1367,7 @@ struct GroupedVarStdImpl : public GroupedAggregator { } ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)}, /*null_count=*/0); - RETURN_NOT_OK(this->Merge(std::move(state), group_id_mapping)); - - return Status::OK(); + return this->Merge(std::move(state), group_id_mapping); } // int32/16/8: textbook one pass algorithm with integer arithmetic (see @@ -1377,12 +1380,15 @@ struct GroupedVarStdImpl : public GroupedAggregator { // for int32: -2^62 <= sum < 2^62 constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8); + const auto g = batch[1].array()->GetValues(1); if (batch[0].is_scalar() && !batch[0].scalar()->is_valid) { + uint8_t* no_nulls = no_nulls_.mutable_data(); + for (int64_t i = 0; i < batch.length; i++) { + BitUtil::ClearBit(no_nulls, g[i]); + } return Status::OK(); } - const auto g = batch[1].array()->GetValues(1); - std::vector> var_std(num_groups_); ARROW_ASSIGN_OR_RAISE(auto mapping, @@ -1402,23 +1408,42 @@ struct GroupedVarStdImpl : public GroupedAggregator { GroupedVarStdImpl state; RETURN_NOT_OK(state.Init(ctx_, &options_)); RETURN_NOT_OK(state.Resize(num_groups_)); - int64_t* other_counts = reinterpret_cast(state.counts_.mutable_data()); - double* other_means = reinterpret_cast(state.means_.mutable_data()); - double* other_m2s = reinterpret_cast(state.m2s_.mutable_data()); + int64_t* other_counts = state.counts_.mutable_data(); + double* other_means = state.means_.mutable_data(); + double* other_m2s = state.m2s_.mutable_data(); + uint8_t* other_no_nulls = state.no_nulls_.mutable_data(); if (batch[0].is_array()) { const auto& array = *batch[0].array(); const CType* values = array.GetValues(1); - arrow::internal::VisitSetBitRunsVoid( - array.buffers[0], array.offset + start_index, - std::min(max_length, batch.length - start_index), - [&](int64_t pos, int64_t len) { - for (int64_t i = 0; i < len; ++i) { - const int64_t index = start_index + pos + i; - const auto value = values[index]; - var_std[g[index]].ConsumeOne(value); + auto visit_values = [&](int64_t pos, int64_t len) { + for (int64_t i = 0; i < len; ++i) { + const int64_t index = start_index + pos + i; + const auto value = values[index]; + var_std[g[index]].ConsumeOne(value); + } + }; + + if (array.MayHaveNulls()) { + arrow::internal::BitRunReader reader( + array.buffers[0]->data(), array.offset + start_index, + std::min(max_length, batch.length - start_index)); + int64_t position = 0; + while (true) { + auto run = reader.NextRun(); + if (run.length == 0) break; + if (run.set) { + visit_values(position, run.length); + } else { + for (int64_t i = 0; i < run.length; ++i) { + BitUtil::ClearBit(other_no_nulls, g[start_index + position + i]); } - }); + } + position += run.length; + } + } else { + visit_values(0, array.length); + } } else { const auto value = UnboxScalar::Unbox(*batch[0].scalar()); for (int64_t i = 0; i < std::min(max_length, batch.length - start_index); ++i) { @@ -1444,16 +1469,21 @@ struct GroupedVarStdImpl : public GroupedAggregator { // Combine m2 from two chunks (see aggregate_var_std.cc) auto other = checked_cast(&raw_other); - auto counts = reinterpret_cast(counts_.mutable_data()); - auto means = reinterpret_cast(means_.mutable_data()); - auto m2s = reinterpret_cast(m2s_.mutable_data()); + int64_t* counts = counts_.mutable_data(); + double* means = means_.mutable_data(); + double* m2s = m2s_.mutable_data(); + uint8_t* no_nulls = no_nulls_.mutable_data(); - const auto* other_counts = reinterpret_cast(other->counts_.data()); - const auto* other_means = reinterpret_cast(other->means_.data()); - const auto* other_m2s = reinterpret_cast(other->m2s_.data()); + const int64_t* other_counts = other->counts_.data(); + const double* other_means = other->means_.data(); + const double* other_m2s = other->m2s_.data(); + const uint8_t* other_no_nulls = other->no_nulls_.data(); auto g = group_id_mapping.GetValues(1); for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { + if (!BitUtil::GetBit(other_no_nulls, other_g)) { + BitUtil::ClearBit(no_nulls, *g); + } if (other_counts[other_g] == 0) continue; MergeVarStd(counts[*g], means[*g], other_counts[other_g], other_means[other_g], other_m2s[other_g], &counts[*g], &means[*g], &m2s[*g]); @@ -1468,10 +1498,10 @@ struct GroupedVarStdImpl : public GroupedAggregator { int64_t null_count = 0; double* results = reinterpret_cast(values->mutable_data()); - const int64_t* counts = reinterpret_cast(counts_.data()); - const double* m2s = reinterpret_cast(m2s_.data()); + const int64_t* counts = counts_.data(); + const double* m2s = m2s_.data(); for (int64_t i = 0; i < num_groups_; ++i) { - if (counts[i] > options_.ddof) { + if (counts[i] > options_.ddof && counts[i] >= options_.min_count) { const double variance = m2s[i] / (counts[i] - options_.ddof); results[i] = result_type_ == VarOrStd::Var ? variance : std::sqrt(variance); continue; @@ -1486,6 +1516,15 @@ struct GroupedVarStdImpl : public GroupedAggregator { null_count += 1; BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); } + if (!options_.skip_nulls) { + if (null_bitmap) { + arrow::internal::BitmapAnd(null_bitmap->data(), 0, no_nulls_.data(), 0, + num_groups_, 0, null_bitmap->mutable_data()); + } else { + ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish()); + } + null_count = kUnknownNullCount; + } return ArrayData::Make(float64(), num_groups_, {std::move(null_bitmap), std::move(values)}, null_count); @@ -1497,7 +1536,9 @@ struct GroupedVarStdImpl : public GroupedAggregator { VarianceOptions options_; int64_t num_groups_ = 0; // m2 = count * s2 = sum((X-mean)^2) - BufferBuilder counts_, means_, m2s_; + TypedBufferBuilder counts_; + TypedBufferBuilder means_, m2s_; + TypedBufferBuilder no_nulls_; ExecContext* ctx_; MemoryPool* pool_; }; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 8fe027490be..32e8efa0ab8 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1183,6 +1183,107 @@ TEST(GroupBy, StddevVarianceTDigestScalar) { } } +TEST(GroupBy, VarianceOptions) { + BatchesWithSchema input; + input.batches = { + ExecBatchFromJSON( + {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()}, + "[[1, 1.0, 1], [1, 1.0, 1], [1, 1.0, 2], [1, 1.0, 2], [1, 1.0, 3]]"), + ExecBatchFromJSON( + {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()}, + "[[1, 1.0, 4], [1, 1.0, 4]]"), + ExecBatchFromJSON( + {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()}, + "[[null, null, 1]]"), + ExecBatchFromJSON({int32(), float32(), int64()}, "[[2, 2.0, 1], [3, 3.0, 2]]"), + ExecBatchFromJSON({int32(), float32(), int64()}, "[[4, 4.0, 2], [2, 2.0, 4]]"), + ExecBatchFromJSON({int32(), float32(), int64()}, "[[null, null, 4]]"), + }; + input.schema = schema( + {field("argument", int32()), field("argument1", float32()), field("key", int64())}); + + VarianceOptions keep_nulls(/*ddof=*/0, /*skip_nulls=*/false, /*min_count=*/0); + VarianceOptions min_count(/*ddof=*/0, /*skip_nulls=*/true, /*min_count=*/3); + VarianceOptions keep_nulls_min_count(/*ddof=*/0, /*skip_nulls=*/false, /*min_count=*/3); + + for (bool use_threads : {false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN( + Datum actual, GroupByUsingExecPlan(input, {"key"}, + { + "argument", + "argument", + "argument", + "argument", + "argument", + "argument", + }, + { + {"hash_stddev", &keep_nulls}, + {"hash_stddev", &min_count}, + {"hash_stddev", &keep_nulls_min_count}, + {"hash_variance", &keep_nulls}, + {"hash_variance", &min_count}, + {"hash_variance", &keep_nulls_min_count}, + }, + use_threads, default_exec_context())); + Datum expected = ArrayFromJSON(struct_({ + field("hash_stddev", float64()), + field("hash_stddev", float64()), + field("hash_stddev", float64()), + field("hash_variance", float64()), + field("hash_variance", float64()), + field("hash_variance", float64()), + field("key", int64()), + }), + R"([ + [null, 0.471405, null, null, 0.222222, null, 1], + [1.29904, 1.29904, 1.29904, 1.6875, 1.6875, 1.6875, 2], + [0.0, null, null, 0.0, null, null, 3], + [null, 0.471405, null, null, 0.222222, null, 4] + ])"); + ValidateOutput(expected); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + + ASSERT_OK_AND_ASSIGN( + actual, GroupByUsingExecPlan(input, {"key"}, + { + "argument1", + "argument1", + "argument1", + "argument1", + "argument1", + "argument1", + }, + { + {"hash_stddev", &keep_nulls}, + {"hash_stddev", &min_count}, + {"hash_stddev", &keep_nulls_min_count}, + {"hash_variance", &keep_nulls}, + {"hash_variance", &min_count}, + {"hash_variance", &keep_nulls_min_count}, + }, + use_threads, default_exec_context())); + expected = ArrayFromJSON(struct_({ + field("hash_stddev", float64()), + field("hash_stddev", float64()), + field("hash_stddev", float64()), + field("hash_variance", float64()), + field("hash_variance", float64()), + field("hash_variance", float64()), + field("key", int64()), + }), + R"([ + [null, 0.471405, null, null, 0.222222, null, 1], + [1.29904, 1.29904, 1.29904, 1.6875, 1.6875, 1.6875, 2], + [0.0, null, null, 0.0, null, null, 3], + [null, 0.471405, null, null, 0.222222, null, 4] + ])"); + ValidateOutput(expected); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST(GroupBy, MinMaxOnly) { for (bool use_exec_plan : {false, true}) { for (bool use_threads : {true, false}) { From 774c31bdef28b33a26166b8b1e27c73532e5d6b1 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 25 Aug 2021 15:26:38 -0400 Subject: [PATCH 2/3] ARROW-13691: [Python][R] Update VarianceOptions bindings --- .../arrow/compute/kernels/aggregate_test.cc | 26 +++++++++++++++++++ .../compute/kernels/aggregate_var_std.cc | 21 ++++++++++----- python/pyarrow/_compute.pyx | 8 +++--- python/pyarrow/includes/libarrow.pxd | 4 ++- r/src/compute.cpp | 10 ++++++- 5 files changed, 57 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index b93e33b05e1..bf6aaf8fd13 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -2264,6 +2264,16 @@ class TestPrimitiveVarStdKernel : public ::testing::Test { AssertVarStdIsInvalid(array, options); } + void AssertVarStdIsNull(const std::string& json, const VarianceOptions& options) { + auto array = ArrayFromJSON(type_singleton(), json); + ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options)); + ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options)); + auto var = checked_cast(out_var.scalar().get()); + auto std = checked_cast(out_std.scalar().get()); + ASSERT_FALSE(var->is_valid); + ASSERT_FALSE(std->is_valid); + } + std::shared_ptr type_singleton() { return Traits::type_singleton(); } private: @@ -2336,6 +2346,22 @@ TYPED_TEST(TestNumericVarStdKernel, Basics) { ResultWith(Datum(MakeNullScalar(float64())))); EXPECT_THAT(Stddev(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64())))); EXPECT_THAT(Variance(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64())))); + + // skip_nulls and min_count + options.ddof = 0; + options.min_count = 3; + this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666); + this->AssertVarStdIsNull("[1, 2, null]", options); + + options.min_count = 0; + options.skip_nulls = false; + this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666); + this->AssertVarStdIsNull("[1, 2, 3, null]", options); + + options.min_count = 4; + options.skip_nulls = false; + this->AssertVarStdIsNull("[1, 2, 3]", options); + this->AssertVarStdIsNull("[1, 2, 3, null]", options); } // Test numerical stability diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index 4fcea3e8e3a..353763035dc 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -39,13 +39,16 @@ struct VarStdState { using CType = typename ArrowType::c_type; using ThisType = VarStdState; + explicit VarStdState(VarianceOptions options) : options(options) {} + // float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm` // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm template enable_if_t::value || (sizeof(CType) > 4)> Consume( const ArrayType& array) { + this->valid = array.null_count() == 0; int64_t count = array.length() - array.null_count(); - if (count == 0) { + if (count == 0 || (!this->valid && !options.skip_nulls)) { return; } @@ -75,6 +78,8 @@ struct VarStdState { // for int32: -2^62 <= sum < 2^62 constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8); + this->valid = array.null_count() == 0; + if (!this->valid && !options.skip_nulls) return; int64_t start_index = 0; int64_t valid_count = array.length() - array.null_count(); @@ -98,7 +103,7 @@ struct VarStdState { }); // merge variance - ThisType state; + ThisType state(options); state.count = var_std.count; state.mean = var_std.mean(); state.m2 = var_std.m2(); @@ -116,12 +121,14 @@ struct VarStdState { } else { this->count = 0; this->mean = 0; + this->valid = false; } } // Combine `m2` from two chunks (m2 = n*s2) // https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html void MergeFrom(const ThisType& state) { + this->valid = this->valid && state.valid; if (state.count == 0) { return; } @@ -135,9 +142,11 @@ struct VarStdState { &this->mean, &this->m2); } + VarianceOptions options; int64_t count = 0; double mean = 0; double m2 = 0; // m2 = count*s2 = sum((X-mean)^2) + bool valid = true; }; template @@ -147,7 +156,7 @@ struct VarStdImpl : public ScalarAggregator { explicit VarStdImpl(const std::shared_ptr& out_type, const VarianceOptions& options, VarOrStd return_type) - : out_type(out_type), options(options), return_type(return_type) {} + : out_type(out_type), state(options), return_type(return_type) {} Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { @@ -166,10 +175,11 @@ struct VarStdImpl : public ScalarAggregator { } Status Finalize(KernelContext*, Datum* out) override { - if (this->state.count <= options.ddof) { + if (state.count <= state.options.ddof || state.count < state.options.min_count || + (!state.valid && !state.options.skip_nulls)) { out->value = std::make_shared(); } else { - double var = this->state.m2 / (this->state.count - options.ddof); + double var = state.m2 / (state.count - state.options.ddof); out->value = std::make_shared(return_type == VarOrStd::Var ? var : sqrt(var)); } @@ -178,7 +188,6 @@ struct VarStdImpl : public ScalarAggregator { std::shared_ptr out_type; VarStdState state; - VarianceOptions options; VarOrStd return_type; }; diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 99ad14496ca..39bb5315f7a 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1018,13 +1018,13 @@ class NullOptions(_NullOptions): cdef class _VarianceOptions(FunctionOptions): - def _set_options(self, ddof): - self.wrapped.reset(new CVarianceOptions(ddof)) + def _set_options(self, ddof, skip_nulls, min_count): + self.wrapped.reset(new CVarianceOptions(ddof, skip_nulls, min_count)) class VarianceOptions(_VarianceOptions): - def __init__(self, *, ddof=0): - self._set_options(ddof) + def __init__(self, *, ddof=0, skip_nulls=True, min_count=0): + self._set_options(ddof, skip_nulls, min_count) cdef class _SplitOptions(FunctionOptions): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0271770214b..4f9f4184b2d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1970,8 +1970,10 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CVarianceOptions \ "arrow::compute::VarianceOptions"(CFunctionOptions): - CVarianceOptions(int ddof) + CVarianceOptions(int ddof, c_bool skip_nulls, uint32_t min_count) int ddof + c_bool skip_nulls + uint32_t min_count cdef cppclass CScalarAggregateOptions \ "arrow::compute::ScalarAggregateOptions"(CFunctionOptions): diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 5468fa83113..48c54187e70 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -366,7 +366,15 @@ std::shared_ptr make_compute_options( if (func_name == "variance" || func_name == "stddev" || func_name == "hash_variance" || func_name == "hash_stddev") { using Options = arrow::compute::VarianceOptions; - return std::make_shared(cpp11::as_cpp(options["ddof"])); + auto out = std::make_shared(); + out->ddof = cpp11::as_cpp(options["ddof"]); + if (!Rf_isNull(options["na.min_count"])) { + out->min_count = cpp11::as_cpp(options["na.min_count"]); + } + if (!Rf_isNull(options["na.rm"])) { + out->skip_nulls = cpp11::as_cpp(options["na.rm"]); + } + return out; } return nullptr; From de982fa0757db90b25413873277a5677355f0bf8 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 31 Aug 2021 12:12:16 -0400 Subject: [PATCH 3/3] ARROW-13691: [C++] Address review feedback --- .../arrow/compute/kernels/aggregate_test.cc | 18 ++++-------------- .../arrow/compute/kernels/aggregate_var_std.cc | 18 +++++++++--------- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index bf6aaf8fd13..eb73e703b6e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -2264,16 +2264,6 @@ class TestPrimitiveVarStdKernel : public ::testing::Test { AssertVarStdIsInvalid(array, options); } - void AssertVarStdIsNull(const std::string& json, const VarianceOptions& options) { - auto array = ArrayFromJSON(type_singleton(), json); - ASSERT_OK_AND_ASSIGN(Datum out_var, Variance(array, options)); - ASSERT_OK_AND_ASSIGN(Datum out_std, Stddev(array, options)); - auto var = checked_cast(out_var.scalar().get()); - auto std = checked_cast(out_std.scalar().get()); - ASSERT_FALSE(var->is_valid); - ASSERT_FALSE(std->is_valid); - } - std::shared_ptr type_singleton() { return Traits::type_singleton(); } private: @@ -2351,17 +2341,17 @@ TYPED_TEST(TestNumericVarStdKernel, Basics) { options.ddof = 0; options.min_count = 3; this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666); - this->AssertVarStdIsNull("[1, 2, null]", options); + this->AssertVarStdIsInvalid("[1, 2, null]", options); options.min_count = 0; options.skip_nulls = false; this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666); - this->AssertVarStdIsNull("[1, 2, 3, null]", options); + this->AssertVarStdIsInvalid("[1, 2, 3, null]", options); options.min_count = 4; options.skip_nulls = false; - this->AssertVarStdIsNull("[1, 2, 3]", options); - this->AssertVarStdIsNull("[1, 2, 3, null]", options); + this->AssertVarStdIsInvalid("[1, 2, 3]", options); + this->AssertVarStdIsInvalid("[1, 2, 3, null]", options); } // Test numerical stability diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc index 353763035dc..42ac655877c 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc @@ -46,9 +46,9 @@ struct VarStdState { template enable_if_t::value || (sizeof(CType) > 4)> Consume( const ArrayType& array) { - this->valid = array.null_count() == 0; + this->all_valid = array.null_count() == 0; int64_t count = array.length() - array.null_count(); - if (count == 0 || (!this->valid && !options.skip_nulls)) { + if (count == 0 || (!this->all_valid && !options.skip_nulls)) { return; } @@ -78,8 +78,8 @@ struct VarStdState { // for int32: -2^62 <= sum < 2^62 constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8); - this->valid = array.null_count() == 0; - if (!this->valid && !options.skip_nulls) return; + this->all_valid = array.null_count() == 0; + if (!this->all_valid && !options.skip_nulls) return; int64_t start_index = 0; int64_t valid_count = array.length() - array.null_count(); @@ -121,14 +121,14 @@ struct VarStdState { } else { this->count = 0; this->mean = 0; - this->valid = false; + this->all_valid = false; } } // Combine `m2` from two chunks (m2 = n*s2) // https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html void MergeFrom(const ThisType& state) { - this->valid = this->valid && state.valid; + this->all_valid = this->all_valid && state.all_valid; if (state.count == 0) { return; } @@ -142,11 +142,11 @@ struct VarStdState { &this->mean, &this->m2); } - VarianceOptions options; + const VarianceOptions options; int64_t count = 0; double mean = 0; double m2 = 0; // m2 = count*s2 = sum((X-mean)^2) - bool valid = true; + bool all_valid = true; }; template @@ -176,7 +176,7 @@ struct VarStdImpl : public ScalarAggregator { Status Finalize(KernelContext*, Datum* out) override { if (state.count <= state.options.ddof || state.count < state.options.min_count || - (!state.valid && !state.options.skip_nulls)) { + (!state.all_valid && !state.options.skip_nulls)) { out->value = std::make_shared(); } else { double var = state.m2 / (state.count - state.options.ddof);