diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc b/cpp/src/arrow/acero/hash_aggregate_test.cc index de414219eb4..12d24429cb6 100644 --- a/cpp/src/arrow/acero/hash_aggregate_test.cc +++ b/cpp/src/arrow/acero/hash_aggregate_test.cc @@ -2000,6 +2000,74 @@ TEST_P(GroupBy, MinMaxScalar) { } } +TEST_P(GroupBy, MinMaxWithNaN) { + auto in_schema = schema({ + field("argument1", float32()), + field("argument2", float64()), + field("key", int64()), + }); + for (bool use_threads : {true, false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + + auto table = TableFromJSON(in_schema, {R"([ + [NaN, NaN, 1], + [NaN, NaN, 2], + [NaN, NaN, 3] +])", + R"([ + [NaN, NaN, 1], + [-Inf, -Inf, 2], + [Inf, Inf, 3] +])"}); + + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + GroupByTest( + { + table->GetColumnByName("argument1"), + table->GetColumnByName("argument1"), + table->GetColumnByName("argument1"), + table->GetColumnByName("argument2"), + table->GetColumnByName("argument2"), + table->GetColumnByName("argument2"), + }, + {table->GetColumnByName("key")}, + { + {"hash_min", nullptr}, + {"hash_max", nullptr}, + {"hash_min_max", nullptr}, + {"hash_min", nullptr}, + {"hash_max", nullptr}, + {"hash_min_max", nullptr}, + }, + use_threads)); + ValidateOutput(aggregated_and_grouped); + SortBy({"key_0"}, &aggregated_and_grouped); + + AssertDatumsEqual(ArrayFromJSON(struct_({ + field("key_0", int64()), + field("hash_min", float32()), + field("hash_max", float32()), + field("hash_min_max", struct_({ + field("min", float32()), + field("max", float32()), + })), + field("hash_min", float64()), + field("hash_max", float64()), + field("hash_min_max", struct_({ + field("min", float64()), + field("max", float64()), + })), + }), + R"([ + [1, NaN, NaN, {"min": NaN, "max": NaN}, NaN, NaN, {"min": NaN, "max": NaN}], + [2, -Inf, -Inf, {"min": -Inf, "max": -Inf}, -Inf, -Inf, {"min": -Inf, "max": -Inf}], + [3, Inf, Inf, {"min": Inf, "max": Inf}, Inf, Inf, {"min": Inf, "max": Inf}] + ])"), + aggregated_and_grouped, + /*verbose=*/true); + } +} + TEST_P(GroupBy, AnyAndAll) { for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc index 7f2bce4063d..3733f415a04 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc @@ -694,8 +694,8 @@ struct MinMaxState> { this->max = std::fmax(this->max, value); } - T min = std::numeric_limits::infinity(); - T max = -std::numeric_limits::infinity(); + T min = std::numeric_limits::quiet_NaN(); + T max = std::numeric_limits::quiet_NaN(); bool has_nulls = false; }; diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index a2bf0b97fd1..cdc62f946a9 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -1841,6 +1841,24 @@ class TestPrimitiveMinMaxKernel : public ::testing::Test { AssertMinMaxIsNull(array, options); } + void AssertMinMaxIsNaN(const Datum& array, const ScalarAggregateOptions& options) { + ASSERT_OK_AND_ASSIGN(Datum out, MinMax(array, options)); + for (const auto& val : out.scalar_as().value) { + ASSERT_TRUE(std::isnan(checked_cast(*val).value)); + } + } + + void AssertMinMaxIsNaN(const std::string& json, const ScalarAggregateOptions& options) { + auto array = ArrayFromJSON(type_singleton(), json); + AssertMinMaxIsNaN(array, options); + } + + void AssertMinMaxIsNaN(const std::vector& json, + const ScalarAggregateOptions& options) { + auto array = ChunkedArrayFromJSON(type_singleton(), json); + AssertMinMaxIsNaN(array, options); + } + std::shared_ptr type_singleton() { return default_type_instance(); } @@ -1963,6 +1981,9 @@ TYPED_TEST(TestFloatingMinMaxKernel, Floats) { this->AssertMinMaxIs("[5, Inf, 2, 3, 4]", 2.0, INFINITY, options); this->AssertMinMaxIs("[5, NaN, 2, 3, 4]", 2, 5, options); this->AssertMinMaxIs("[5, -Inf, 2, 3, 4]", -INFINITY, 5, options); + this->AssertMinMaxIs("[NaN, null, 42]", 42, 42, options); + this->AssertMinMaxIsNaN("[NaN, NaN]", options); + this->AssertMinMaxIsNaN("[NaN, null]", options); this->AssertMinMaxIs(chunked_input1, 1, 9, options); this->AssertMinMaxIs(chunked_input2, 1, 9, options); this->AssertMinMaxIs(chunked_input3, 1, 9, options); @@ -1980,6 +2001,7 @@ TYPED_TEST(TestFloatingMinMaxKernel, Floats) { this->AssertMinMaxIs("[5, -Inf, 2, 3, 4]", -INFINITY, 5, options); this->AssertMinMaxIsNull("[5, null, 2, 3, 4]", options); this->AssertMinMaxIsNull("[5, -Inf, null, 3, 4]", options); + this->AssertMinMaxIsNull("[NaN, null]", options); this->AssertMinMaxIsNull(chunked_input1, options); this->AssertMinMaxIsNull(chunked_input2, options); this->AssertMinMaxIsNull(chunked_input3, options); diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 19f7fc2e5b0..2ab5e574e22 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include #include @@ -270,52 +271,53 @@ struct GroupedCountImpl : public GroupedAggregator { // ---------------------------------------------------------------------- // MinMax implementation +// XXX: Consider making these concepts complete and moving to public header. + +template +concept CBooleanConcept = std::same_as; + +// XXX: Ideally we want to have std::floating_point = true. +template +concept CFloatingPointConcept = std::floating_point || std::same_as; + +template +concept CDecimalConcept = std::same_as || std::same_as || + std::same_as || std::same_as; + template struct AntiExtrema { static constexpr CType anti_min() { return std::numeric_limits::max(); } static constexpr CType anti_max() { return std::numeric_limits::min(); } }; -template <> -struct AntiExtrema { - static constexpr bool anti_min() { return true; } - static constexpr bool anti_max() { return false; } -}; - -template <> -struct AntiExtrema { - static constexpr float anti_min() { return std::numeric_limits::infinity(); } - static constexpr float anti_max() { return -std::numeric_limits::infinity(); } +template +struct AntiExtrema { + static constexpr CType anti_min() { return true; } + static constexpr CType anti_max() { return false; } }; -template <> -struct AntiExtrema { - static constexpr double anti_min() { return std::numeric_limits::infinity(); } - static constexpr double anti_max() { return -std::numeric_limits::infinity(); } +template +struct AntiExtrema { + static constexpr CType anti_min() { return std::numeric_limits::quiet_NaN(); } + static constexpr CType anti_max() { return std::numeric_limits::quiet_NaN(); } }; -template <> -struct AntiExtrema { - static constexpr Decimal32 anti_min() { return BasicDecimal32::GetMaxSentinel(); } - static constexpr Decimal32 anti_max() { return BasicDecimal32::GetMinSentinel(); } +template +struct AntiExtrema { + static constexpr CType anti_min() { return CType::GetMaxSentinel(); } + static constexpr CType anti_max() { return CType::GetMinSentinel(); } }; -template <> -struct AntiExtrema { - static constexpr Decimal64 anti_min() { return BasicDecimal64::GetMaxSentinel(); } - static constexpr Decimal64 anti_max() { return BasicDecimal64::GetMinSentinel(); } -}; - -template <> -struct AntiExtrema { - static constexpr Decimal128 anti_min() { return BasicDecimal128::GetMaxSentinel(); } - static constexpr Decimal128 anti_max() { return BasicDecimal128::GetMinSentinel(); } +template +struct MinMaxOp { + static constexpr CType min(CType a, CType b) { return std::min(a, b); } + static constexpr CType max(CType a, CType b) { return std::max(a, b); } }; -template <> -struct AntiExtrema { - static constexpr Decimal256 anti_min() { return BasicDecimal256::GetMaxSentinel(); } - static constexpr Decimal256 anti_max() { return BasicDecimal256::GetMinSentinel(); } +template +struct MinMaxOp { + static constexpr CType min(CType a, CType b) { return std::fmin(a, b); } + static constexpr CType max(CType a, CType b) { return std::fmax(a, b); } }; template @@ -352,8 +354,8 @@ struct GroupedMinMaxImpl final : public GroupedAggregator { VisitGroupedValues( batch, [&](uint32_t g, CType val) { - GetSet::Set(raw_mins, g, std::min(GetSet::Get(raw_mins, g), val)); - GetSet::Set(raw_maxes, g, std::max(GetSet::Get(raw_maxes, g), val)); + GetSet::Set(raw_mins, g, MinMaxOp::min(GetSet::Get(raw_mins, g), val)); + GetSet::Set(raw_maxes, g, MinMaxOp::max(GetSet::Get(raw_maxes, g), val)); bit_util::SetBit(has_values_.mutable_data(), g); }, [&](uint32_t g) { bit_util::SetBit(has_nulls_.mutable_data(), g); }); @@ -373,12 +375,12 @@ struct GroupedMinMaxImpl final : public GroupedAggregator { auto g = group_id_mapping.GetValues(1); for (uint32_t other_g = 0; static_cast(other_g) < group_id_mapping.length; ++other_g, ++g) { - GetSet::Set( - raw_mins, *g, - std::min(GetSet::Get(raw_mins, *g), GetSet::Get(other_raw_mins, other_g))); - GetSet::Set( - raw_maxes, *g, - std::max(GetSet::Get(raw_maxes, *g), GetSet::Get(other_raw_maxes, other_g))); + GetSet::Set(raw_mins, *g, + MinMaxOp::min(GetSet::Get(raw_mins, *g), + GetSet::Get(other_raw_mins, other_g))); + GetSet::Set(raw_maxes, *g, + MinMaxOp::max(GetSet::Get(raw_maxes, *g), + GetSet::Get(other_raw_maxes, other_g))); if (bit_util::GetBit(other->has_values_.data(), other_g)) { bit_util::SetBit(has_values_.mutable_data(), *g);