Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions cpp/src/arrow/compute/api_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ static auto kCountOptionsType =
GetFunctionOptionsType<CountOptions>(DataMember("mode", &CountOptions::mode));
static auto kModeOptionsType =
GetFunctionOptionsType<ModeOptions>(DataMember("n", &ModeOptions::n));
static auto kVarianceOptionsType =
GetFunctionOptionsType<VarianceOptions>(DataMember("ddof", &VarianceOptions::ddof));
static auto kVarianceOptionsType = GetFunctionOptionsType<VarianceOptions>(
DataMember("ddof", &VarianceOptions::ddof),
DataMember("skip_nulls", &VarianceOptions::skip_nulls),
DataMember("min_count", &VarianceOptions::min_count));
static auto kQuantileOptionsType = GetFunctionOptionsType<QuantileOptions>(
DataMember("q", &QuantileOptions::q),
DataMember("interpolation", &QuantileOptions::interpolation));
Expand All @@ -113,8 +115,11 @@ constexpr char CountOptions::kTypeName[];
ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {}
constexpr char ModeOptions::kTypeName[];

VarianceOptions::VarianceOptions(int ddof)
: FunctionOptions(internal::kVarianceOptionsType), ddof(ddof) {}
VarianceOptions::VarianceOptions(int ddof, bool skip_nulls, uint32_t min_count)
: FunctionOptions(internal::kVarianceOptionsType),
ddof(ddof),
skip_nulls(skip_nulls),
min_count(min_count) {}
constexpr char VarianceOptions::kTypeName[];

QuantileOptions::QuantileOptions(double q, enum Interpolation interpolation)
Expand Down
7 changes: 6 additions & 1 deletion cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,16 @@ class ARROW_EXPORT ModeOptions : public FunctionOptions {
/// By default, ddof is zero, and population variance or stddev is returned.
class ARROW_EXPORT VarianceOptions : public FunctionOptions {
public:
explicit VarianceOptions(int ddof = 0);
explicit VarianceOptions(int ddof = 0, bool skip_nulls = true, uint32_t min_count = 0);
constexpr static char const kTypeName[] = "VarianceOptions";
static VarianceOptions Defaults() { return VarianceOptions{}; }

int ddof = 0;
/// If true (the default), null values are ignored. Otherwise, if any value is null,
/// emit null.
bool skip_nulls;
/// If less than this many non-null values are observed, emit null.
uint32_t min_count;
};

/// \brief Control Quantile kernel behavior
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,22 @@ TYPED_TEST(TestNumericVarStdKernel, Basics) {
ResultWith(Datum(MakeNullScalar(float64()))));
EXPECT_THAT(Stddev(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64()))));
EXPECT_THAT(Variance(MakeNullScalar(ty)), ResultWith(Datum(MakeNullScalar(float64()))));

// skip_nulls and min_count
options.ddof = 0;
options.min_count = 3;
this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
this->AssertVarStdIsInvalid("[1, 2, null]", options);

options.min_count = 0;
options.skip_nulls = false;
this->AssertVarStdIs("[1, 2, 3]", options, 0.6666666666666666);
this->AssertVarStdIsInvalid("[1, 2, 3, null]", options);

options.min_count = 4;
options.skip_nulls = false;
this->AssertVarStdIsInvalid("[1, 2, 3]", options);
this->AssertVarStdIsInvalid("[1, 2, 3, null]", options);
}

