diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 0a567e385e7..139ab614010 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -824,6 +824,36 @@ Status AddHashAggKernels( return Status::OK(); } +template +void VisitGroupedValues(const ExecBatch& batch, ConsumeValue&& valid_func, + ConsumeNull&& null_func) { + auto g = batch[1].array()->GetValues(1); + if (batch[0].is_array()) { + VisitArrayValuesInline( + *batch[0].array(), + [&](typename TypeTraits::CType val) { valid_func(*g++, val); }, + [&]() { null_func(*g++); }); + return; + } + const auto& input = *batch[0].scalar(); + if (input.is_valid) { + const auto val = UnboxScalar::Unbox(input); + for (int64_t i = 0; i < batch.length; i++) { + valid_func(*g++, val); + } + } else { + for (int64_t i = 0; i < batch.length; i++) { + null_func(*g++); + } + } +} + +template +void VisitGroupedValuesNonNull(const ExecBatch& batch, ConsumeValue&& valid_func) { + VisitGroupedValues(batch, std::forward(valid_func), + [](uint32_t) {}); +} + // ---------------------------------------------------------------------- // Count implementation @@ -856,12 +886,15 @@ struct GroupedCountImpl : public GroupedAggregator { Status Consume(const ExecBatch& batch) override { auto counts = reinterpret_cast(counts_.mutable_data()); - - const auto& input = batch[0].array(); - auto g_begin = batch[1].array()->GetValues(1); - switch (options_.mode) { - case CountOptions::ONLY_VALID: { + + if (options_.mode == CountOptions::ALL) { + for (int64_t i = 0; i < batch.length; ++i, ++g_begin) { + counts[*g_begin] += 1; + } + } else if (batch[0].is_array()) { + const auto& input = batch[0].array(); + if (options_.mode == CountOptions::ONLY_VALID) { arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset, input->length, [&](int64_t offset, int64_t length) { @@ -870,25 +903,25 @@ struct GroupedCountImpl : public GroupedAggregator { counts[*g] += 1; } }); - break; - } - case CountOptions::ONLY_NULL: { + } else { // ONLY_NULL if (input->MayHaveNulls()) { auto end = input->offset + input->length; for (int64_t i = input->offset; i < end; ++i, ++g_begin) { counts[*g_begin] += !BitUtil::GetBit(input->buffers[0]->data(), i); } } - break; } - case CountOptions::ALL: { + } else { + const auto& input = *batch[0].scalar(); + if (options_.mode == CountOptions::ONLY_VALID) { + for (int64_t i = 0; i < batch.length; ++i, ++g_begin) { + counts[*g_begin] += input.is_valid; + } + } else { // ONLY_NULL for (int64_t i = 0; i < batch.length; ++i, ++g_begin) { - counts[*g_begin] += 1; + counts[*g_begin] += !input.is_valid; } - break; } - default: - DCHECK(false) << "unreachable"; } return Status::OK(); } @@ -911,12 +944,13 @@ struct GroupedCountImpl : public GroupedAggregator { template struct GroupedReducingAggregator : public GroupedAggregator { using AccType = typename FindAccumulatorType::Type; - using c_type = typename TypeTraits::CType; + using CType = typename TypeTraits::CType; + using InputCType = typename TypeTraits::CType; Status Init(ExecContext* ctx, const FunctionOptions* options) override { pool_ = ctx->memory_pool(); options_ = checked_cast(*options); - reduced_ = TypedBufferBuilder(pool_); + reduced_ = TypedBufferBuilder(pool_); counts_ = TypedBufferBuilder(pool_); no_nulls_ = TypedBufferBuilder(pool_); // out_type_ initialized by SumInit @@ -933,31 +967,36 @@ struct GroupedReducingAggregator : public GroupedAggregator { } Status Consume(const ExecBatch& batch) override { - c_type* reduced = reduced_.mutable_data(); + CType* reduced = reduced_.mutable_data(); int64_t* counts = counts_.mutable_data(); uint8_t* no_nulls = no_nulls_.mutable_data(); - auto g = batch[1].array()->GetValues(1); - - return Impl::Consume(*batch[0].array(), reduced, counts, no_nulls, g); + VisitGroupedValues( + batch, + [&](uint32_t g, InputCType value) { + reduced[g] = Impl::Reduce(*out_type_, reduced[g], value); + counts[g]++; + }, + [&](uint32_t g) { BitUtil::SetBitTo(no_nulls, g, false); }); + return Status::OK(); } Status Merge(GroupedAggregator&& raw_other, const ArrayData& group_id_mapping) override { auto other = checked_cast*>(&raw_other); - c_type* reduced = reduced_.mutable_data(); + CType* reduced = reduced_.mutable_data(); int64_t* counts = counts_.mutable_data(); uint8_t* no_nulls = no_nulls_.mutable_data(); - const c_type* other_reduced = other->reduced_.data(); + const CType* other_reduced = other->reduced_.data(); const int64_t* other_counts = other->counts_.data(); const uint8_t* other_no_nulls = no_nulls_.mutable_data(); auto g = group_id_mapping.GetValues(1); for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { counts[*g] += other_counts[other_g]; - Impl::UpdateGroupWith(*out_type_, reduced, *g, other_reduced[other_g]); + reduced[*g] = Impl::Reduce(*out_type_, reduced[*g], other_reduced[other_g]); BitUtil::SetBitTo( no_nulls, *g, BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g)); @@ -969,7 +1008,7 @@ struct GroupedReducingAggregator : public GroupedAggregator { static Result> Finish(MemoryPool* pool, const ScalarAggregateOptions& options, const int64_t* counts, - TypedBufferBuilder* reduced, + TypedBufferBuilder* reduced, int64_t num_groups, int64_t* null_count, std::shared_ptr* null_bitmap) { for (int64_t i = 0; i < num_groups; ++i) { @@ -1014,7 +1053,7 @@ struct GroupedReducingAggregator : public GroupedAggregator { int64_t num_groups_ = 0; ScalarAggregateOptions options_; - TypedBufferBuilder reduced_; + TypedBufferBuilder reduced_; TypedBufferBuilder counts_; TypedBufferBuilder no_nulls_; std::shared_ptr out_type_; @@ -1027,31 +1066,20 @@ struct GroupedReducingAggregator : public GroupedAggregator { template struct GroupedSumImpl : public GroupedReducingAggregator> { using Base = GroupedReducingAggregator>; - using c_type = typename Base::c_type; + using CType = typename Base::CType; + using InputCType = typename Base::InputCType; // Default value for a group - static c_type NullValue(const DataType&) { return c_type(0); } + static CType NullValue(const DataType&) { return CType(0); } - // Update all groups - static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts, - uint8_t* no_nulls, const uint32_t* g) { - // XXX this uses naive summation; we should switch to pairwise summation as was - // done for the scalar aggregate kernel in ARROW-11758 - internal::VisitArrayValuesInline( - values, - [&](typename TypeTraits::CType value) { - reduced[*g] = static_cast(to_unsigned(reduced[*g]) + - to_unsigned(static_cast(value))); - counts[*g++] += 1; - }, - [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); - return Status::OK(); + template + static enable_if_number Reduce(const DataType&, const CType u, + const InputCType v) { + return static_cast(to_unsigned(u) + to_unsigned(static_cast(v))); } - // Update a single group during merge - static void UpdateGroupWith(const DataType&, c_type* reduced, uint32_t g, - c_type value) { - reduced[g] += value; + static CType Reduce(const DataType&, const CType u, const CType v) { + return static_cast(to_unsigned(u) + to_unsigned(v)); } using Base::Finish; @@ -1119,28 +1147,21 @@ struct GroupedProductImpl final : public GroupedReducingAggregator> { using Base = GroupedReducingAggregator>; using AccType = typename Base::AccType; - using c_type = typename Base::c_type; + using CType = typename Base::CType; + using InputCType = typename Base::InputCType; - static c_type NullValue(const DataType& out_type) { + static CType NullValue(const DataType& out_type) { return MultiplyTraits::one(out_type); } - static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts, - uint8_t* no_nulls, const uint32_t* g) { - internal::VisitArrayValuesInline( - values, - [&](typename TypeTraits::CType value) { - reduced[*g] = MultiplyTraits::Multiply(*values.type, reduced[*g], - static_cast(value)); - counts[*g++] += 1; - }, - [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); - return Status::OK(); + template + static enable_if_number Reduce(const DataType& out_type, const CType u, + const InputCType v) { + return MultiplyTraits::Multiply(out_type, u, static_cast(v)); } - static void UpdateGroupWith(const DataType& out_type, c_type* reduced, uint32_t g, - c_type value) { - reduced[g] = MultiplyTraits::Multiply(out_type, reduced[g], value); + static CType Reduce(const DataType& out_type, const CType u, const CType v) { + return MultiplyTraits::Multiply(out_type, u, v); } using Base::Finish; @@ -1190,39 +1211,30 @@ struct GroupedProductFactory { template struct GroupedMeanImpl : public GroupedReducingAggregator> { using Base = GroupedReducingAggregator>; - using c_type = typename Base::c_type; + using CType = typename Base::CType; + using InputCType = typename Base::InputCType; using MeanType = - typename std::conditional::value, c_type, double>::type; + typename std::conditional::value, CType, double>::type; - static c_type NullValue(const DataType&) { return c_type(0); } + static CType NullValue(const DataType&) { return CType(0); } - static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts, - uint8_t* no_nulls, const uint32_t* g) { - // XXX this uses naive summation; we should switch to pairwise summation as was - // done for the scalar aggregate kernel in ARROW-11758 - internal::VisitArrayValuesInline( - values, - [&](typename TypeTraits::CType value) { - reduced[*g] = static_cast(to_unsigned(reduced[*g]) + - to_unsigned(static_cast(value))); - counts[*g++] += 1; - }, - [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); - return Status::OK(); + template + static enable_if_number Reduce(const DataType&, const CType u, + const InputCType v) { + return static_cast(to_unsigned(u) + to_unsigned(static_cast(v))); } - static void UpdateGroupWith(const DataType&, c_type* reduced, uint32_t g, - c_type value) { - reduced[g] += value; + static CType Reduce(const DataType&, const CType u, const CType v) { + return static_cast(to_unsigned(u) + to_unsigned(v)); } static Result> Finish(MemoryPool* pool, const ScalarAggregateOptions& options, const int64_t* counts, - TypedBufferBuilder* reduced_, + TypedBufferBuilder* reduced_, int64_t num_groups, int64_t* null_count, std::shared_ptr* null_bitmap) { - const c_type* reduced = reduced_->data(); + const CType* reduced = reduced_->data(); ARROW_ASSIGN_OR_RAISE(std::shared_ptr values, AllocateBuffer(num_groups * sizeof(MeanType), pool)); MeanType* means = reinterpret_cast(values->mutable_data()); @@ -1325,36 +1337,40 @@ struct GroupedVarStdImpl : public GroupedAggregator { using SumType = typename std::conditional::value, double, int128_t>::type; - int64_t* counts = reinterpret_cast(counts_.mutable_data()); - double* means = reinterpret_cast(means_.mutable_data()); - double* m2s = reinterpret_cast(m2s_.mutable_data()); + GroupedVarStdImpl state; + RETURN_NOT_OK(state.Init(ctx_, &options_)); + RETURN_NOT_OK(state.Resize(num_groups_)); + int64_t* counts = reinterpret_cast(state.counts_.mutable_data()); + double* means = reinterpret_cast(state.means_.mutable_data()); + double* m2s = reinterpret_cast(state.m2s_.mutable_data()); // XXX this uses naive summation; we should switch to pairwise summation as was // done for the scalar aggregate kernel in ARROW-11567 std::vector sums(num_groups_); - auto g = batch[1].array()->GetValues(1); - VisitArrayDataInline( - *batch[0].array(), - [&](typename TypeTraits::CType value) { - sums[*g] += value; - counts[*g] += 1; - ++g; - }, - [&] { ++g; }); + VisitGroupedValuesNonNull( + batch, [&](uint32_t g, typename TypeTraits::CType value) { + sums[g] += value; + counts[g]++; + }); for (int64_t i = 0; i < num_groups_; i++) { means[i] = static_cast(sums[i]) / counts[i]; } - g = batch[1].array()->GetValues(1); - VisitArrayDataInline( - *batch[0].array(), - [&](typename TypeTraits::CType value) { + VisitGroupedValuesNonNull( + batch, [&](uint32_t g, typename TypeTraits::CType value) { const double v = static_cast(value); - m2s[*g] += (v - means[*g]) * (v - means[*g]); - ++g; - }, - [&] { ++g; }); + m2s[g] += (v - means[g]) * (v - means[g]); + }); + + ARROW_ASSIGN_OR_RAISE(auto mapping, + AllocateBuffer(num_groups_ * sizeof(uint32_t), pool_)); + for (uint32_t i = 0; static_cast(i) < num_groups_; i++) { + reinterpret_cast(mapping->mutable_data())[i] = i; + } + ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)}, + /*null_count=*/0); + RETURN_NOT_OK(this->Merge(std::move(state), group_id_mapping)); return Status::OK(); } @@ -1369,7 +1385,10 @@ struct GroupedVarStdImpl : public GroupedAggregator { // for int32: -2^62 <= sum < 2^62 constexpr int64_t max_length = 1ULL << (63 - sizeof(CType) * 8); - const auto& array = *batch[0].array(); + if (batch[0].is_scalar() && !batch[0].scalar()->is_valid) { + return Status::OK(); + } + const auto g = batch[1].array()->GetValues(1); std::vector> var_std(num_groups_); @@ -1382,8 +1401,6 @@ struct GroupedVarStdImpl : public GroupedAggregator { ArrayData group_id_mapping(uint32(), num_groups_, {nullptr, std::move(mapping)}, /*null_count=*/0); - const CType* values = array.GetValues(1); - for (int64_t start_index = 0; start_index < batch.length; start_index += max_length) { // process in chunks that overflow will never happen @@ -1397,16 +1414,26 @@ struct GroupedVarStdImpl : public GroupedAggregator { double* other_means = reinterpret_cast(state.means_.mutable_data()); double* other_m2s = reinterpret_cast(state.m2s_.mutable_data()); - arrow::internal::VisitSetBitRunsVoid( - array.buffers[0], array.offset + start_index, - std::min(max_length, batch.length - start_index), - [&](int64_t pos, int64_t len) { - for (int64_t i = 0; i < len; ++i) { - const int64_t index = start_index + pos + i; - const auto value = values[index]; - var_std[g[index]].ConsumeOne(value); - } - }); + if (batch[0].is_array()) { + const auto& array = *batch[0].array(); + const CType* values = array.GetValues(1); + arrow::internal::VisitSetBitRunsVoid( + array.buffers[0], array.offset + start_index, + std::min(max_length, batch.length - start_index), + [&](int64_t pos, int64_t len) { + for (int64_t i = 0; i < len; ++i) { + const int64_t index = start_index + pos + i; + const auto value = values[index]; + var_std[g[index]].ConsumeOne(value); + } + }); + } else { + const auto value = UnboxScalar::Unbox(*batch[0].scalar()); + for (int64_t i = 0; i < std::min(max_length, batch.length - start_index); ++i) { + const int64_t index = start_index + i; + var_std[g[index]].ConsumeOne(value); + } + } for (int64_t i = 0; i < num_groups_; i++) { if (var_std[i].count == 0) continue; @@ -1546,14 +1573,8 @@ struct GroupedTDigestImpl : public GroupedAggregator { } Status Consume(const ExecBatch& batch) override { - auto g = batch[1].array()->GetValues(1); - VisitArrayDataInline( - *batch[0].array(), - [&](typename TypeTraits::CType value) { - this->tdigests_[*g].NanAdd(value); - ++g; - }, - [&] { ++g; }); + VisitGroupedValuesNonNull( + batch, [&](uint32_t g, CType value) { tdigests_[g].NanAdd(value); }); return Status::OK(); } @@ -1696,18 +1717,17 @@ struct GroupedMinMaxImpl : public GroupedAggregator { } Status Consume(const ExecBatch& batch) override { - auto g = batch[1].array()->GetValues(1); auto raw_mins = reinterpret_cast(mins_.mutable_data()); auto raw_maxes = reinterpret_cast(maxes_.mutable_data()); - VisitArrayValuesInline( - *batch[0].array(), - [&](CType val) { - raw_maxes[*g] = std::max(raw_maxes[*g], val); - raw_mins[*g] = std::min(raw_mins[*g], val); - BitUtil::SetBit(has_values_.mutable_data(), *g++); + VisitGroupedValues( + batch, + [&](uint32_t g, CType val) { + raw_maxes[g] = std::max(raw_maxes[g], val); + raw_mins[g] = std::min(raw_mins[g], val); + BitUtil::SetBit(has_values_.mutable_data(), g); }, - [&] { BitUtil::SetBit(has_nulls_.mutable_data(), *g++); }); + [&](uint32_t g) { BitUtil::SetBit(has_nulls_.mutable_data(), g); }); return Status::OK(); } @@ -1815,7 +1835,7 @@ struct GroupedBooleanAggregator : public GroupedAggregator { Status Init(ExecContext* ctx, const FunctionOptions* options) override { options_ = checked_cast(*options); pool_ = ctx->memory_pool(); - seen_ = TypedBufferBuilder(pool_); + reduced_ = TypedBufferBuilder(pool_); no_nulls_ = TypedBufferBuilder(pool_); counts_ = TypedBufferBuilder(pool_); return Status::OK(); @@ -1824,39 +1844,54 @@ struct GroupedBooleanAggregator : public GroupedAggregator { Status Resize(int64_t new_num_groups) override { auto added_groups = new_num_groups - num_groups_; num_groups_ = new_num_groups; - RETURN_NOT_OK(seen_.Append(added_groups, Impl::NullValue())); + RETURN_NOT_OK(reduced_.Append(added_groups, Impl::NullValue())); RETURN_NOT_OK(no_nulls_.Append(added_groups, true)); return counts_.Append(added_groups, 0); } Status Consume(const ExecBatch& batch) override { - uint8_t* seen = seen_.mutable_data(); + uint8_t* reduced = reduced_.mutable_data(); uint8_t* no_nulls = no_nulls_.mutable_data(); int64_t* counts = counts_.mutable_data(); - const auto& input = *batch[0].array(); auto g = batch[1].array()->GetValues(1); - if (input.MayHaveNulls()) { - const uint8_t* bitmap = input.buffers[1]->data(); - arrow::internal::VisitBitBlocksVoid( - input.buffers[0], input.offset, input.length, - [&](int64_t position) { - counts[*g]++; - Impl::UpdateGroupWith(seen, *g, BitUtil::GetBit(bitmap, position)); - g++; - }, - [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); + if (batch[0].is_array()) { + const auto& input = *batch[0].array(); + if (input.MayHaveNulls()) { + const uint8_t* bitmap = input.buffers[1]->data(); + arrow::internal::VisitBitBlocksVoid( + input.buffers[0], input.offset, input.length, + [&](int64_t position) { + counts[*g]++; + Impl::UpdateGroupWith(reduced, *g, BitUtil::GetBit(bitmap, position)); + g++; + }, + [&] { BitUtil::SetBitTo(no_nulls, *g++, false); }); + } else { + arrow::internal::VisitBitBlocksVoid( + input.buffers[1], input.offset, input.length, + [&](int64_t) { + Impl::UpdateGroupWith(reduced, *g, true); + counts[*g++]++; + }, + [&]() { + Impl::UpdateGroupWith(reduced, *g, false); + counts[*g++]++; + }); + } } else { - arrow::internal::VisitBitBlocksVoid( - input.buffers[1], input.offset, input.length, - [&](int64_t) { - Impl::UpdateGroupWith(seen, *g, true); - counts[*g++]++; - }, - [&]() { - Impl::UpdateGroupWith(seen, *g, false); - counts[*g++]++; - }); + const auto& input = *batch[0].scalar(); + if (input.is_valid) { + const bool value = UnboxScalar::Unbox(input); + for (int64_t i = 0; i < batch.length; i++) { + Impl::UpdateGroupWith(reduced, *g, value); + counts[*g++]++; + } + } else { + for (int64_t i = 0; i < batch.length; i++) { + BitUtil::SetBitTo(no_nulls, *g++, false); + } + } } return Status::OK(); } @@ -1865,18 +1900,18 @@ struct GroupedBooleanAggregator : public GroupedAggregator { const ArrayData& group_id_mapping) override { auto other = checked_cast*>(&raw_other); - uint8_t* seen = seen_.mutable_data(); + uint8_t* reduced = reduced_.mutable_data(); uint8_t* no_nulls = no_nulls_.mutable_data(); int64_t* counts = counts_.mutable_data(); - const uint8_t* other_seen = other->seen_.mutable_data(); + const uint8_t* other_reduced = other->reduced_.mutable_data(); const uint8_t* other_no_nulls = other->no_nulls_.mutable_data(); const int64_t* other_counts = other->counts_.mutable_data(); auto g = group_id_mapping.GetValues(1); for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) { counts[*g] += other_counts[other_g]; - Impl::UpdateGroupWith(seen, *g, BitUtil::GetBit(other_seen, other_g)); + Impl::UpdateGroupWith(reduced, *g, BitUtil::GetBit(other_reduced, other_g)); BitUtil::SetBitTo( no_nulls, *g, BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g)); @@ -1901,11 +1936,11 @@ struct GroupedBooleanAggregator : public GroupedAggregator { BitUtil::SetBitTo(null_bitmap->mutable_data(), i, false); } - ARROW_ASSIGN_OR_RAISE(auto seen, seen_.Finish()); + ARROW_ASSIGN_OR_RAISE(auto reduced, reduced_.Finish()); if (!options_.skip_nulls) { null_count = kUnknownNullCount; ARROW_ASSIGN_OR_RAISE(auto no_nulls, no_nulls_.Finish()); - Impl::AdjustForMinCount(no_nulls->mutable_data(), seen->data(), num_groups_); + Impl::AdjustForMinCount(no_nulls->mutable_data(), reduced->data(), num_groups_); if (null_bitmap) { arrow::internal::BitmapAnd(null_bitmap->data(), /*left_offset=*/0, no_nulls->data(), /*right_offset=*/0, num_groups_, @@ -1916,14 +1951,14 @@ struct GroupedBooleanAggregator : public GroupedAggregator { } return ArrayData::Make(out_type(), num_groups_, - {std::move(null_bitmap), std::move(seen)}, null_count); + {std::move(null_bitmap), std::move(reduced)}, null_count); } std::shared_ptr out_type() const override { return boolean(); } int64_t num_groups_ = 0; ScalarAggregateOptions options_; - TypedBufferBuilder seen_, no_nulls_; + TypedBufferBuilder reduced_, no_nulls_; TypedBufferBuilder counts_; MemoryPool* pool_; }; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index df2222a4eef..812806fe9fb 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -119,42 +119,20 @@ Result NaiveGroupBy(std::vector arguments, std::vector keys return StructArray::Make(std::move(out_columns), std::move(out_names)); } -Result GroupByUsingExecPlan(const std::vector& arguments, - const std::vector& keys, +Result GroupByUsingExecPlan(const BatchesWithSchema& input, + const std::vector& key_names, + const std::vector& arg_names, const std::vector& aggregates, bool use_threads, ExecContext* ctx) { - using arrow::compute::detail::ExecBatchIterator; - - FieldVector scan_fields(arguments.size() + keys.size()); - std::vector keys_str(keys.size()); - std::vector arguments_str(arguments.size()); - std::vector names(arguments.size()); - for (size_t i = 0; i < arguments.size(); ++i) { - auto name = std::string("agg_") + std::to_string(i); + std::vector keys(key_names.size()); + std::vector targets(aggregates.size()); + std::vector names(aggregates.size()); + for (size_t i = 0; i < aggregates.size(); ++i) { names[i] = aggregates[i].function; - scan_fields[i] = field(name, arguments[i].type()); - arguments_str[i] = FieldRef(std::move(name)); - } - for (size_t i = 0; i < keys.size(); ++i) { - auto name = std::string("key_") + std::to_string(i); - scan_fields[arguments.size() + i] = field(name, keys[i].type()); - keys_str[i] = FieldRef(std::move(name)); + targets[i] = FieldRef(arg_names[i]); } - - std::vector scan_batches; - std::vector inputs; - for (const auto& argument : arguments) { - inputs.push_back(argument); - } - for (const auto& key : keys) { - inputs.push_back(key); - } - ARROW_ASSIGN_OR_RAISE(auto batch_iterator, - ExecBatchIterator::Make(inputs, ctx->exec_chunksize())); - ExecBatch batch; - while (batch_iterator->Next(&batch)) { - if (batch.length == 0) continue; - scan_batches.push_back(batch); + for (size_t i = 0; i < key_names.size(); ++i) { + keys[i] = FieldRef(key_names[i]); } ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(ctx)); @@ -162,16 +140,11 @@ Result GroupByUsingExecPlan(const std::vector& arguments, RETURN_NOT_OK( Declaration::Sequence( { - {"source", SourceNodeOptions{schema(std::move(scan_fields)), - MakeVectorGenerator(arrow::internal::MapVector( - [](ExecBatch batch) { - return util::make_optional( - std::move(batch)); - }, - std::move(scan_batches)))}}, + {"source", + SourceNodeOptions{input.schema, input.gen(use_threads, /*slow=*/false)}}, {"aggregate", - AggregateNodeOptions{std::move(aggregates), std::move(arguments_str), - std::move(names), std::move(keys_str)}}, + AggregateNodeOptions{std::move(aggregates), std::move(targets), + std::move(names), std::move(keys)}}, {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); @@ -190,11 +163,11 @@ Result GroupByUsingExecPlan(const std::vector& arguments, std::move(collected)); }); - std::vector output_batches = - start_and_collect.MoveResult().MoveValueUnsafe(); + ARROW_ASSIGN_OR_RAISE(std::vector output_batches, + start_and_collect.MoveResult()); - ArrayVector out_arrays(arguments.size() + keys.size()); - for (size_t i = 0; i < arguments.size() + keys.size(); ++i) { + ArrayVector out_arrays(aggregates.size() + key_names.size()); + for (size_t i = 0; i < out_arrays.size(); ++i) { std::vector> arrays(output_batches.size()); for (size_t j = 0; j < output_batches.size(); ++j) { arrays[j] = output_batches[j].values[i].make_array(); @@ -206,6 +179,44 @@ Result GroupByUsingExecPlan(const std::vector& arguments, plan->sources()[0]->outputs()[0]->output_schema()->fields()); } +/// Simpler overload where you can give the columns as datums +Result GroupByUsingExecPlan(const std::vector& arguments, + const std::vector& keys, + const std::vector& aggregates, + bool use_threads, ExecContext* ctx) { + using arrow::compute::detail::ExecBatchIterator; + + FieldVector scan_fields(arguments.size() + keys.size()); + std::vector key_names(keys.size()); + std::vector arg_names(arguments.size()); + for (size_t i = 0; i < arguments.size(); ++i) { + auto name = std::string("agg_") + std::to_string(i); + scan_fields[i] = field(name, arguments[i].type()); + arg_names[i] = std::move(name); + } + for (size_t i = 0; i < keys.size(); ++i) { + auto name = std::string("key_") + std::to_string(i); + scan_fields[arguments.size() + i] = field(name, keys[i].type()); + key_names[i] = std::move(name); + } + + std::vector inputs = arguments; + inputs.reserve(inputs.size() + keys.size()); + inputs.insert(inputs.end(), keys.begin(), keys.end()); + + ARROW_ASSIGN_OR_RAISE(auto batch_iterator, + ExecBatchIterator::Make(inputs, ctx->exec_chunksize())); + BatchesWithSchema input; + input.schema = schema(std::move(scan_fields)); + ExecBatch batch; + while (batch_iterator->Next(&batch)) { + if (batch.length == 0) continue; + input.batches.push_back(std::move(batch)); + } + + return GroupByUsingExecPlan(input, key_names, arg_names, aggregates, use_threads, ctx); +} + void ValidateGroupBy(const std::vector& aggregates, std::vector arguments, std::vector keys) { ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys, aggregates)); @@ -697,6 +708,46 @@ TEST(GroupBy, CountOnly) { } } +TEST(GroupBy, CountScalar) { + BatchesWithSchema input; + input.batches = { + ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()}, + "[[1, 1], [1, 1], [1, 2], [1, 3]]"), + ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()}, + "[[null, 1], [null, 1], [null, 2], [null, 3]]"), + ExecBatchFromJSON({int32(), int64()}, "[[2, 1], [3, 2], [4, 3]]"), + }; + input.schema = schema({field("argument", int32()), field("key", int64())}); + + CountOptions skip_nulls(CountOptions::ONLY_VALID); + CountOptions keep_nulls(CountOptions::ONLY_NULL); + CountOptions count_all(CountOptions::ALL); + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN( + Datum actual, + GroupByUsingExecPlan(input, {"key"}, {"argument", "argument", "argument"}, + { + {"hash_count", &skip_nulls}, + {"hash_count", &keep_nulls}, + {"hash_count", &count_all}, + }, + use_threads, default_exec_context())); + Datum expected = ArrayFromJSON(struct_({ + field("hash_count", int64()), + field("hash_count", int64()), + field("hash_count", int64()), + field("key", int64()), + }), + R"([ + [3, 2, 5, 1], + [2, 1, 3, 2], + [2, 1, 3, 3] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST(GroupBy, SumOnly) { for (bool use_exec_plan : {false, true}) { for (bool use_threads : {true, false}) { @@ -866,6 +917,43 @@ TEST(GroupBy, MeanOnly) { } } +TEST(GroupBy, SumMeanProductScalar) { + BatchesWithSchema input; + input.batches = { + ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()}, + "[[1, 1], [1, 1], [1, 2], [1, 3]]"), + ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()}, + "[[null, 1], [null, 1], [null, 2], [null, 3]]"), + ExecBatchFromJSON({int32(), int64()}, "[[2, 1], [3, 2], [4, 3]]"), + }; + input.schema = schema({field("argument", int32()), field("key", int64())}); + + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN( + Datum actual, + GroupByUsingExecPlan(input, {"key"}, {"argument", "argument", "argument"}, + { + {"hash_sum", nullptr}, + {"hash_mean", nullptr}, + {"hash_product", nullptr}, + }, + use_threads, default_exec_context())); + Datum expected = ArrayFromJSON(struct_({ + field("hash_sum", int64()), + field("hash_mean", float64()), + field("hash_product", int64()), + field("key", int64()), + }), + R"([ + [4, 1.333333, 2, 1], + [4, 2, 3, 2], + [5, 2.5, 4, 3] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST(GroupBy, VarianceAndStddev) { auto batch = RecordBatchFromJSON( schema({field("argument", int32()), field("key", int64())}), R"([ @@ -1032,6 +1120,55 @@ TEST(GroupBy, TDigest) { /*verbose=*/true); } +TEST(GroupBy, StddevVarianceTDigestScalar) { + BatchesWithSchema input; + input.batches = { + ExecBatchFromJSON( + {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()}, + "[[1, 1.0, 1], [1, 1.0, 1], [1, 1.0, 2], [1, 1.0, 3]]"), + ExecBatchFromJSON( + {ValueDescr::Scalar(int32()), ValueDescr::Scalar(float32()), int64()}, + "[[null, null, 1], [null, null, 1], [null, null, 2], [null, null, 3]]"), + ExecBatchFromJSON({int32(), float32(), int64()}, + "[[2, 2.0, 1], [3, 3.0, 2], [4, 4.0, 3]]"), + }; + input.schema = schema( + {field("argument", int32()), field("argument1", float32()), field("key", int64())}); + + for (bool use_threads : {false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN(Datum actual, + GroupByUsingExecPlan(input, {"key"}, + {"argument", "argument", "argument", + "argument1", "argument1", "argument1"}, + { + {"hash_stddev", nullptr}, + {"hash_variance", nullptr}, + {"hash_tdigest", nullptr}, + {"hash_stddev", nullptr}, + {"hash_variance", nullptr}, + {"hash_tdigest", nullptr}, + }, + use_threads, default_exec_context())); + Datum expected = + ArrayFromJSON(struct_({ + field("hash_stddev", float64()), + field("hash_variance", float64()), + field("hash_tdigest", fixed_size_list(float64(), 1)), + field("hash_stddev", float64()), + field("hash_variance", float64()), + field("hash_tdigest", fixed_size_list(float64(), 1)), + field("key", int64()), + }), + R"([ + [0.4714045, 0.222222, [1.0], 0.4714045, 0.222222, [1.0], 1], + [1.0, 1.0, [1.0], 1.0, 1.0, [1.0], 2], + [1.5, 2.25, [1.0], 1.5, 2.25, [1.0], 3] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST(GroupBy, MinMaxOnly) { for (bool use_exec_plan : {false, true}) { for (bool use_threads : {true, false}) { @@ -1153,6 +1290,39 @@ TEST(GroupBy, MinMaxDecimal) { } } +TEST(GroupBy, MinMaxScalar) { + BatchesWithSchema input; + input.batches = { + ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()}, + "[[-1, 1], [-1, 1], [-1, 2], [-1, 3]]"), + ExecBatchFromJSON({ValueDescr::Scalar(int32()), int64()}, + "[[null, 1], [null, 1], [null, 2], [null, 3]]"), + ExecBatchFromJSON({int32(), int64()}, "[[2, 1], [3, 2], [4, 3]]"), + }; + input.schema = schema({field("argument", int32()), field("key", int64())}); + + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN( + Datum actual, + GroupByUsingExecPlan(input, {"key"}, {"argument", "argument", "argument"}, + {{"hash_min_max", nullptr}}, use_threads, + default_exec_context())); + Datum expected = + ArrayFromJSON(struct_({ + field("hash_min_max", + struct_({field("min", int32()), field("max", int32())})), + field("key", int64()), + }), + R"([ + [{"min": -1, "max": 2}, 1], + [{"min": -1, "max": 3}, 2], + [{"min": -1, "max": 4}, 3] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST(GroupBy, AnyAndAll) { ScalarAggregateOptions options(/*skip_nulls=*/false); for (bool use_threads : {true, false}) { @@ -1239,6 +1409,47 @@ TEST(GroupBy, AnyAndAll) { } } +TEST(GroupBy, AnyAllScalar) { + BatchesWithSchema input; + input.batches = { + ExecBatchFromJSON({ValueDescr::Scalar(boolean()), int64()}, + "[[true, 1], [true, 1], [true, 2], [true, 3]]"), + ExecBatchFromJSON({ValueDescr::Scalar(boolean()), int64()}, + "[[null, 1], [null, 1], [null, 2], [null, 3]]"), + ExecBatchFromJSON({boolean(), int64()}, "[[true, 1], [false, 2], [null, 3]]"), + }; + input.schema = schema({field("argument", boolean()), field("key", int64())}); + + ScalarAggregateOptions keep_nulls(/*skip_nulls=*/false, /*min_count=*/0); + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + ASSERT_OK_AND_ASSIGN( + Datum actual, + GroupByUsingExecPlan(input, {"key"}, + {"argument", "argument", "argument", "argument"}, + { + {"hash_any", nullptr}, + {"hash_all", nullptr}, + {"hash_any", &keep_nulls}, + {"hash_all", &keep_nulls}, + }, + use_threads, default_exec_context())); + Datum expected = ArrayFromJSON(struct_({ + field("hash_any", boolean()), + field("hash_all", boolean()), + field("hash_any", boolean()), + field("hash_all", boolean()), + field("key", int64()), + }), + R"([ + [true, true, true, null, 1], + [true, false, true, false, 2], + [true, true, true, null, 3] + ])"); + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } +} + TEST(GroupBy, CountDistinct) { for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");