diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc b/cpp/src/arrow/acero/hash_aggregate_test.cc index dfeae61a6cd..871afcae1fa 100644 --- a/cpp/src/arrow/acero/hash_aggregate_test.cc +++ b/cpp/src/arrow/acero/hash_aggregate_test.cc @@ -1361,46 +1361,87 @@ void SortBy(std::vector names, Datum* aggregated_and_grouped) { } // namespace TEST_P(GroupBy, CountOnly) { - for (bool use_threads : {true, false}) { - SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); - - auto table = - TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([ - [1.0, 1], - [null, 1] - ])", - R"([ - [0.0, 2], - [null, 3], - [4.0, null], - [3.25, 1], - [0.125, 2] - ])", - R"([ - [-0.25, 2], - [0.75, null], - [null, 3] - ])"}); + const std::vector json = { + // Test inputs ("argument", "key") + R"([[1.0, 1], + [null, 1]])", + R"([[0.0, 2], + [null, 3], + [null, 2], + [4.0, null], + [3.25, 1], + [3.25, 1], + [0.125, 2]])", + R"([[-0.25, 2], + [0.75, null], + [null, 3]])", + }; + const auto skip_nulls = std::make_shared(CountOptions::ONLY_VALID); + const auto only_nulls = std::make_shared(CountOptions::ONLY_NULL); + const auto count_all = std::make_shared(CountOptions::ALL); + const auto possible_count_options = std::vector>{ + nullptr, // default = skip_nulls + skip_nulls, + only_nulls, + count_all, + }; + const auto expected_results = std::vector{ + // Results ("key_0", "hash_count") + // nullptr = skip_nulls + R"([[1, 3], + [2, 3], + [3, 0], + [null, 2]])", + // skip_nulls + R"([[1, 3], + [2, 3], + [3, 0], + [null, 2]])", + // only_nulls + R"([[1, 1], + [2, 1], + [3, 2], + [null, 0]])", + // count_all + R"([[1, 4], + [2, 4], + [3, 2], + [null, 2]])", + }; + // NOTE: the "key" column (1) does not appear in the possible run-end + // encoding transformations because GroupBy kernels do not support run-end + // encoded key arrays. + for (const auto& re_encode_cols : std::vector>{{}, {0}}) { + for (bool use_threads : {/*true, */ false}) { + SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); + for (size_t i = 0; i < possible_count_options.size(); i++) { + SCOPED_TRACE(possible_count_options[i] ? possible_count_options[i]->ToString() + : "default"); + auto table = TableFromJSON( + schema({field("argument", float64()), field("key", int64())}), json); + + auto transformed_table = table; + if (!re_encode_cols.empty()) { + ASSERT_OK_AND_ASSIGN(transformed_table, + RunEndEncodeTableColumns(*table, re_encode_cols)); + } - ASSERT_OK_AND_ASSIGN( - Datum aggregated_and_grouped, - GroupByTest({table->GetColumnByName("argument")}, {table->GetColumnByName("key")}, - { - {"hash_count", nullptr}, - }, - use_threads)); - SortBy({"key_0"}, &aggregated_and_grouped); + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + GroupByTest({transformed_table->GetColumnByName("argument")}, + {transformed_table->GetColumnByName("key")}, + { + {"hash_count", possible_count_options[i]}, + }, + use_threads)); + SortBy({"key_0"}, &aggregated_and_grouped); - AssertDatumsEqual( - ArrayFromJSON(struct_({field("key_0", int64()), field("hash_count", int64())}), - R"([ - [1, 2], - [2, 3], - [3, 0], - [null, 2] - ])"), - aggregated_and_grouped, - /*verbose=*/true); + AssertDatumsEqual(aggregated_and_grouped, + ArrayFromJSON(struct_({field("key_0", int64()), + field("hash_count", int64())}), + expected_results[i]), + /*verbose=*/true); + } + } } } diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 18c9eb720a0..8764e9c354c 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -203,7 +203,7 @@ void ArraySpan::SetMembers(const ArrayData& data) { type_id = ext_type->storage_type()->id(); } - if (data.buffers[0] == nullptr && type_id != Type::NA && + if ((data.buffers.size() == 0 || data.buffers[0] == nullptr) && type_id != Type::NA && type_id != Type::SPARSE_UNION && type_id != Type::DENSE_UNION) { // This should already be zero but we make for sure this->null_count = 0; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 380dde016ef..f7d7dc78028 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -45,6 +45,7 @@ #include "arrow/util/cpu_info.h" #include "arrow/util/int128_internal.h" #include "arrow/util/int_util_overflow.h" +#include "arrow/util/ree_util.h" #include "arrow/util/task_group.h" #include "arrow/util/tdigest.h" #include "arrow/util/thread_pool.h" @@ -305,6 +306,46 @@ struct GroupedCountImpl : public GroupedAggregator { return Status::OK(); } + template + struct RunEndEncodedCountImpl { + /// Count the number of valid or invalid values in a run-end-encoded array. + /// + /// \param[in] input the run-end-encoded array + /// \param[out] counts the counts being accumulated + /// \param[in] g the group ids of the values in the array + template + void DoCount(const ArraySpan& input, int64_t* counts, const uint32_t* g) { + ree_util::RunEndEncodedArraySpan ree_span(input); + const auto* physical_validity = ree_util::ValuesArray(input).GetValues(0); + auto end = ree_span.end(); + for (auto it = ree_span.begin(); it != end; ++it) { + const bool is_valid = bit_util::GetBit(physical_validity, it.index_into_array()); + if (is_valid == count_valid) { + for (int64_t i = 0; i < it.run_length(); ++i, ++g) { + counts[*g] += 1; + } + } else { + g += it.run_length(); + } + } + } + + void operator()(const ArraySpan& input, int64_t* counts, const uint32_t* g) { + auto ree_type = checked_cast(input.type); + switch (ree_type->run_end_type()->id()) { + case Type::INT16: + DoCount(input, counts, g); + break; + case Type::INT32: + DoCount(input, counts, g); + break; + default: + DoCount(input, counts, g); + break; + } + } + }; + Status Consume(const ExecSpan& batch) override { auto counts = reinterpret_cast(counts_.mutable_data()); auto g_begin = batch[1].array.GetValues(1); @@ -315,26 +356,61 @@ struct GroupedCountImpl : public GroupedAggregator { } } else if (batch[0].is_array()) { const ArraySpan& input = batch[0].array; - if (options_.mode == CountOptions::ONLY_VALID) { + if (options_.mode == CountOptions::ONLY_VALID) { // ONLY_VALID if (input.type->id() != arrow::Type::NA) { - arrow::internal::VisitSetBitRunsVoid( - input.buffers[0].data, input.offset, input.length, - [&](int64_t offset, int64_t length) { - auto g = g_begin + offset; - for (int64_t i = 0; i < length; ++i, ++g) { - counts[*g] += 1; - } - }); + const uint8_t* bitmap = input.buffers[0].data; + if (bitmap) { + arrow::internal::VisitSetBitRunsVoid( + bitmap, input.offset, input.length, [&](int64_t offset, int64_t length) { + auto g = g_begin + offset; + for (int64_t i = 0; i < length; ++i, ++g) { + counts[*g] += 1; + } + }); + } else { + // Array without validity bitmaps require special handling of nulls. + const bool all_valid = !input.MayHaveLogicalNulls(); + if (all_valid) { + for (int64_t i = 0; i < input.length; ++i, ++g_begin) { + counts[*g_begin] += 1; + } + } else { + switch (input.type->id()) { + case Type::RUN_END_ENCODED: + RunEndEncodedCountImpl{}(input, counts, g_begin); + break; + default: // Generic and forward-compatible version. + for (int64_t i = 0; i < input.length; ++i, ++g_begin) { + counts[*g_begin] += input.IsValid(i); + } + break; + } + } + } } } else { // ONLY_NULL if (input.type->id() == arrow::Type::NA) { for (int64_t i = 0; i < batch.length; ++i, ++g_begin) { counts[*g_begin] += 1; } - } else if (input.MayHaveNulls()) { - auto end = input.offset + input.length; - for (int64_t i = input.offset; i < end; ++i, ++g_begin) { - counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i); + } else if (input.MayHaveLogicalNulls()) { + if (input.HasValidityBitmap()) { + auto end = input.offset + input.length; + for (int64_t i = input.offset; i < end; ++i, ++g_begin) { + counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i); + } + } else { + // Arrays without validity bitmaps require special handling of nulls. + switch (input.type->id()) { + case Type::RUN_END_ENCODED: + RunEndEncodedCountImpl{}(input, counts, g_begin); + break; + default: // Generic and forward-compatible version. + for (int64_t i = 0; i < input.length; ++i, ++g_begin) { + counts[*g_begin] += input.IsNull(i); + } + break; + } } } } diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 37c430892d0..9569375bda9 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -47,6 +47,7 @@ #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/compute/api_vector.h" #include "arrow/datum.h" #include "arrow/ipc/json_simple.h" #include "arrow/pretty_print.h" @@ -427,6 +428,24 @@ std::shared_ptr TableFromJSON(const std::shared_ptr& schema, return *Table::FromRecordBatches(schema, std::move(batches)); } +Result> RunEndEncodeTableColumns( + const Table& table, const std::vector& column_indices) { + const int num_columns = table.num_columns(); + std::vector> encoded_columns; + encoded_columns.reserve(num_columns); + for (int i = 0; i < num_columns; i++) { + if (std::find(column_indices.begin(), column_indices.end(), i) != + column_indices.end()) { + ARROW_ASSIGN_OR_RAISE(auto run_end_encoded, compute::RunEndEncode(table.column(i))); + DCHECK_EQ(run_end_encoded.kind(), Datum::CHUNKED_ARRAY); + encoded_columns.push_back(run_end_encoded.chunked_array()); + } else { + encoded_columns.push_back(table.column(i)); + } + } + return Table::Make(table.schema(), std::move(encoded_columns)); +} + Result> PrintArrayDiff(const ChunkedArray& expected, const ChunkedArray& actual) { if (actual.Equals(expected)) { diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 27080562952..55bd307b12c 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -341,6 +341,10 @@ ARROW_TESTING_EXPORT std::shared_ptr
TableFromJSON(const std::shared_ptr&, const std::vector& json); +ARROW_TESTING_EXPORT +Result> RunEndEncodeTableColumns( + const Table& table, const std::vector& column_indices); + // Given an array, return a new identical array except for one validity bit // set to a new value. // This is useful to force the underlying "value" of null entries to otherwise