-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-13627: [C++] Fully support ScalarAggregateOptions in (hash) any/all/sum/product/mean #10942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b406294
68bd7d7
0f8857f
a2d8def
4ec5fb0
a332bd8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,6 +162,13 @@ struct ProductImpl : public ScalarAggregator { | |
| if (batch[0].is_array()) { | ||
| const auto& data = batch[0].array(); | ||
| this->count += data->length - data->GetNullCount(); | ||
| this->nulls_observed = this->nulls_observed || data->GetNullCount(); | ||
|
|
||
| if (!options.skip_nulls && this->nulls_observed) { | ||
| // Short-circuit | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| VisitArrayDataInline<ArrowType>( | ||
| *data, | ||
| [&](typename TypeTraits<ArrowType>::CType value) { | ||
|
|
@@ -172,6 +179,7 @@ struct ProductImpl : public ScalarAggregator { | |
| } else { | ||
| const auto& data = *batch[0].scalar(); | ||
| this->count += data.is_valid * batch.length; | ||
| this->nulls_observed = this->nulls_observed || !data.is_valid; | ||
| if (data.is_valid) { | ||
| for (int64_t i = 0; i < batch.length; i++) { | ||
| auto value = internal::UnboxScalar<ArrowType>::Unbox(data); | ||
|
|
@@ -188,11 +196,13 @@ struct ProductImpl : public ScalarAggregator { | |
| this->count += other.count; | ||
| this->product = | ||
| static_cast<ProductType>(to_unsigned(this->product) * to_unsigned(other.product)); | ||
| this->nulls_observed = this->nulls_observed || other.nulls_observed; | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status Finalize(KernelContext*, Datum* out) override { | ||
| if (this->count < options.min_count) { | ||
| if ((!options.skip_nulls && this->nulls_observed) || | ||
| (this->count < options.min_count)) { | ||
| out->value = std::make_shared<OutputType>(); | ||
| } else { | ||
| out->value = MakeScalar(this->product); | ||
|
|
@@ -201,6 +211,7 @@ struct ProductImpl : public ScalarAggregator { | |
| } | ||
|
|
||
| size_t count = 0; | ||
| bool nulls_observed = false; | ||
| typename AccType::c_type product = 1; | ||
| ScalarAggregateOptions options; | ||
| }; | ||
|
|
@@ -268,17 +279,19 @@ struct BooleanAnyImpl : public ScalarAggregator { | |
|
|
||
| Status Consume(KernelContext*, const ExecBatch& batch) override { | ||
| // short-circuit if seen a True already | ||
| if (this->any == true) { | ||
| if (this->any == true && this->count >= options.min_count) { | ||
| return Status::OK(); | ||
| } | ||
| if (batch[0].is_scalar()) { | ||
| const auto& scalar = *batch[0].scalar(); | ||
| this->has_nulls = !scalar.is_valid; | ||
| this->any = scalar.is_valid && checked_cast<const BooleanScalar&>(scalar).value; | ||
| this->count += scalar.is_valid; | ||
| return Status::OK(); | ||
| } | ||
| const auto& data = *batch[0].array(); | ||
| this->has_nulls = data.GetNullCount() > 0; | ||
| this->count += data.length - data.GetNullCount(); | ||
| arrow::internal::OptionalBinaryBitBlockCounter counter( | ||
| data.buffers[0], data.offset, data.buffers[1], data.offset, data.length); | ||
| int64_t position = 0; | ||
|
|
@@ -297,11 +310,13 @@ struct BooleanAnyImpl : public ScalarAggregator { | |
| const auto& other = checked_cast<const BooleanAnyImpl&>(src); | ||
| this->any |= other.any; | ||
| this->has_nulls |= other.has_nulls; | ||
| this->count += other.count; | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status Finalize(KernelContext* ctx, Datum* out) override { | ||
| if (!options.skip_nulls && !this->any && this->has_nulls) { | ||
| if ((!options.skip_nulls && !this->any && this->has_nulls) || | ||
| this->count < options.min_count) { | ||
|
||
| out->value = std::make_shared<BooleanScalar>(); | ||
| } else { | ||
| out->value = std::make_shared<BooleanScalar>(this->any); | ||
|
|
@@ -311,6 +326,7 @@ struct BooleanAnyImpl : public ScalarAggregator { | |
|
|
||
| bool any = false; | ||
| bool has_nulls = false; | ||
| int64_t count = 0; | ||
| ScalarAggregateOptions options; | ||
| }; | ||
|
|
||
|
|
@@ -329,7 +345,7 @@ struct BooleanAllImpl : public ScalarAggregator { | |
|
|
||
| Status Consume(KernelContext*, const ExecBatch& batch) override { | ||
| // short-circuit if seen a false already | ||
| if (this->all == false) { | ||
| if (this->all == false && this->count >= options.min_count) { | ||
| return Status::OK(); | ||
| } | ||
| // short-circuit if seen a null already | ||
|
|
@@ -339,11 +355,13 @@ struct BooleanAllImpl : public ScalarAggregator { | |
| if (batch[0].is_scalar()) { | ||
| const auto& scalar = *batch[0].scalar(); | ||
| this->has_nulls = !scalar.is_valid; | ||
| this->count += scalar.is_valid; | ||
| this->all = !scalar.is_valid || checked_cast<const BooleanScalar&>(scalar).value; | ||
| return Status::OK(); | ||
| } | ||
| const auto& data = *batch[0].array(); | ||
| this->has_nulls = data.GetNullCount() > 0; | ||
| this->count += data.length - data.GetNullCount(); | ||
| arrow::internal::OptionalBinaryBitBlockCounter counter( | ||
| data.buffers[1], data.offset, data.buffers[0], data.offset, data.length); | ||
| int64_t position = 0; | ||
|
|
@@ -363,11 +381,13 @@ struct BooleanAllImpl : public ScalarAggregator { | |
| const auto& other = checked_cast<const BooleanAllImpl&>(src); | ||
| this->all &= other.all; | ||
| this->has_nulls |= other.has_nulls; | ||
| this->count += other.count; | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status Finalize(KernelContext*, Datum* out) override { | ||
| if (!options.skip_nulls && this->all && this->has_nulls) { | ||
| if ((!options.skip_nulls && this->all && this->has_nulls) || | ||
| this->count < options.min_count) { | ||
|
||
| out->value = std::make_shared<BooleanScalar>(); | ||
| } else { | ||
| out->value = std::make_shared<BooleanScalar>(this->all); | ||
|
|
@@ -377,6 +397,7 @@ struct BooleanAllImpl : public ScalarAggregator { | |
|
|
||
| bool all = true; | ||
| bool has_nulls = false; | ||
| int64_t count = 0; | ||
| ScalarAggregateOptions options; | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a possible optimization: if options.skip_nulls, either check the bitmask up front for missings and exit early if any, or exit after the first one is found? It looks like as it is, we still go through and count/sum/etc. all non-null values always.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, we can short-circuit as soon as nulls_observed if we have !skip_nulls. Updated, thanks for pointing this out.