Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,43 @@ Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
return visitor.Create();
}

// For "min" and "max" functions: override finalize and return the actual value
template <MinOrMax min_or_max>
void AddMinOrMaxAggKernel(ScalarAggregateFunction* func,
ScalarAggregateFunction* min_max_func) {
auto sig = KernelSignature::Make(
{InputType(ValueDescr::ANY)},
OutputType([](KernelContext*,
const std::vector<ValueDescr>& descrs) -> Result<ValueDescr> {
// any[T] -> scalar[T]
return ValueDescr::Scalar(descrs.front().type);
}));

auto init = [min_max_func](
KernelContext* ctx,
const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
std::vector<ValueDescr> inputs = args.inputs;
ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
KernelInitArgs new_args{kernel, inputs, args.options};
return kernel->init(ctx, new_args);
};

auto finalize = [](KernelContext* ctx, Datum* out) -> Status {
Datum temp;
RETURN_NOT_OK(checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, &temp));
const auto& result = temp.scalar_as<StructScalar>();
DCHECK(result.is_valid);
*out = result.value[static_cast<uint8_t>(min_or_max)];
return Status::OK();
};

// Note SIMD level is always NONE, but the convenience kernel will
// dispatch to an appropriate implementation
ScalarAggregateKernel kernel(std::move(sig), std::move(init), AggregateConsume,
AggregateMerge, std::move(finalize));
DCHECK_OK(func->AddKernel(kernel));
}

// ----------------------------------------------------------------------
// Any implementation

Expand Down Expand Up @@ -663,6 +700,13 @@ const FunctionDoc min_max_doc{"Compute the minimum and maximum values of a numer
{"array"},
"ScalarAggregateOptions"};

const FunctionDoc min_or_max_doc{
"Compute the minimum or maximum values of a numeric array",
("Null values are ignored by default.\n"
"This can be changed through ScalarAggregateOptions."),
{"array"},
"ScalarAggregateOptions"};

const FunctionDoc any_doc{"Test whether any element in a boolean array evaluates to true",
("Null values are ignored by default.\n"
"If null values are taken into account by setting "
Expand Down Expand Up @@ -781,6 +825,18 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
}
#endif

auto min_max_func = func.get();
DCHECK_OK(registry->AddFunction(std::move(func)));

// Add min/max as convenience functions
func = std::make_shared<ScalarAggregateFunction>("min", Arity::Unary(), &min_or_max_doc,
&default_scalar_aggregate_options);
aggregate::AddMinOrMaxAggKernel<MinOrMax::Min>(func.get(), min_max_func);
DCHECK_OK(registry->AddFunction(std::move(func)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the comment of using enum is accepted, then use enum as template parameter.


func = std::make_shared<ScalarAggregateFunction>("max", Arity::Unary(), &min_or_max_doc,
&default_scalar_aggregate_options);
aggregate::AddMinOrMaxAggKernel<MinOrMax::Max>(func.get(), min_max_func);
DCHECK_OK(registry->AddFunction(std::move(func)));

func = std::make_shared<ScalarAggregateFunction>(
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ struct ScalarAggregator : public KernelState {
// kernel implementations together
enum class VarOrStd : bool { Var, Std };

// Helper to differentiate between min/max calculation so we can fold
// kernel implementations together
enum class MinOrMax : uint8_t { Min = 0, Max };

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);
Expand Down
25 changes: 21 additions & 4 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1177,11 +1177,23 @@ class TestPrimitiveMinMaxKernel : public ::testing::Test {
ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options));
const StructScalar& value = out.scalar_as<StructScalar>();

const auto& out_min = checked_cast<const ScalarType&>(*value.value[0]);
ASSERT_EQ(expected_min, out_min.value);
{
const auto& out_min = checked_cast<const ScalarType&>(*value.value[0]);
ASSERT_EQ(expected_min, out_min.value);

const auto& out_max = checked_cast<const ScalarType&>(*value.value[1]);
ASSERT_EQ(expected_max, out_max.value);
const auto& out_max = checked_cast<const ScalarType&>(*value.value[1]);
ASSERT_EQ(expected_max, out_max.value);
}

{
ASSERT_OK_AND_ASSIGN(out, CallFunction("min", {array}, &options));
const auto& out_min = out.scalar_as<ScalarType>();
ASSERT_EQ(expected_min, out_min.value);

ASSERT_OK_AND_ASSIGN(out, CallFunction("max", {array}, &options));
const auto& out_max = out.scalar_as<ScalarType>();
ASSERT_EQ(expected_max, out_max.value);
}
}

void AssertMinMaxIs(const std::string& json, c_type expected_min, c_type expected_max,
Expand Down Expand Up @@ -1427,6 +1439,11 @@ TEST(TestDecimalMinMaxKernel, Decimals) {
EXPECT_THAT(MinMax(chunked_input3, options),
ResultWith(ScalarFromJSON(ty, R"({"min": "1.01", "max": "9.42"})")));

EXPECT_THAT(CallFunction("min", {chunked_input1}, &options),
ResultWith(ScalarFromJSON(item_ty, R"("1.01")")));
EXPECT_THAT(CallFunction("max", {chunked_input1}, &options),
ResultWith(ScalarFromJSON(item_ty, R"("9.42")")));

EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, "null"), options),
ResultWith(ScalarFromJSON(ty, R"({"min": null, "max": null})")));
EXPECT_THAT(MinMax(ScalarFromJSON(item_ty, R"("1.00")"), options),
Expand Down
96 changes: 73 additions & 23 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -770,38 +770,34 @@ Result<std::unique_ptr<KernelState>> HashAggregateInit(KernelContext* ctx,
return std::move(impl);
}

Status HashAggregateResize(KernelContext* ctx, int64_t num_groups) {
return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups);
}
Status HashAggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch);
}
Status HashAggregateMerge(KernelContext* ctx, KernelState&& other,
const ArrayData& group_id_mapping) {
return checked_cast<GroupedAggregator*>(ctx->state())
->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping);
}
Status HashAggregateFinalize(KernelContext* ctx, Datum* out) {
return checked_cast<GroupedAggregator*>(ctx->state())->Finalize().Value(out);
}