// Test numerical stability
Expand Down
21 changes: 15 additions & 6 deletions cpp/src/arrow/compute/kernels/aggregate_var_std.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ struct VarStdState {
using CType = typename ArrowType::c_type;
using ThisType = VarStdState<ArrowType>;

explicit VarStdState(VarianceOptions options) : options(options) {}

// float/double/int64: calculate `m2` (sum((X-mean)^2)) with `two pass algorithm`
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
template <typename T = ArrowType>
enable_if_t<is_floating_type<T>::value || (sizeof(CType) > 4)> Consume(
const ArrayType& array) {
this->all_valid = array.null_count() == 0;
int64_t count = array.length() - array.null_count();
if (count == 0) {
if (count == 0 || (!this->all_valid && !options.skip_nulls)) {
return;
}

Expand Down Expand Up @@ -75,6 +78,8 @@ struct VarStdState {
// for int32: -2^62 <= sum < 2^62
constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);

this->all_valid = array.null_count() == 0;
if (!this->all_valid && !options.skip_nulls) return;
int64_t start_index = 0;
int64_t valid_count = array.length() - array.null_count();

Expand All @@ -98,7 +103,7 @@ struct VarStdState {
});

// merge variance
ThisType state;
ThisType state(options);
state.count = var_std.count;
state.mean = var_std.mean();
state.m2 = var_std.m2();
Expand All @@ -116,12 +121,14 @@ struct VarStdState {
} else {
this->count = 0;
this->mean = 0;
this->all_valid = false;
}
}

// Combine `m2` from two chunks (m2 = n*s2)
// https://www.emathzone.com/tutorials/basic-statistics/combined-variance.html
void MergeFrom(const ThisType& state) {
this->all_valid = this->all_valid && state.all_valid;
if (state.count == 0) {
return;
}
Expand All @@ -135,9 +142,11 @@ struct VarStdState {
&this->mean, &this->m2);
}

const VarianceOptions options;
int64_t count = 0;
double mean = 0;
double m2 = 0; // m2 = count*s2 = sum((X-mean)^2)
bool all_valid = true;
};

template <typename ArrowType>
Expand All @@ -147,7 +156,7 @@ struct VarStdImpl : public ScalarAggregator {

explicit VarStdImpl(const std::shared_ptr<DataType>& out_type,
const VarianceOptions& options, VarOrStd return_type)
: out_type(out_type), options(options), return_type(return_type) {}
: out_type(out_type), state(options), return_type(return_type) {}

Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
Expand All @@ -166,10 +175,11 @@ struct VarStdImpl : public ScalarAggregator {
}

Status Finalize(KernelContext*, Datum* out) override {
if (this->state.count <= options.ddof) {
if (state.count <= state.options.ddof || state.count < state.options.min_count ||
(!state.all_valid && !state.options.skip_nulls)) {
out->value = std::make_shared<DoubleScalar>();
} else {
double var = this->state.m2 / (this->state.count - options.ddof);
double var = state.m2 / (state.count - state.options.ddof);
out->value =
std::make_shared<DoubleScalar>(return_type == VarOrStd::Var ? var : sqrt(var));
}
Expand All @@ -178,7 +188,6 @@ struct VarStdImpl : public ScalarAggregator {

std::shared_ptr<DataType> out_type;
VarStdState<ArrowType> state;
VarianceOptions options;
VarOrStd return_type;
};

Expand Down
119 changes: 80 additions & 39 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1304,18 +1304,20 @@ struct GroupedVarStdImpl : public GroupedAggregator {
options_ = *checked_cast<const VarianceOptions*>(options);
ctx_ = ctx;
pool_ = ctx->memory_pool();
counts_ = BufferBuilder(pool_);
means_ = BufferBuilder(pool_);
m2s_ = BufferBuilder(pool_);
counts_ = TypedBufferBuilder<int64_t>(pool_);
means_ = TypedBufferBuilder<double>(pool_);
m2s_ = TypedBufferBuilder<double>(pool_);
no_nulls_ = TypedBufferBuilder<bool>(pool_);
return Status::OK();
}

Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
RETURN_NOT_OK(counts_.Append(added_groups * sizeof(int64_t), 0));
RETURN_NOT_OK(means_.Append(added_groups * sizeof(double), 0));
RETURN_NOT_OK(m2s_.Append(added_groups * sizeof(double), 0));
RETURN_NOT_OK(counts_.Append(added_groups, 0));
RETURN_NOT_OK(means_.Append(added_groups, 0));
RETURN_NOT_OK(m2s_.Append(added_groups, 0));
RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
return Status::OK();
}

Expand All @@ -1332,18 +1334,21 @@ struct GroupedVarStdImpl : public GroupedAggregator {
GroupedVarStdImpl<Type> state;
RETURN_NOT_OK(state.Init(ctx_, &options_));
RETURN_NOT_OK(state.Resize(num_groups_));
int64_t* counts = reinterpret_cast<int64_t*>(state.counts_.mutable_data());
double* means = reinterpret_cast<double*>(state.means_.mutable_data());
double* m2s = reinterpret_cast<double*>(state.m2s_.mutable_data());
int64_t* counts = state.counts_.mutable_data();
double* means = state.means_.mutable_data();
double* m2s = state.m2s_.mutable_data();
uint8_t* no_nulls = state.no_nulls_.mutable_data();

// XXX this uses naive summation; we should switch to pairwise summation as was
// done for the scalar aggregate kernel in ARROW-11567
std::vector<SumType> sums(num_groups_);
VisitGroupedValuesNonNull<Type>(
batch, [&](uint32_t g, typename TypeTraits<Type>::CType value) {
VisitGroupedValues<Type>(
batch,
[&](uint32_t g, typename TypeTraits<Type>::CType value) {
sums[g] += value;
counts[g]++;
});
},
[&](uint32_t g) { BitUtil::ClearBit(no_nulls, g); });

for (int64_t i = 0; i < num_groups_; i++) {
means[i] = static_cast<double>(sums[i]) / counts[i];
Expand All @@ -1362,9 +1367,7 @@ struct GroupedVarStdImpl : public GroupedAggregator {
}
ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)},
/*null_count=*/0);
RETURN_NOT_OK(this->Merge(std::move(state), group_id_mapping));

return Status::OK();
return this->Merge(std::move(state), group_id_mapping);
}

