From a4db317ead8ee752792152117b3649f8ac300cc9 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 12 Aug 2021 16:35:13 -0400
Subject: [PATCH] ARROW-13613: [C++] Add decimal support to (hash)
sum/mean/product
---
.../arrow/compute/kernels/aggregate_basic.cc | 99 ++++--
.../compute/kernels/aggregate_basic_avx2.cc | 12 +-
.../compute/kernels/aggregate_basic_avx512.cc | 12 +-
.../kernels/aggregate_basic_internal.h | 61 +++-
.../compute/kernels/aggregate_internal.h | 41 ++-
.../arrow/compute/kernels/aggregate_test.cc | 305 ++++++++++++++++++
.../arrow/compute/kernels/hash_aggregate.cc | 136 ++++++--
.../compute/kernels/hash_aggregate_test.cc | 71 ++++
docs/source/cpp/compute.rst | 108 ++++---
9 files changed, 700 insertions(+), 145 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index a3fafaae75d..2952eade96b 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -116,22 +116,18 @@ Result> CountInit(KernelContext*,
template
struct SumImplDefault : public SumImpl {
- explicit SumImplDefault(const ScalarAggregateOptions& options_) {
- this->options = options_;
- }
+ using SumImpl::SumImpl;
};
template
struct MeanImplDefault : public MeanImpl {
- explicit MeanImplDefault(const ScalarAggregateOptions& options_) {
- this->options = options_;
- }
+ using MeanImpl::MeanImpl;
};
Result> SumInit(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit visitor(
- ctx, *args.inputs[0].type,
+ ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
@@ -139,7 +135,7 @@ Result> SumInit(KernelContext* ctx,
Result> MeanInit(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit visitor(
- ctx, *args.inputs[0].type,
+ ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
@@ -156,7 +152,13 @@ struct ProductImpl : public ScalarAggregator {
using ProductType = typename TypeTraits::CType;
using OutputType = typename TypeTraits::ScalarType;
- explicit ProductImpl(const ScalarAggregateOptions& options) { this->options = options; }
+ explicit ProductImpl(const std::shared_ptr& out_type,
+ const ScalarAggregateOptions& options)
+ : out_type(out_type),
+ options(options),
+ count(0),
+ product(MultiplyTraits::one(*out_type)),
+ nulls_observed(false) {}
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
@@ -169,11 +171,11 @@ struct ProductImpl : public ScalarAggregator {
return Status::OK();
}
- VisitArrayDataInline(
+ internal::VisitArrayValuesInline(
*data,
[&](typename TypeTraits::CType value) {
this->product =
- static_cast(to_unsigned(this->product) * to_unsigned(value));
+ MultiplyTraits::Multiply(*out_type, this->product, value);
},
[] {});
} else {
@@ -184,7 +186,7 @@ struct ProductImpl : public ScalarAggregator {
for (int64_t i = 0; i < batch.length; i++) {
auto value = internal::UnboxScalar::Unbox(data);
this->product =
- static_cast(to_unsigned(this->product) * to_unsigned(value));
+ MultiplyTraits::Multiply(*out_type, this->product, value);
}
}
}
@@ -195,7 +197,7 @@ struct ProductImpl : public ScalarAggregator {
const auto& other = checked_cast(src);
this->count += other.count;
this->product =
- static_cast(to_unsigned(this->product) * to_unsigned(other.product));
+ MultiplyTraits::Multiply(*out_type, this->product, other.product);
this->nulls_observed = this->nulls_observed || other.nulls_observed;
return Status::OK();
}
@@ -203,26 +205,27 @@ struct ProductImpl : public ScalarAggregator {
Status Finalize(KernelContext*, Datum* out) override {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
- out->value = std::make_shared();
+ out->value = std::make_shared(out_type);
} else {
- out->value = MakeScalar(this->product);
+ out->value = std::make_shared(this->product, out_type);
}
return Status::OK();
}
- size_t count = 0;
- bool nulls_observed = false;
- typename AccType::c_type product = 1;
+ std::shared_ptr out_type;
ScalarAggregateOptions options;
+ size_t count;
+ ProductType product;
+ bool nulls_observed;
};
struct ProductInit {
std::unique_ptr state;
KernelContext* ctx;
- const DataType& type;
+ const std::shared_ptr& type;
const ScalarAggregateOptions& options;
- ProductInit(KernelContext* ctx, const DataType& type,
+ ProductInit(KernelContext* ctx, const std::shared_ptr& type,
const ScalarAggregateOptions& options)
: ctx(ctx), type(type), options(options) {}
@@ -235,24 +238,32 @@ struct ProductInit {
}
Status Visit(const BooleanType&) {
- state.reset(new ProductImpl(options));
+ auto ty = TypeTraits::AccType>::type_singleton();
+ state.reset(new ProductImpl(ty, options));
return Status::OK();
}
template
enable_if_number Visit(const Type&) {
- state.reset(new ProductImpl(options));
+ auto ty = TypeTraits::AccType>::type_singleton();
+ state.reset(new ProductImpl(ty, options));
+ return Status::OK();
+ }
+
+ template
+ enable_if_decimal Visit(const Type&) {
+ state.reset(new ProductImpl(type, options));
return Status::OK();
}
Result> Create() {
- RETURN_NOT_OK(VisitTypeInline(type, this));
+ RETURN_NOT_OK(VisitTypeInline(*type, this));
return std::move(state);
}
static Result> Init(KernelContext* ctx,
const KernelInitArgs& args) {
- ProductInit visitor(ctx, *args.inputs[0].type,
+ ProductInit visitor(ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
@@ -550,7 +561,8 @@ void AddBasicAggKernels(KernelInit init,
SimdLevel::type simd_level) {
for (const auto& ty : types) {
// array[InT] -> scalar[OutT]
- auto sig = KernelSignature::Make({InputType::Array(ty)}, ValueDescr::Scalar(out_ty));
+ auto sig =
+ KernelSignature::Make({InputType::Array(ty->id())}, ValueDescr::Scalar(out_ty));
AddAggKernel(std::move(sig), init, func, simd_level);
}
}
@@ -561,7 +573,8 @@ void AddScalarAggKernels(KernelInit init,
ScalarAggregateFunction* func) {
for (const auto& ty : types) {
// scalar[InT] -> scalar[OutT]
- auto sig = KernelSignature::Make({InputType::Scalar(ty)}, ValueDescr::Scalar(out_ty));
+ auto sig =
+ KernelSignature::Make({InputType::Scalar(ty->id())}, ValueDescr::Scalar(out_ty));
AddAggKernel(std::move(sig), init, func, SimdLevel::NONE);
}
}
@@ -598,6 +611,13 @@ void AddMinMaxKernels(KernelInit init,
}
}
+Result ScalarFirstType(KernelContext*,
+ const std::vector& descrs) {
+ ValueDescr result = descrs.front();
+ result.shape = ValueDescr::SCALAR;
+ return result;
+}
+
} // namespace aggregate
namespace internal {
@@ -628,9 +648,12 @@ const FunctionDoc product_doc{
const FunctionDoc mean_doc{
"Compute the mean of a numeric array",
("Null values are ignored by default. Minimum count of non-null\n"
- "values can be set and null is returned if too few are "
- "present.\nThis can be changed through ScalarAggregateOptions.\n"
- "The result is always computed as a double, regardless of the input types."),
+ "values can be set and null is returned if too few are present.\n"
+ "This can be changed through ScalarAggregateOptions.\n"
+ "The result is a double for integer and floating point arguments,\n"
+ "and a decimal with the same bit-width/precision/scale for decimal arguments.\n"
+ "For integers and floats, NaN is returned if min_count = 0 and\n"
+ "there are no values. For decimals, null is returned instead."),
{"array"},
"ScalarAggregateOptions"};
@@ -683,6 +706,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
&default_scalar_aggregate_options);
aggregate::AddArrayScalarAggKernels(aggregate::SumInit, {boolean()}, uint64(),
func.get());
+ AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL128)},
+ OutputType(aggregate::ScalarFirstType)),
+ aggregate::SumInit, func.get(), SimdLevel::NONE);
+ AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL256)},
+ OutputType(aggregate::ScalarFirstType)),
+ aggregate::SumInit, func.get(), SimdLevel::NONE);
aggregate::AddArrayScalarAggKernels(aggregate::SumInit, SignedIntTypes(), int64(),
func.get());
aggregate::AddArrayScalarAggKernels(aggregate::SumInit, UnsignedIntTypes(), uint64(),
@@ -711,6 +740,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
func.get());
aggregate::AddArrayScalarAggKernels(aggregate::MeanInit, NumericTypes(), float64(),
func.get());
+ AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL128)},
+ OutputType(aggregate::ScalarFirstType)),
+ aggregate::MeanInit, func.get(), SimdLevel::NONE);
+ AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL256)},
+ OutputType(aggregate::ScalarFirstType)),
+ aggregate::MeanInit, func.get(), SimdLevel::NONE);
// Add the SIMD variants for mean
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
@@ -754,6 +789,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
uint64(), func.get());
aggregate::AddArrayScalarAggKernels(aggregate::ProductInit::Init, FloatingPointTypes(),
float64(), func.get());
+ AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL128)},
+ OutputType(aggregate::ScalarFirstType)),
+ aggregate::ProductInit::Init, func.get(), SimdLevel::NONE);
+ AddAggKernel(KernelSignature::Make({InputType(Type::DECIMAL256)},
+ OutputType(aggregate::ScalarFirstType)),
+ aggregate::ProductInit::Init, func.get(), SimdLevel::NONE);
DCHECK_OK(registry->AddFunction(std::move(func)));
// any
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc
index 8d3e5a0409d..55e9f290e0e 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc
@@ -26,22 +26,18 @@ namespace aggregate {
template
struct SumImplAvx2 : public SumImpl {
- explicit SumImplAvx2(const ScalarAggregateOptions& options_) {
- this->options = options_;
- }
+ using SumImpl::SumImpl;
};
template
struct MeanImplAvx2 : public MeanImpl {
- explicit MeanImplAvx2(const ScalarAggregateOptions& options_) {
- this->options = options_;
- }
+ using MeanImpl::MeanImpl;
};
Result> SumInitAvx2(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit visitor(
- ctx, *args.inputs[0].type,
+ ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
@@ -49,7 +45,7 @@ Result> SumInitAvx2(KernelContext* ctx,
Result> MeanInitAvx2(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit visitor(
- ctx, *args.inputs[0].type,
+ ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc
index 4f8ad74a086..df33dedabba 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc
@@ -26,22 +26,18 @@ namespace aggregate {
template
struct SumImplAvx512 : public SumImpl {
- explicit SumImplAvx512(const ScalarAggregateOptions& options_) {
- this->options = options_;
- }
+ using SumImpl::SumImpl;
};
template
struct MeanImplAvx512 : public MeanImpl {
- explicit MeanImplAvx512(const ScalarAggregateOptions& options_) {
- this->options = options_;
- }
+ using MeanImpl::MeanImpl;
};
Result> SumInitAvx512(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit visitor(
- ctx, *args.inputs[0].type,
+ ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
@@ -49,7 +45,7 @@ Result> SumInitAvx512(KernelContext* ctx,
Result> MeanInitAvx512(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit visitor(
- ctx, *args.inputs[0].type,
+ ctx, args.inputs[0].type,
static_cast(*args.options));
return visitor.Create();
}
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
index b355a2e1b75..b97af066585 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
@@ -60,10 +60,15 @@ void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func);
template
struct SumImpl : public ScalarAggregator {
using ThisType = SumImpl;
- using CType = typename ArrowType::c_type;
+ using CType = typename TypeTraits::CType;
using SumType = typename FindAccumulatorType::Type;
+ using SumCType = typename TypeTraits::CType;
using OutputType = typename TypeTraits::ScalarType;
+ SumImpl(const std::shared_ptr& out_type,
+ const ScalarAggregateOptions& options_)
+ : out_type(out_type), options(options_) {}
+
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
const auto& data = batch[0].array();
@@ -76,12 +81,9 @@ struct SumImpl : public ScalarAggregator {
}
if (is_boolean_type::value) {
- this->sum +=
- static_cast(BooleanArray(data).true_count());
+ this->sum += static_cast(BooleanArray(data).true_count());
} else {
- this->sum +=
- arrow::compute::detail::SumArray(
- *data);
+ this->sum += arrow::compute::detail::SumArray(*data);
}
} else {
const auto& data = *batch[0].scalar();
@@ -105,22 +107,39 @@ struct SumImpl : public ScalarAggregator {
Status Finalize(KernelContext*, Datum* out) override {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
- out->value = std::make_shared();
+ out->value = std::make_shared(out_type);
} else {
- out->value = MakeScalar(this->sum);
+ out->value = std::make_shared(this->sum, out_type);
}
return Status::OK();
}
size_t count = 0;
bool nulls_observed = false;
- typename SumType::c_type sum = 0;
+ SumCType sum = 0;
+ std::shared_ptr out_type;
ScalarAggregateOptions options;
};
template
struct MeanImpl : public SumImpl {
- Status Finalize(KernelContext*, Datum* out) override {
+ using SumImpl::SumImpl;
+
+ template
+ enable_if_decimal FinalizeImpl(Datum* out) {
+ using SumCType = typename SumImpl::SumCType;
+ using OutputType = typename SumImpl::OutputType;
+ if ((!options.skip_nulls && this->nulls_observed) ||
+ (this->count < options.min_count) || (this->count == 0)) {
+ out->value = std::make_shared(this->out_type);
+ } else {
+ const SumCType mean = this->sum / this->count;
+ out->value = std::make_shared(mean, this->out_type);
+ }
+ return Status::OK();
+ }
+ template
+ enable_if_t::value, Status> FinalizeImpl(Datum* out) {
if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared();
@@ -130,17 +149,19 @@ struct MeanImpl : public SumImpl {
}
return Status::OK();
}
- ScalarAggregateOptions options;
+ Status Finalize(KernelContext*, Datum* out) override { return FinalizeImpl(out); }
+
+ using SumImpl::options;
};
template class KernelClass>
struct SumLikeInit {
std::unique_ptr state;
KernelContext* ctx;
- const DataType& type;
+ const std::shared_ptr type;
const ScalarAggregateOptions& options;
- SumLikeInit(KernelContext* ctx, const DataType& type,
+ SumLikeInit(KernelContext* ctx, const std::shared_ptr& type,
const ScalarAggregateOptions& options)
: ctx(ctx), type(type), options(options) {}
@@ -151,18 +172,26 @@ struct SumLikeInit {
}
Status Visit(const BooleanType&) {
- state.reset(new KernelClass(options));
+ auto ty = TypeTraits::SumType>::type_singleton();
+ state.reset(new KernelClass(ty, options));
return Status::OK();
}
template
enable_if_number Visit(const Type&) {
- state.reset(new KernelClass(options));
+ auto ty = TypeTraits::SumType>::type_singleton();
+ state.reset(new KernelClass(ty, options));
+ return Status::OK();
+ }
+
+ template
+ enable_if_decimal Visit(const Type&) {
+ state.reset(new KernelClass(type, options));
return Status::OK();
}
Result> Create() {
- RETURN_NOT_OK(VisitTypeInline(type, this));
+ RETURN_NOT_OK(VisitTypeInline(*type, this));
return std::move(state);
}
};
diff --git a/cpp/src/arrow/compute/kernels/aggregate_internal.h b/cpp/src/arrow/compute/kernels/aggregate_internal.h
index 3f5ba39d30e..13687d22820 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_internal.h
+++ b/cpp/src/arrow/compute/kernels/aggregate_internal.h
@@ -17,6 +17,7 @@
#pragma once
+#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit_run_reader.h"
@@ -49,6 +50,44 @@ struct FindAccumulatorType> {
using Type = DoubleType;
};
+template
+struct FindAccumulatorType> {
+ using Type = Decimal128Type;
+};
+
+template
+struct FindAccumulatorType> {
+ using Type = Decimal256Type;
+};
+
+// Helpers for implementing aggregations on decimals
+
+template
+struct MultiplyTraits {
+ using CType = typename TypeTraits::CType;
+
+ constexpr static CType one(const DataType&) { return static_cast(1); }
+
+ constexpr static CType Multiply(const DataType&, CType lhs, CType rhs) {
+ return static_cast(internal::to_unsigned(lhs) * internal::to_unsigned(rhs));
+ }
+};
+
+template
+struct MultiplyTraits> {
+ using CType = typename TypeTraits::CType;
+
+ constexpr static CType one(const DataType& ty) {
+ // Return 1 scaled to output type scale
+ return CType(1).IncreaseScaleBy(static_cast(ty).scale());
+ }
+
+ constexpr static CType Multiply(const DataType& ty, CType lhs, CType rhs) {
+ // Multiply then rescale down to output scale
+ return (lhs * rhs).ReduceScaleBy(static_cast(ty).scale());
+ }
+};
+
struct ScalarAggregator : public KernelState {
virtual Status Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
virtual Status MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
@@ -148,7 +187,7 @@ enable_if_t::value, SumType> SumArray(
return sum[root_level];
}
-// naive summation for integers
+// naive summation for integers and decimals
template
enable_if_t::value, SumType> SumArray(
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index 9893923a097..b93e33b05e1 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -494,6 +494,100 @@ TEST_F(TestSumKernelRoundOff, Basics) {
ASSERT_EQ(sum->value, 2756346749973250.0);
}
+TEST(TestDecimalSumKernel, SimpleSum) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([null])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(
+ Sum(ArrayFromJSON(ty, R"(["0.00", "1.01", "2.02", "3.03", "4.04", "5.05"])")),
+ ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+ Datum chunks =
+ ChunkedArrayFromJSON(ty, {R"(["0.00", "1.01", "2.02", "3.03", "4.04", "5.05"])"});
+ EXPECT_THAT(Sum(chunks), ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+ chunks = ChunkedArrayFromJSON(
+ ty, {R"(["0.00", "1.01", "2.02"])", R"(["3.03", "4.04", "5.05"])"});
+ EXPECT_THAT(Sum(chunks), ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+ chunks = ChunkedArrayFromJSON(
+ ty, {R"(["0.00", "1.01", "2.02"])", "[]", R"(["3.03", "4.04", "5.05"])"});
+ EXPECT_THAT(Sum(chunks), ResultWith(ScalarFromJSON(ty, R"("15.15")")));
+
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ chunks = ChunkedArrayFromJSON(ty, {});
+ EXPECT_THAT(Sum(chunks, options), ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+
+ EXPECT_THAT(
+ Sum(ArrayFromJSON(ty, R"(["1.01", null, "3.03", null, "5.05", null, "7.07"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"("16.16")")));
+
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"("5.05")")),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"(null)")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"(null)"), options),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ }
+}
+
+TEST(TestDecimalSumKernel, ScalarAggregateOptions) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum null = ScalarFromJSON(ty, R"(null)");
+ Datum zero = ScalarFromJSON(ty, R"("0.00")");
+ Datum result = ScalarFromJSON(ty, R"("14.14")");
+ Datum arr =
+ ArrayFromJSON(ty, R"(["1.01", null, "3.03", null, "3.03", null, "7.07"])");
+
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(zero));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(zero));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ arr = ArrayFromJSON(ty, R"(["1.01", "3.03", "3.03", "7.07"])");
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Sum(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Sum(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null));
+ EXPECT_THAT(Sum(null, ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null));
+ }
+}
+
//
// Product
//
@@ -591,6 +685,110 @@ TYPED_TEST(TestNumericProductKernel, ScalarAggregateOptions) {
ResultWith(null_result));
}
+TEST(TestDecimalProductKernel, SimpleProduct) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum null = ScalarFromJSON(ty, R"(null)");
+
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([])")), ResultWith(null));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([null])")), ResultWith(null));
+ EXPECT_THAT(
+ Product(ArrayFromJSON(ty, R"(["0.00", "1.00", "2.00", "3.00", "4.00", "5.00"])")),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ Datum chunks =
+ ChunkedArrayFromJSON(ty, {R"(["1.00", "2.00", "3.00", "4.00", "5.00"])"});
+ EXPECT_THAT(Product(chunks), ResultWith(ScalarFromJSON(ty, R"("120.00")")));
+ chunks =
+ ChunkedArrayFromJSON(ty, {R"(["1.00", "2.00"])", R"(["-3.00", "4.00", "5.00"])"});
+ EXPECT_THAT(Product(chunks), ResultWith(ScalarFromJSON(ty, R"("-120.00")")));
+ chunks = ChunkedArrayFromJSON(
+ ty, {R"(["1.00", "2.00"])", R"([])", R"(["-3.00", "4.00", "-5.00"])"});
+ EXPECT_THAT(Product(chunks), ResultWith(ScalarFromJSON(ty, R"("120.00")")));
+
+ const ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("1.00")")));
+ EXPECT_THAT(Product(ArrayFromJSON(ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"("1.00")")));
+ chunks = ChunkedArrayFromJSON(ty, {});
+ EXPECT_THAT(Product(chunks, options), ResultWith(ScalarFromJSON(ty, R"("1.00")")));
+
+ EXPECT_THAT(Product(ArrayFromJSON(
+ ty, R"(["1.00", null, "-3.00", null, "3.00", null, "7.00"])"),
+ options),
+ ResultWith(ScalarFromJSON(ty, R"("-63.00")")));
+
+ EXPECT_THAT(Product(ScalarFromJSON(ty, R"("5.00")")),
+ ResultWith(ScalarFromJSON(ty, R"("5.00")")));
+ EXPECT_THAT(Product(null), ResultWith(null));
+ }
+}
+
+TEST(TestDecimalProductKernel, ScalarAggregateOptions) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum null = ScalarFromJSON(ty, R"(null)");
+ Datum one = ScalarFromJSON(ty, R"("1.00")");
+ Datum result = ScalarFromJSON(ty, R"("63.00")");
+
+ Datum empty = ArrayFromJSON(ty, R"([])");
+ Datum null_arr = ArrayFromJSON(ty, R"([null])");
+ Datum arr =
+ ArrayFromJSON(ty, R"(["1.00", null, "3.00", null, "3.00", null, "7.00"])");
+
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(one));
+ EXPECT_THAT(
+ Product(null_arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ ResultWith(one));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(empty, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(null_arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ ResultWith(null));
+
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(null));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ arr = ArrayFromJSON(ty, R"(["1.00", "3.00", "3.00", "7.00"])");
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(
+ Product(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ EXPECT_THAT(Product(ScalarFromJSON(ty, R"("5.00")"),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(ScalarFromJSON(ty, R"("5.00")")));
+ EXPECT_THAT(Product(ScalarFromJSON(ty, R"("5.00")"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null));
+ EXPECT_THAT(Product(null, ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null));
+ }
+}
+
TEST(TestProductKernel, Overflow) {
EXPECT_THAT(Product(ArrayFromJSON(int64(), "[8589934592, 8589934593]")),
ResultWith(Datum(static_cast(8589934592))));
@@ -855,6 +1053,113 @@ TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMeanOverflow) {
}
}
+TEST(TestDecimalMeanKernel, SimpleMean) {
+ ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
+
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ // Decimal doesn't have NaN
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([])"), options),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([null])"), options),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"([null])")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, R"(["1.01", null, "1.01"])")),
+ ResultWith(ScalarFromJSON(ty, R"("1.01")")));
+ EXPECT_THAT(
+ Mean(ArrayFromJSON(
+ ty, R"(["1.01", "2.02", "3.03", "4.04", "5.05", "6.06", "7.07", "8.08"])")),
+ ResultWith(ScalarFromJSON(ty, R"("4.54")")));
+ EXPECT_THAT(
+ Mean(ArrayFromJSON(
+ ty, R"(["0.00", "0.00", "0.00", "0.00", "0.00", "0.00", "0.00", "0.00"])")),
+ ResultWith(ScalarFromJSON(ty, R"("0.00")")));
+ EXPECT_THAT(
+ Mean(ArrayFromJSON(
+ ty, R"(["1.01", "1.01", "1.01", "1.01", "1.01", "1.01", "1.01", "1.01"])")),
+ ResultWith(ScalarFromJSON(ty, R"("1.01")")));
+
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"("5.05")")),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"(null)")),
+ ResultWith(ScalarFromJSON(ty, R"(null)")));
+ }
+}
+
+TEST(TestDecimalMeanKernel, ScalarAggregateOptions) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ Datum result = ScalarFromJSON(ty, R"("3.03")");
+ Datum null = ScalarFromJSON(ty, R"(null)");
+ Datum arr = ArrayFromJSON(ty, R"(["1.01", null, "2.02", "2.02", null, "7.07"])");
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/1)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/0)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/3)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/5)),
+ null);
+
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)),
+ null);
+ EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"),
+ ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)),
+ null);
+
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ null);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ null);
+
+ arr = ArrayFromJSON(ty, R"(["1.01", "2.02", "2.02", "7.07"])");
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0)),
+ result);
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/3)),
+ ResultWith(result));
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4)),
+ ResultWith(result));
+ EXPECT_THAT(Mean(arr, ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/5)),
+ ResultWith(null));
+
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(ScalarFromJSON(ty, R"("5.05")")));
+ EXPECT_THAT(Mean(ScalarFromJSON(ty, R"("5.05")"),
+ ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/2)),
+ ResultWith(null));
+ EXPECT_THAT(Mean(null, ScalarAggregateOptions(/*skip_nulls=*/false)),
+ ResultWith(null));
+ }
+}
+
//
// Min / Max
//
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index ba56488dcb8..20dcd8ef331 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -919,14 +919,14 @@ struct GroupedReducingAggregator : public GroupedAggregator {
reduced_ = TypedBufferBuilder(pool_);
counts_ = TypedBufferBuilder(pool_);
no_nulls_ = TypedBufferBuilder(pool_);
- out_type_ = TypeTraits::type_singleton();
+ // out_type_ initialized by SumInit
return Status::OK();
}
Status Resize(int64_t new_num_groups) override {
auto added_groups = new_num_groups - num_groups_;
num_groups_ = new_num_groups;
- RETURN_NOT_OK(reduced_.Append(added_groups, Impl::NullValue()));
+ RETURN_NOT_OK(reduced_.Append(added_groups, Impl::NullValue(*out_type_)));
RETURN_NOT_OK(counts_.Append(added_groups, 0));
RETURN_NOT_OK(no_nulls_.Append(added_groups, true));
return Status::OK();
@@ -957,7 +957,7 @@ struct GroupedReducingAggregator : public GroupedAggregator {
auto g = group_id_mapping.GetValues(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
counts[*g] += other_counts[other_g];
- Impl::UpdateGroupWith(reduced, *g, other_reduced[other_g]);
+ Impl::UpdateGroupWith(*out_type_, reduced, *g, other_reduced[other_g]);
BitUtil::SetBitTo(
no_nulls, *g,
BitUtil::GetBit(no_nulls, *g) && BitUtil::GetBit(other_no_nulls, other_g));
@@ -1030,14 +1030,14 @@ struct GroupedSumImpl : public GroupedReducingAggregator(
+ internal::VisitArrayValuesInline(
values,
[&](typename TypeTraits::CType value) {
reduced[*g] = static_cast(to_unsigned(reduced[*g]) +
@@ -1049,17 +1049,46 @@ struct GroupedSumImpl : public GroupedReducingAggregator class Impl, typename T>
+Result> SumInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit>(ctx, args));
+ static_cast*>(impl.get())->out_type_ =
+ TypeTraits::AccType>::type_singleton();
+ return std::move(impl);
+}
+
+template
+Result> DecimalSumInit(KernelContext* ctx,
+ const KernelInitArgs& args) {
+ ARROW_ASSIGN_OR_RAISE(auto impl, HashAggregateInit(ctx, args));
+ static_cast(impl.get())->out_type_ = args.inputs[0].type;
+ return std::move(impl);
+}
+
struct GroupedSumFactory {
template ::Type>
Status Visit(const T&) {
- kernel = MakeKernel(std::move(argument_type), HashAggregateInit>);
+ kernel = MakeKernel(std::move(argument_type), SumInit);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit>);
+ return Status::OK();
+ }
+ Status Visit(const Decimal256Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit>);
return Status::OK();
}
@@ -1073,7 +1102,7 @@ struct GroupedSumFactory {
static Result Make(const std::shared_ptr& type) {
GroupedSumFactory factory;
- factory.argument_type = InputType::Array(type);
+ factory.argument_type = InputType::Array(type->id());
RETURN_NOT_OK(VisitTypeInline(*type, &factory));
return std::move(factory.kernel);
}
@@ -1089,25 +1118,29 @@ template
struct GroupedProductImpl final
: public GroupedReducingAggregator> {
using Base = GroupedReducingAggregator>;
+ using AccType = typename Base::AccType;
using c_type = typename Base::c_type;
- static c_type NullValue() { return c_type(1); }
+ static c_type NullValue(const DataType& out_type) {
+ return MultiplyTraits::one(out_type);
+ }
static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts,
uint8_t* no_nulls, const uint32_t* g) {
- VisitArrayDataInline(
+ internal::VisitArrayValuesInline(
values,
[&](typename TypeTraits::CType value) {
- reduced[*g] = static_cast(to_unsigned(reduced[*g]) *
- to_unsigned(static_cast(value)));
+ reduced[*g] = MultiplyTraits::Multiply(*values.type, reduced[*g],
+ static_cast(value));
counts[*g++] += 1;
},
[&] { BitUtil::SetBitTo(no_nulls, *g++, false); });
return Status::OK();
}
- static void UpdateGroupWith(c_type* reduced, uint32_t g, c_type value) {
- reduced[g] *= value;
+ static void UpdateGroupWith(const DataType& out_type, c_type* reduced, uint32_t g,
+ c_type value) {
+ reduced[g] = MultiplyTraits::Multiply(out_type, reduced[g], value);
}
using Base::Finish;
@@ -1116,8 +1149,19 @@ struct GroupedProductImpl final
struct GroupedProductFactory {
template ::Type>
Status Visit(const T&) {
- kernel =
- MakeKernel(std::move(argument_type), HashAggregateInit>);
+ kernel = MakeKernel(std::move(argument_type), SumInit);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit>);
return Status::OK();
}
@@ -1131,7 +1175,7 @@ struct GroupedProductFactory {
static Result Make(const std::shared_ptr& type) {
GroupedProductFactory factory;
- factory.argument_type = InputType::Array(type);
+ factory.argument_type = InputType::Array(type->id());
RETURN_NOT_OK(VisitTypeInline(*type, &factory));
return std::move(factory.kernel);
}
@@ -1147,14 +1191,16 @@ template
struct GroupedMeanImpl : public GroupedReducingAggregator> {
using Base = GroupedReducingAggregator>;
using c_type = typename Base::c_type;
+ using MeanType =
+ typename std::conditional::value, c_type, double>::type;
- static c_type NullValue() { return c_type(0); }
+ static c_type NullValue(const DataType&) { return c_type(0); }
static Status Consume(const ArrayData& values, c_type* reduced, int64_t* counts,
uint8_t* no_nulls, const uint32_t* g) {
// XXX this uses naive summation; we should switch to pairwise summation as was
// done for the scalar aggregate kernel in ARROW-11758
- VisitArrayDataInline(
+ internal::VisitArrayValuesInline(
values,
[&](typename TypeTraits::CType value) {
reduced[*g] = static_cast(to_unsigned(reduced[*g]) +
@@ -1165,7 +1211,8 @@ struct GroupedMeanImpl : public GroupedReducingAggregator* null_bitmap) {
const c_type* reduced = reduced_->data();
ARROW_ASSIGN_OR_RAISE(std::shared_ptr values,
- AllocateBuffer(num_groups * sizeof(double), pool));
- double* means = reinterpret_cast(values->mutable_data());
+ AllocateBuffer(num_groups * sizeof(MeanType), pool));
+ MeanType* means = reinterpret_cast(values->mutable_data());
for (int64_t i = 0; i < num_groups; ++i) {
if (counts[i] >= options.min_count) {
- means[i] = static_cast(reduced[i]) / counts[i];
+ means[i] = static_cast(reduced[i]) / counts[i];
continue;
}
- means[i] = 0;
+ means[i] = MeanType(0);
if ((*null_bitmap) == nullptr) {
ARROW_ASSIGN_OR_RAISE(*null_bitmap, AllocateBitmap(num_groups, pool));
@@ -1197,13 +1244,28 @@ struct GroupedMeanImpl : public GroupedReducingAggregator out_type() const override { return float64(); }
+ std::shared_ptr out_type() const override {
+ if (is_decimal_type::value) return this->out_type_;
+ return float64();
+ }
};
struct GroupedMeanFactory {
template ::Type>
Status Visit(const T&) {
- kernel = MakeKernel(std::move(argument_type), HashAggregateInit>);
+ kernel = MakeKernel(std::move(argument_type), SumInit);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal128Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit>);
+ return Status::OK();
+ }
+
+ Status Visit(const Decimal256Type&) {
+ kernel = MakeKernel(std::move(argument_type),
+ DecimalSumInit>);
return Status::OK();
}
@@ -1217,7 +1279,7 @@ struct GroupedMeanFactory {
static Result Make(const std::shared_ptr& type) {
GroupedMeanFactory factory;
- factory.argument_type = InputType::Array(type);
+ factory.argument_type = InputType::Array(type->id());
RETURN_NOT_OK(VisitTypeInline(*type, &factory));
return std::move(factory.kernel);
}
@@ -2179,10 +2241,13 @@ const FunctionDoc hash_product_doc{
{"array", "group_id_array"},
"ScalarAggregateOptions"};
-const FunctionDoc hash_mean_doc{"Average values of a numeric array",
- ("Null values are ignored."),
- {"array", "group_id_array"},
- "ScalarAggregateOptions"};
+const FunctionDoc hash_mean_doc{
+ "Average values of a numeric array",
+ ("Null values are ignored.\n"
+ "For integers and floats, NaN is returned if min_count = 0 and\n"
+ "there are no values. For decimals, null is returned instead."),
+ {"array", "group_id_array"},
+ "ScalarAggregateOptions"};
const FunctionDoc hash_stddev_doc{
"Calculate the standard deviation of a numeric array",
@@ -2249,6 +2314,9 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(AddHashAggKernels(UnsignedIntTypes(), GroupedSumFactory::Make, func.get()));
DCHECK_OK(
AddHashAggKernels(FloatingPointTypes(), GroupedSumFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
+ GroupedSumFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
@@ -2263,6 +2331,9 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
AddHashAggKernels(UnsignedIntTypes(), GroupedProductFactory::Make, func.get()));
DCHECK_OK(
AddHashAggKernels(FloatingPointTypes(), GroupedProductFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
+ GroupedProductFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
@@ -2275,6 +2346,9 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
AddHashAggKernels(UnsignedIntTypes(), GroupedMeanFactory::Make, func.get()));
DCHECK_OK(
AddHashAggKernels(FloatingPointTypes(), GroupedMeanFactory::Make, func.get()));
+ // Type parameters are ignored
+ DCHECK_OK(AddHashAggKernels({decimal128(1, 1), decimal256(1, 1)},
+ GroupedMeanFactory::Make, func.get()));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 21440248493..a160461b5dc 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -745,6 +745,77 @@ TEST(GroupBy, SumOnly) {
}
}
+TEST(GroupBy, SumMeanProductDecimal) {
+ auto in_schema = schema({
+ field("argument0", decimal128(3, 2)),
+ field("argument1", decimal256(3, 2)),
+ field("key", int64()),
+ });
+
+ for (bool use_exec_plan : {false, true}) {
+ for (bool use_threads : {true, false}) {
+ SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
+
+ auto table = TableFromJSON(in_schema, {R"([
+ ["1.00", "1.00", 1],
+ [null, null, 1]
+ ])",
+ R"([
+ ["0.00", "0.00", 2],
+ [null, null, 3],
+ ["4.00", "4.00", null],
+ ["3.25", "3.25", 1],
+ ["0.12", "0.12", 2]
+ ])",
+ R"([
+ ["-0.25", "-0.25", 2],
+ ["0.75", "0.75", null],
+ [null, null, 3]
+ ])"});
+
+ ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
+ GroupByTest(
+ {
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ table->GetColumnByName("argument0"),
+ table->GetColumnByName("argument1"),
+ },
+ {table->GetColumnByName("key")},
+ {
+ {"hash_sum", nullptr},
+ {"hash_sum", nullptr},
+ {"hash_mean", nullptr},
+ {"hash_mean", nullptr},
+ {"hash_product", nullptr},
+ {"hash_product", nullptr},
+ },
+ use_threads, use_exec_plan));
+ SortBy({"key_0"}, &aggregated_and_grouped);
+
+ AssertDatumsEqual(ArrayFromJSON(struct_({
+ field("hash_sum", decimal128(3, 2)),
+ field("hash_sum", decimal256(3, 2)),
+ field("hash_mean", decimal128(3, 2)),
+ field("hash_mean", decimal256(3, 2)),
+ field("hash_product", decimal128(3, 2)),
+ field("hash_product", decimal256(3, 2)),
+ field("key_0", int64()),
+ }),
+ R"([
+ ["4.25", "4.25", "2.12", "2.12", "3.25", "3.25", 1],
+ ["-0.13", "-0.13", "-0.04", "-0.04", "0.00", "0.00", 2],
+ [null, null, null, null, null, null, 3],
+ ["4.75", "4.75", "2.37", "2.37", "3.00", "3.00", null]
+ ])"),
+ aggregated_and_grouped,
+ /*verbose=*/true);
+ }
+ }
+}
+
TEST(GroupBy, MeanOnly) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 9b5e950399a..c5ecbb419d1 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -186,35 +186,35 @@ Aggregations
Scalar aggregations operate on a (chunked) array or scalar value and reduce
the input to a single output value.
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| Function name | Arity | Input types | Output type | Options class | Notes |
-+===============+=======+=============+================+==================================+=======+
-| all | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| any | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| mean | Unary | Numeric | Scalar Float64 | :struct:`ScalarAggregateOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| min_max | Unary | Numeric | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(6) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| stddev | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| tdigest | Unary | Numeric | Scalar Float64 | :struct:`TDigestOptions` | \(7) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| variance | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| Function name | Arity | Input types | Output type | Options class | Notes |
++===============+=======+=============+========================+==================================+=======+
+| all | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| any | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| mean | Unary | Numeric | Scalar Decimal/Float64 | :struct:`ScalarAggregateOptions` | |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| min_max | Unary | Numeric | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(6) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| stddev | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| tdigest | Unary | Numeric | Scalar Float64 | :struct:`TDigestOptions` | \(7) |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
+| variance | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | |
++---------------+-------+-------------+------------------------+----------------------------------+-------+
Notes:
@@ -234,7 +234,8 @@ Notes:
Note that the output can have less than *N* elements if the input has
less than *N* distinct values.
-* \(5) Output is Int64, UInt64 or Float64, depending on the input type.
+* \(5) Output is Int64, UInt64, Float64, or Decimal128/256, depending on the
+ input type.
* \(6) Output is Float64 or input type, depending on QuantileOptions.
@@ -288,27 +289,29 @@ The supported aggregation functions are as follows. All function names are
prefixed with ``hash_``, which differentiates them from their scalar
equivalents above and reflects how they are implemented internally.
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| Function name | Arity | Input types | Output type | Options class | Notes |
-+===============+=======+=============+================+==================================+=======+
-| hash_all | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_any | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_mean | Unary | Numeric | Float64 | :struct:`ScalarAggregateOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_min_max | Unary | Numeric | Struct | :struct:`ScalarAggregateOptions` | \(3) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_stddev | Unary | Numeric | Float64 | :struct:`VarianceOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_sum | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_tdigest | Unary | Numeric | Float64 | :struct:`TDigestOptions` | \(5) |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
-| hash_variance | Unary | Numeric | Float64 | :struct:`VarianceOptions` | |
-+---------------+-------+-------------+----------------+----------------------------------+-------+
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| Function name | Arity | Input types | Output type | Options class | Notes |
++===============+=======+=============+=================+==================================+=======+
+| hash_all | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_any | Unary | Boolean | Boolean | :struct:`ScalarAggregateOptions` | \(1) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_count | Unary | Any | Int64 | :struct:`CountOptions` | \(2) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_min_max | Unary | Numeric | Struct | :struct:`ScalarAggregateOptions` | \(3) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_stddev | Unary | Numeric | Float64 | :struct:`VarianceOptions` | |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_sum | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_tdigest | Unary | Numeric | Float64 | :struct:`TDigestOptions` | \(5) |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
+| hash_variance | Unary | Numeric | Float64 | :struct:`VarianceOptions` | |
++---------------+-------+-------------+-----------------+----------------------------------+-------+
* \(1) If null values are taken into account, by setting the
:member:`ScalarAggregateOptions::skip_nulls` to false, then `Kleene logic`_
@@ -319,7 +322,8 @@ equivalents above and reflects how they are implemented internally.
* \(3) Output is a ``{"min": input type, "max": input type}`` Struct scalar.
-* \(4) Output is Int64, UInt64 or Float64, depending on the input type.
+* \(4) Output is Int64, UInt64, Float64, or Decimal128/256, depending on the
+ input type.
* \(5) T-digest computes approximate quantiles, and so only needs a
fixed amount of memory. See the `reference implementation