Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ class ExecContext;

/// \brief Control general scalar aggregate kernel behavior
///
/// By default, null values are ignored
/// By default, null values are ignored (skip_nulls = true).
class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions {
public:
explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1);
constexpr static char const kTypeName[] = "ScalarAggregateOptions";
static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; }

/// If true (the default), null values are ignored. Otherwise, if any value is null,
/// emit null.
bool skip_nulls;
/// If less than this many non-null values are observed, emit null.
uint32_t min_count;
};

Expand Down
31 changes: 26 additions & 5 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ struct ProductImpl : public ScalarAggregator {
if (batch[0].is_array()) {
const auto& data = batch[0].array();
this->count += data->length - data->GetNullCount();
this->nulls_observed = this->nulls_observed || data->GetNullCount();

if (!options.skip_nulls && this->nulls_observed) {
// Short-circuit
return Status::OK();
}

VisitArrayDataInline<ArrowType>(
*data,
[&](typename TypeTraits<ArrowType>::CType value) {
Expand All @@ -172,6 +179,7 @@ struct ProductImpl : public ScalarAggregator {
} else {
const auto& data = *batch[0].scalar();
this->count += data.is_valid * batch.length;
this->nulls_observed = this->nulls_observed || !data.is_valid;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a possible optimization: if options.skip_nulls, either check the bitmask up front for missings and exit early if any, or exit after the first one is found? It looks like as it is, we still go through and count/sum/etc. all non-null values always.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, we can short-circuit as soon as nulls_observed if we have !skip_nulls. Updated, thanks for pointing this out.

if (data.is_valid) {
for (int64_t i = 0; i < batch.length; i++) {
auto value = internal::UnboxScalar<ArrowType>::Unbox(data);
Expand All @@ -188,11 +196,13 @@ struct ProductImpl : public ScalarAggregator {
this->count += other.count;
this->product =
static_cast<ProductType>(to_unsigned(this->product) * to_unsigned(other.product));
this->nulls_observed = this->nulls_observed || other.nulls_observed;
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override {
if (this->count < options.min_count) {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared<OutputType>();
} else {
out->value = MakeScalar(this->product);
Expand All @@ -201,6 +211,7 @@ struct ProductImpl : public ScalarAggregator {
}

size_t count = 0;
bool nulls_observed = false;
typename AccType::c_type product = 1;
ScalarAggregateOptions options;
};
Expand Down Expand Up @@ -268,17 +279,19 @@ struct BooleanAnyImpl : public ScalarAggregator {

Status Consume(KernelContext*, const ExecBatch& batch) override {
// short-circuit if seen a True already
if (this->any == true) {
if (this->any == true && this->count >= options.min_count) {
return Status::OK();
}
if (batch[0].is_scalar()) {
const auto& scalar = *batch[0].scalar();
this->has_nulls = !scalar.is_valid;
this->any = scalar.is_valid && checked_cast<const BooleanScalar&>(scalar).value;
this->count += scalar.is_valid;
return Status::OK();
}
const auto& data = *batch[0].array();
this->has_nulls = data.GetNullCount() > 0;
this->count += data.length - data.GetNullCount();
arrow::internal::OptionalBinaryBitBlockCounter counter(
data.buffers[0], data.offset, data.buffers[1], data.offset, data.length);
int64_t position = 0;
Expand All @@ -297,11 +310,13 @@ struct BooleanAnyImpl : public ScalarAggregator {
const auto& other = checked_cast<const BooleanAnyImpl&>(src);
this->any |= other.any;
this->has_nulls |= other.has_nulls;
this->count += other.count;
return Status::OK();
}

Status Finalize(KernelContext* ctx, Datum* out) override {
if (!options.skip_nulls && !this->any && this->has_nulls) {
if ((!options.skip_nulls && !this->any && this->has_nulls) ||
this->count < options.min_count) {
Copy link
Member

@pitrou pitrou Aug 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about this condition: if there are nulls and options.skip_nulls is false, this kernel can still return true (when this->any is true)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looked weird to me as well, but that is how Kleene logic works, and you can observe this in R:

> any(c(NA), na.rm = FALSE)
[1] NA
> any(c(NA, TRUE), na.rm = FALSE)
[1] TRUE
> any(c(NA, FALSE), na.rm = FALSE)
[1] NA
> any(c(), na.rm = FALSE)
[1] FALSE

out->value = std::make_shared<BooleanScalar>();
} else {
out->value = std::make_shared<BooleanScalar>(this->any);
Expand All @@ -311,6 +326,7 @@ struct BooleanAnyImpl : public ScalarAggregator {

bool any = false;
bool has_nulls = false;
int64_t count = 0;
ScalarAggregateOptions options;
};

Expand All @@ -329,7 +345,7 @@ struct BooleanAllImpl : public ScalarAggregator {

Status Consume(KernelContext*, const ExecBatch& batch) override {
// short-circuit if seen a false already
if (this->all == false) {
if (this->all == false && this->count >= options.min_count) {
return Status::OK();
}
// short-circuit if seen a null already
Expand All @@ -339,11 +355,13 @@ struct BooleanAllImpl : public ScalarAggregator {
if (batch[0].is_scalar()) {
const auto& scalar = *batch[0].scalar();
this->has_nulls = !scalar.is_valid;
this->count += scalar.is_valid;
this->all = !scalar.is_valid || checked_cast<const BooleanScalar&>(scalar).value;
return Status::OK();
}
const auto& data = *batch[0].array();
this->has_nulls = data.GetNullCount() > 0;
this->count += data.length - data.GetNullCount();
arrow::internal::OptionalBinaryBitBlockCounter counter(
data.buffers[1], data.offset, data.buffers[0], data.offset, data.length);
int64_t position = 0;
Expand All @@ -363,11 +381,13 @@ struct BooleanAllImpl : public ScalarAggregator {
const auto& other = checked_cast<const BooleanAllImpl&>(src);
this->all &= other.all;
this->has_nulls |= other.has_nulls;
this->count += other.count;
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override {
if (!options.skip_nulls && this->all && this->has_nulls) {
if ((!options.skip_nulls && this->all && this->has_nulls) ||
this->count < options.min_count) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here: if the input is [false, true, null] and skip_nulls is false, then the result is false rather than null?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and this matches base R/dplyr's behavior:

> all(c(FALSE, TRUE, NA), na.rm = FALSE)
[1] FALSE

out->value = std::make_shared<BooleanScalar>();
} else {
out->value = std::make_shared<BooleanScalar>(this->all);
Expand All @@ -377,6 +397,7 @@ struct BooleanAllImpl : public ScalarAggregator {

bool all = true;
bool has_nulls = false;
int64_t count = 0;
ScalarAggregateOptions options;
};

Expand Down
16 changes: 14 additions & 2 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ struct SumImpl : public ScalarAggregator {
if (batch[0].is_array()) {
const auto& data = batch[0].array();
this->count += data->length - data->GetNullCount();
this->nulls_observed = this->nulls_observed || data->GetNullCount();

if (!options.skip_nulls && this->nulls_observed) {
// Short-circuit
return Status::OK();
}

if (is_boolean_type<ArrowType>::value) {
this->sum +=
static_cast<typename SumType::c_type>(BooleanArray(data).true_count());
Expand All @@ -79,6 +86,7 @@ struct SumImpl : public ScalarAggregator {
} else {
const auto& data = *batch[0].scalar();
this->count += data.is_valid * batch.length;
this->nulls_observed = this->nulls_observed || !data.is_valid;
if (data.is_valid) {
this->sum += internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length;
}
Expand All @@ -90,11 +98,13 @@ struct SumImpl : public ScalarAggregator {
const auto& other = checked_cast<const ThisType&>(src);
this->count += other.count;
this->sum += other.sum;
this->nulls_observed = this->nulls_observed || other.nulls_observed;
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override {
if (this->count < options.min_count) {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared<OutputType>();
} else {
out->value = MakeScalar(this->sum);
Expand All @@ -103,14 +113,16 @@ struct SumImpl : public ScalarAggregator {
}

size_t count = 0;
bool nulls_observed = false;
typename SumType::c_type sum = 0;
ScalarAggregateOptions options;
};

template <typename ArrowType, SimdLevel::type SimdLevel>
struct MeanImpl : public SumImpl<ArrowType, SimdLevel> {
Status Finalize(KernelContext*, Datum* out) override {
if (this->count < options.min_count) {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared<DoubleScalar>();
} else {
const double mean = static_cast<double>(this->sum) / this->count;
Expand Down
Loading