// int32/16/8: textbook one pass algorithm with integer arithmetic (see
Expand All @@ -1377,12 +1380,15 @@ struct GroupedVarStdImpl : public GroupedAggregator {
// for int32: -2^62 <= sum < 2^62
constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8);

const auto g = batch[1].array()->GetValues<uint32_t>(1);
if (batch[0].is_scalar() && !batch[0].scalar()->is_valid) {
uint8_t* no_nulls = no_nulls_.mutable_data();
for (int64_t i = 0; i < batch.length; i++) {
BitUtil::ClearBit(no_nulls, g[i]);
}
return Status::OK();
}

const auto g = batch[1].array()->GetValues<uint32_t>(1);

std::vector<IntegerVarStd<Type>> var_std(num_groups_);

ARROW_ASSIGN_OR_RAISE(auto mapping,
Expand All @@ -1402,23 +1408,42 @@ struct GroupedVarStdImpl : public GroupedAggregator {
GroupedVarStdImpl<Type> state;
RETURN_NOT_OK(state.Init(ctx_, &options_));
RETURN_NOT_OK(state.Resize(num_groups_));
int64_t* other_counts = reinterpret_cast<int64_t*>(state.counts_.mutable_data());
double* other_means = reinterpret_cast<double*>(state.means_.mutable_data());
double* other_m2s = reinterpret_cast<double*>(state.m2s_.mutable_data());
int64_t* other_counts = state.counts_.mutable_data();
double* other_means = state.means_.mutable_data();
double* other_m2s = state.m2s_.mutable_data();
uint8_t* other_no_nulls = state.no_nulls_.mutable_data();

if (batch[0].is_array()) {
const auto& array = *batch[0].array();
const CType* values = array.GetValues<CType>(1);
arrow::internal::VisitSetBitRunsVoid(
array.buffers[0], array.offset + start_index,
std::min(max_length, batch.length - start_index),
[&](int64_t pos, int64_t len) {
for (int64_t i = 0; i < len; ++i) {
const int64_t index = start_index + pos + i;
const auto value = values[index];
var_std[g[index]].ConsumeOne(value);
auto visit_values = [&](int64_t pos, int64_t len) {
for (int64_t i = 0; i < len; ++i) {
const int64_t index = start_index + pos + i;
const auto value = values[index];
var_std[g[index]].ConsumeOne(value);
}
};

if (array.MayHaveNulls()) {
arrow::internal::BitRunReader reader(
array.buffers[0]->data(), array.offset + start_index,
std::min(max_length, batch.length - start_index));
int64_t position = 0;
while (true) {
auto run = reader.NextRun();
if (run.length == 0) break;
if (run.set) {
visit_values(position, run.length);
} else {
for (int64_t i = 0; i < run.length; ++i) {
BitUtil::ClearBit(other_no_nulls, g[start_index + position + i]);
}
});
}
position += run.length;
}
} else {
visit_values(0, array.length);
}
} else {
const auto value = UnboxScalar<Type>::Unbox(*batch[0].scalar());
for (int64_t i = 0; i < std::min(max_length, batch.length - start_index); ++i) {
Expand All @@ -1444,16 +1469,21 @@ struct GroupedVarStdImpl : public GroupedAggregator {
// Combine m2 from two chunks (see aggregate_var_std.cc)
auto other = checked_cast<GroupedVarStdImpl*>(&raw_other);

auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
auto means = reinterpret_cast<double*>(means_.mutable_data());
auto m2s = reinterpret_cast<double*>(m2s_.mutable_data());
int64_t* counts = counts_.mutable_data();
double* means = means_.mutable_data();
double* m2s = m2s_.mutable_data();
uint8_t* no_nulls = no_nulls_.mutable_data();

const auto* other_counts = reinterpret_cast<const int64_t*>(other->counts_.data());
const auto* other_means = reinterpret_cast<const double*>(other->means_.data());
const auto* other_m2s = reinterpret_cast<const double*>(other->m2s_.data());
const int64_t* other_counts = other->counts_.data();
const double* other_means = other->means_.data();
const double* other_m2s = other->m2s_.data();
const uint8_t* other_no_nulls = other->no_nulls_.data();

auto g = group_id_mapping.GetValues<uint32_t>(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
if (!BitUtil::GetBit(other_no_nulls, other_g)) {
BitUtil::ClearBit(no_nulls, *g);
}
if (other_counts[other_g] == 0) continue;
MergeVarStd(counts[*g], means[*g], other_counts[other_g], other_means[other_g],
other_m2s[other_g], &counts[*g], &means[*g], &m2s[*g]);
Expand All @@ -1468,10 +1498,10 @@ struct GroupedVarStdImpl : public GroupedAggregator {
int64_t null_count = 0;

double* results = reinterpret_cast<double*>(values->mutable_data());
const int64_t* counts = reinterpret_cast<const int64_t*>(counts_.data());
const double* m2s = reinterpret_cast<const double*>(m2s_.data());
const int64_t* counts = counts_.data();
const double* m2s = m2s_.data();
for (int64_t i = 0; i < num_groups_; ++i) {
if (counts[i] > options_.ddof) {
if (counts[i] > options_.ddof && counts[i] >= options_.min_count) {
const double variance = m2s[i] / (counts[i] - options_.ddof);
results[i] = result_type_ == VarOrStd::Var ? variance : std::sqrt(variance);
continue;
Expand All @@ -1486,6 +1516,15 @@ struct GroupedVarStdImpl : public GroupedAggregator {
null_count += 1;
BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false);
}
if (!options_.skip_nulls) {
if (null_bitmap) {
arrow::internal::BitmapAnd(null_bitmap->data(), 0, no_nulls_.data(), 0,
num_groups_, 0, null_bitmap->mutable_data());
} else {
ARROW_ASSIGN_OR_RAISE(null_bitmap, no_nulls_.Finish());
}
null_count = kUnknownNullCount;
}

return ArrayData::Make(float64(), num_groups_,
{std::move(null_bitmap), std::move(values)}, null_count);
Expand All @@ -1497,7 +1536,9 @@ struct GroupedVarStdImpl : public GroupedAggregator {
VarianceOptions options_;
int64_t num_groups_ = 0;
// m2 = count * s2 = sum((X-mean)^2)
BufferBuilder counts_, means_, m2s_;
TypedBufferBuilder<int64_t> counts_;
TypedBufferBuilder<double> means_, m2s_;
TypedBufferBuilder<bool> no_nulls_;
ExecContext* ctx_;
MemoryPool* pool_;
};
Expand Down
Loading