HashAggregateKernel MakeKernel(InputType argument_type, KernelInit init) {
HashAggregateKernel kernel;

kernel.init = std::move(init);

kernel.signature = KernelSignature::Make(
{std::move(argument_type), InputType::Array(Type::UINT32)},
OutputType(
[](KernelContext* ctx, const std::vector<ValueDescr>&) -> Result<ValueDescr> {
return checked_cast<GroupedAggregator*>(ctx->state())->out_type();
}));

kernel.resize = [](KernelContext* ctx, int64_t num_groups) {
return checked_cast<GroupedAggregator*>(ctx->state())->Resize(num_groups);
};

kernel.consume = [](KernelContext* ctx, const ExecBatch& batch) {
return checked_cast<GroupedAggregator*>(ctx->state())->Consume(batch);
};

kernel.merge = [](KernelContext* ctx, KernelState&& other,
const ArrayData& group_id_mapping) {
return checked_cast<GroupedAggregator*>(ctx->state())
->Merge(checked_cast<GroupedAggregator&&>(other), group_id_mapping);
};

kernel.finalize = [](KernelContext* ctx, Datum* out) {
ARROW_ASSIGN_OR_RAISE(*out,
checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
return Status::OK();
};

kernel.resize = HashAggregateResize;
kernel.consume = HashAggregateConsume;
kernel.merge = HashAggregateMerge;
kernel.finalize = HashAggregateFinalize;
return kernel;
}

Expand Down Expand Up @@ -1929,6 +1925,35 @@ Result<std::unique_ptr<KernelState>> MinMaxInit(KernelContext* ctx,
return std::move(impl);
}

template <MinOrMax min_or_max>
HashAggregateKernel MakeMinOrMaxKernel(HashAggregateFunction* min_max_func) {
HashAggregateKernel kernel;
kernel.init = [min_max_func](
KernelContext* ctx,
const KernelInitArgs& args) -> Result<std::unique_ptr<KernelState>> {
std::vector<ValueDescr> inputs = args.inputs;
ARROW_ASSIGN_OR_RAISE(auto kernel, min_max_func->DispatchBest(&inputs));
KernelInitArgs new_args{kernel, inputs, args.options};
return kernel->init(ctx, new_args);
};
kernel.signature = KernelSignature::Make(
{InputType(ValueDescr::ANY), InputType::Array(Type::UINT32)},
OutputType([](KernelContext* ctx,
const std::vector<ValueDescr>& descrs) -> Result<ValueDescr> {
return ValueDescr::Array(descrs[0].type);
}));
Comment on lines +1943 to +1944
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Here OutputType is expressed as a lambda function, so should MinMaxType be also a lambda.

kernel.resize = HashAggregateResize;
kernel.consume = HashAggregateConsume;
kernel.merge = HashAggregateMerge;
kernel.finalize = [](KernelContext* ctx, Datum* out) {
ARROW_ASSIGN_OR_RAISE(Datum temp,
checked_cast<GroupedAggregator*>(ctx->state())->Finalize());
*out = temp.array_as<StructArray>()->field(static_cast<uint8_t>(min_or_max));
return Status::OK();
};
return kernel;
}

struct GroupedMinMaxFactory {
template <typename T>
enable_if_physical_integer<T, Status> Visit(const T&) {
Expand Down Expand Up @@ -2628,6 +2653,13 @@ const FunctionDoc hash_min_max_doc{
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_min_or_max_doc{
"Compute the minimum or maximum values of a numeric array",
("Null values are ignored by default.\n"
"This can be changed through ScalarAggregateOptions."),
{"array", "group_id_array"},
"ScalarAggregateOptions"};

const FunctionDoc hash_any_doc{"Test whether any element evaluates to true",
("Null values are ignored."),
{"array", "group_id_array"},
Expand Down Expand Up @@ -2750,6 +2782,7 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

HashAggregateFunction* min_max_func = nullptr;
{
auto func = std::make_shared<HashAggregateFunction>(
"hash_min_max", Arity::Binary(), &hash_min_max_doc,
Expand All @@ -2760,6 +2793,23 @@ void RegisterHashAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(AddHashAggKernels(
{null(), boolean(), decimal128(1, 1), decimal256(1, 1), month_interval()},
GroupedMinMaxFactory::Make, func.get()));
min_max_func = func.get();
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>(
"hash_min", Arity::Binary(), &hash_min_or_max_doc,
&default_scalar_aggregate_options);
DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel<MinOrMax::Min>(min_max_func)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

{
auto func = std::make_shared<HashAggregateFunction>(
"hash_max", Arity::Binary(), &hash_min_or_max_doc,
&default_scalar_aggregate_options);
DCHECK_OK(func->AddKernel(MakeMinOrMaxKernel<MinOrMax::Max>(min_max_func)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}

Expand Down
53 changes: 53 additions & 0 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,59 @@ TEST(GroupBy, MinMaxDecimal) {
}
}

TEST(GroupBy, MinOrMax) {
auto table =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe include tests of only hash_min and only hash_max.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's a need to test the kernels separately, since they're being run separately by the nodes - I group them mostly because it's a lot of vertical visual noise to duplicate these tests otherwise (we should probably clean them up at some point).

Copy link
Contributor

@edponce edponce Sep 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree (thus the "maybe"), if it is already hitting those code paths independently which seems it is.

TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
[1.0, 1],
[null, 1]
])",
R"([
[0.0, 2],
[null, 3],
[4.0, null],
[3.25, 1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extend test case to get min/max between NaN, Inf, null, value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

[0.125, 2]
])",
R"([
[-0.25, 2],
[0.75, null],
[null, 3]
])",
R"([
[NaN, 4],
[null, 4],
[Inf, 4],
[-Inf, 4],
[0.0, 4]
])"});

ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
GroupByTest({table->GetColumnByName("argument"),
table->GetColumnByName("argument")},
{table->GetColumnByName("key")},
{
{"hash_min", nullptr},
{"hash_max", nullptr},
},
/*use_threads=*/true, /*use_exec_plan=*/true));
SortBy({"key_0"}, &aggregated_and_grouped);

AssertDatumsEqual(ArrayFromJSON(struct_({
field("hash_min", float64()),
field("hash_max", float64()),
field("key_0", int64()),
}),
R"([
[1.0, 3.25, 1],
[-0.25, 0.125, 2],
[null, null, 3],
[-Inf, Inf, 4],
[0.75, 4.0, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
}

TEST(GroupBy, MinMaxScalar) {
BatchesWithSchema input;
input.batches = {
Expand Down
8 changes: 8 additions & 0 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,12 @@ the input to a single output value.
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| mean | Unary | Numeric | Scalar Decimal/Float64 | :struct:`ScalarAggregateOptions` | |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| min | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| min_max | Unary | Non-nested types | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) |
+---------------+-------+------------------+------------------------+----------------------------------+-------+
| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) |
Expand Down Expand Up @@ -307,8 +311,12 @@ equivalents above and reflects how they are implemented internally.
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_distinct | Unary | Any | Input type | :struct:`CountOptions` | \(2) |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_max | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_mean | Unary | Numeric | Decimal/Float64 | :struct:`ScalarAggregateOptions` | |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_min | Unary | Non-nested, non-binary/string-like | Input type | :struct:`ScalarAggregateOptions` | |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_min_max | Unary | Non-nested, non-binary/string-like | Struct | :struct:`ScalarAggregateOptions` | \(3) |
+---------------------+-------+------------------------------------+-----------------+----------------------------------+-------+
| hash_product | Unary | Numeric | Numeric | :struct:`ScalarAggregateOptions` | \(4) |
Expand Down