Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 79 additions & 38 deletions cpp/src/arrow/acero/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1361,46 +1361,87 @@ void SortBy(std::vector<std::string> names, Datum* aggregated_and_grouped) {
} // namespace

TEST_P(GroupBy, CountOnly) {
for (bool use_threads : {true, false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");

auto table =
TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([
[1.0, 1],
[null, 1]
])",
R"([
[0.0, 2],
[null, 3],
[4.0, null],
[3.25, 1],
[0.125, 2]
])",
R"([
[-0.25, 2],
[0.75, null],
[null, 3]
])"});
const std::vector<std::string> json = {
// Test inputs ("argument", "key")
R"([[1.0, 1],
[null, 1]])",
R"([[0.0, 2],
[null, 3],
[null, 2],
[4.0, null],
[3.25, 1],
[3.25, 1],
[0.125, 2]])",
R"([[-0.25, 2],
[0.75, null],
[null, 3]])",
};
const auto skip_nulls = std::make_shared<CountOptions>(CountOptions::ONLY_VALID);
const auto only_nulls = std::make_shared<CountOptions>(CountOptions::ONLY_NULL);
const auto count_all = std::make_shared<CountOptions>(CountOptions::ALL);
const auto possible_count_options = std::vector<std::shared_ptr<CountOptions>>{
nullptr, // default = skip_nulls
skip_nulls,
only_nulls,
count_all,
};
const auto expected_results = std::vector<std::string>{
// Results ("key_0", "hash_count")
// nullptr = skip_nulls
R"([[1, 3],
[2, 3],
[3, 0],
[null, 2]])",
// skip_nulls
R"([[1, 3],
[2, 3],
[3, 0],
[null, 2]])",
// only_nulls
R"([[1, 1],
[2, 1],
[3, 2],
[null, 0]])",
// count_all
R"([[1, 4],
[2, 4],
[3, 2],
[null, 2]])",
};
// NOTE: the "key" column (1) does not appear in the possible run-end
// encoding transformations because GroupBy kernels do not support run-end
// encoded key arrays.
for (const auto& re_encode_cols : std::vector<std::vector<int>>{{}, {0}}) {
for (bool use_threads : {/*true, */ false}) {
SCOPED_TRACE(use_threads ? "parallel/merged" : "serial");
for (size_t i = 0; i < possible_count_options.size(); i++) {
SCOPED_TRACE(possible_count_options[i] ? possible_count_options[i]->ToString()
: "default");
auto table = TableFromJSON(
schema({field("argument", float64()), field("key", int64())}), json);

auto transformed_table = table;
if (!re_encode_cols.empty()) {
ASSERT_OK_AND_ASSIGN(transformed_table,
RunEndEncodeTableColumns(*table, re_encode_cols));
}

ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
GroupByTest({table->GetColumnByName("argument")}, {table->GetColumnByName("key")},
{
{"hash_count", nullptr},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
GroupByTest({transformed_table->GetColumnByName("argument")},
{transformed_table->GetColumnByName("key")},
{
{"hash_count", possible_count_options[i]},
},
use_threads));
SortBy({"key_0"}, &aggregated_and_grouped);

AssertDatumsEqual(
ArrayFromJSON(struct_({field("key_0", int64()), field("hash_count", int64())}),
R"([
[1, 2],
[2, 3],
[3, 0],
[null, 2]
])"),
aggregated_and_grouped,
/*verbose=*/true);
AssertDatumsEqual(aggregated_and_grouped,
ArrayFromJSON(struct_({field("key_0", int64()),
field("hash_count", int64())}),
expected_results[i]),
/*verbose=*/true);
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/array/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void ArraySpan::SetMembers(const ArrayData& data) {
type_id = ext_type->storage_type()->id();
}

if (data.buffers[0] == nullptr && type_id != Type::NA &&
if ((data.buffers.size() == 0 || data.buffers[0] == nullptr) && type_id != Type::NA &&
type_id != Type::SPARSE_UNION && type_id != Type::DENSE_UNION) {
// This should already be zero but we make for sure
this->null_count = 0;
Expand Down
102 changes: 89 additions & 13 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "arrow/util/cpu_info.h"
#include "arrow/util/int128_internal.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/ree_util.h"
#include "arrow/util/task_group.h"
#include "arrow/util/tdigest.h"
#include "arrow/util/thread_pool.h"
Expand Down Expand Up @@ -305,6 +306,46 @@ struct GroupedCountImpl : public GroupedAggregator {
return Status::OK();
}

template <bool count_valid>
struct RunEndEncodedCountImpl {
/// Count the number of valid or invalid values in a run-end-encoded array.
///
/// \param[in] input the run-end-encoded array
/// \param[out] counts the counts being accumulated
/// \param[in] g the group ids of the values in the array
template <typename RunEndCType>
void DoCount(const ArraySpan& input, int64_t* counts, const uint32_t* g) {
ree_util::RunEndEncodedArraySpan<RunEndCType> ree_span(input);
const auto* physical_validity = ree_util::ValuesArray(input).GetValues<uint8_t>(0);
auto end = ree_span.end();
for (auto it = ree_span.begin(); it != end; ++it) {
const bool is_valid = bit_util::GetBit(physical_validity, it.index_into_array());
if (is_valid == count_valid) {
for (int64_t i = 0; i < it.run_length(); ++i, ++g) {
counts[*g] += 1;
}
} else {
g += it.run_length();
}
}
}

void operator()(const ArraySpan& input, int64_t* counts, const uint32_t* g) {
auto ree_type = checked_cast<const RunEndEncodedType*>(input.type);
switch (ree_type->run_end_type()->id()) {
case Type::INT16:
DoCount<int16_t>(input, counts, g);
break;
case Type::INT32:
DoCount<int32_t>(input, counts, g);
break;
default:
DoCount<int64_t>(input, counts, g);
break;
}
}
};

Status Consume(const ExecSpan& batch) override {
auto counts = reinterpret_cast<int64_t*>(counts_.mutable_data());
auto g_begin = batch[1].array.GetValues<uint32_t>(1);
Expand All @@ -315,26 +356,61 @@ struct GroupedCountImpl : public GroupedAggregator {
}
} else if (batch[0].is_array()) {
const ArraySpan& input = batch[0].array;
if (options_.mode == CountOptions::ONLY_VALID) {
if (options_.mode == CountOptions::ONLY_VALID) { // ONLY_VALID
if (input.type->id() != arrow::Type::NA) {
arrow::internal::VisitSetBitRunsVoid(
input.buffers[0].data, input.offset, input.length,
[&](int64_t offset, int64_t length) {
auto g = g_begin + offset;
for (int64_t i = 0; i < length; ++i, ++g) {
counts[*g] += 1;
}
});
const uint8_t* bitmap = input.buffers[0].data;
if (bitmap) {
arrow::internal::VisitSetBitRunsVoid(
bitmap, input.offset, input.length, [&](int64_t offset, int64_t length) {
auto g = g_begin + offset;
for (int64_t i = 0; i < length; ++i, ++g) {
counts[*g] += 1;
}
});
} else {
// Array without validity bitmaps require special handling of nulls.
const bool all_valid = !input.MayHaveLogicalNulls();
if (all_valid) {
for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
counts[*g_begin] += 1;
}
} else {
switch (input.type->id()) {
case Type::RUN_END_ENCODED:
RunEndEncodedCountImpl<true>{}(input, counts, g_begin);
break;
default: // Generic and forward-compatible version.
for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
counts[*g_begin] += input.IsValid(i);
}
break;
}
}
}
}
} else { // ONLY_NULL
if (input.type->id() == arrow::Type::NA) {
for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
counts[*g_begin] += 1;
}
} else if (input.MayHaveNulls()) {
auto end = input.offset + input.length;
for (int64_t i = input.offset; i < end; ++i, ++g_begin) {
counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i);
} else if (input.MayHaveLogicalNulls()) {
if (input.HasValidityBitmap()) {
auto end = input.offset + input.length;
for (int64_t i = input.offset; i < end; ++i, ++g_begin) {
counts[*g_begin] += !bit_util::GetBit(input.buffers[0].data, i);
}
} else {
// Arrays without validity bitmaps require special handling of nulls.
switch (input.type->id()) {
case Type::RUN_END_ENCODED:
RunEndEncodedCountImpl<false>{}(input, counts, g_begin);
break;
default: // Generic and forward-compatible version.
for (int64_t i = 0; i < input.length; ++i, ++g_begin) {
counts[*g_begin] += input.IsNull(i);
}
break;
}
}
}
}
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/arrow/testing/gtest_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/compute/api_vector.h"
#include "arrow/datum.h"
#include "arrow/ipc/json_simple.h"
#include "arrow/pretty_print.h"
Expand Down Expand Up @@ -427,6 +428,24 @@ std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>& schema,
return *Table::FromRecordBatches(schema, std::move(batches));
}

Result<std::shared_ptr<Table>> RunEndEncodeTableColumns(
const Table& table, const std::vector<int>& column_indices) {
const int num_columns = table.num_columns();
std::vector<std::shared_ptr<ChunkedArray>> encoded_columns;
encoded_columns.reserve(num_columns);
for (int i = 0; i < num_columns; i++) {
if (std::find(column_indices.begin(), column_indices.end(), i) !=
column_indices.end()) {
ARROW_ASSIGN_OR_RAISE(auto run_end_encoded, compute::RunEndEncode(table.column(i)));
DCHECK_EQ(run_end_encoded.kind(), Datum::CHUNKED_ARRAY);
encoded_columns.push_back(run_end_encoded.chunked_array());
} else {
encoded_columns.push_back(table.column(i));
}
}
return Table::Make(table.schema(), std::move(encoded_columns));
}

Result<std::optional<std::string>> PrintArrayDiff(const ChunkedArray& expected,
const ChunkedArray& actual) {
if (actual.Equals(expected)) {
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/testing/gtest_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ ARROW_TESTING_EXPORT
std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>&,
const std::vector<std::string>& json);

ARROW_TESTING_EXPORT
Result<std::shared_ptr<Table>> RunEndEncodeTableColumns(
const Table& table, const std::vector<int>& column_indices);

// Given an array, return a new identical array except for one validity bit
// set to a new value.
// This is useful to force the underlying "value" of null entries to otherwise
Expand Down