Skip to content
17 changes: 12 additions & 5 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,34 @@ struct CountDistinctImpl : public ScalarAggregator {
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
const ArrayData& arr = *batch[0].array();
this->has_nulls = arr.GetNullCount() > 0;

auto visit_null = []() { return Status::OK(); };
auto visit_value = [&](VisitorArgType arg) {
int y;
int32_t y;
return memo_table_->GetOrInsert(arg, &y);
};
RETURN_NOT_OK(VisitArraySpanInline<Type>(arr, visit_value, visit_null));
this->non_nulls += memo_table_->size();
this->has_nulls = arr.GetNullCount() > 0;

} else {
const Scalar& input = *batch[0].scalar();
this->has_nulls = !input.is_valid;

if (input.is_valid) {
this->non_nulls += batch.length;
int32_t unused;
RETURN_NOT_OK(memo_table_->GetOrInsert(UnboxScalar<Type>::Unbox(input), &unused));
}
}

this->non_nulls = memo_table_->size();

return Status::OK();
}

Status MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other_state = checked_cast<const CountDistinctImpl&>(src);
this->non_nulls += other_state.non_nulls;
RETURN_NOT_OK(this->memo_table_->MergeTable(*(other_state.memo_table_)));
this->non_nulls = this->memo_table_->size();
this->has_nulls = this->has_nulls || other_state.has_nulls;
return Status::OK();
}
Expand Down
72 changes: 72 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,11 +962,83 @@ class TestCountDistinctKernel : public ::testing::Test {
EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
}

void CheckChunkedArr(const std::shared_ptr<DataType>& type,
const std::vector<std::string>& json, int64_t expected_all,
bool has_nulls = true) {
Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls);
}

CountOptions only_valid{CountOptions::ONLY_VALID};
CountOptions only_null{CountOptions::ONLY_NULL};
CountOptions all{CountOptions::ALL};
};

TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) {
// Boolean
CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false);
CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", "[true]"}, 3);

// Number
for (auto ty : NumericTypes()) {
CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6, 8, 9, 10]"},
8);
CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7,
/*has_nulls=*/false);
}

// Date
CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"}, 4);
CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"}, 3);

// Time
CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14, null]"}, 4);
CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000, 11000]"}, 3);

CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null, 84203999999]", "[0]"},
3);
CheckChunkedArr(time64(TimeUnit::NANO),
{"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3);

// Timestamp & Duration
for (auto u : TimeUnit::values()) {
CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789, null]"},
3);

CheckChunkedArr(duration(u),
{"[123456789, 987654321, 123456789, 123456789]", "[123456789]"}, 2,
/*has_nulls=*/false);

auto ts =
std::vector<std::string>{R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])",
R"(["2020-01-01", null])", R"(["2020-01-01", null])"};
CheckChunkedArr(timestamp(u), ts, 3);
CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3);
}

// Interval
CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null, 9012]"},
3);
CheckChunkedArr(day_time_interval(),
{"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3);
CheckChunkedArr(month_day_nano_interval(),
{"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2);

// Binary & String & Fixed binary
auto samples = std::vector<std::string>{
R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba", null])"};

CheckChunkedArr(binary(), samples, 4);
CheckChunkedArr(large_binary(), samples, 4);
CheckChunkedArr(utf8(), samples, 4);
CheckChunkedArr(large_utf8(), samples, 4);
CheckChunkedArr(fixed_size_binary(3), samples, 4);

// Decimal
samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"};
CheckChunkedArr(decimal128(21, 3), samples, 3);
CheckChunkedArr(decimal256(13, 3), samples, 3);
}

TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) {
// Boolean
Check(boolean(), "[]", 0, /*has_nulls=*/false);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ struct UnboxScalar<Type, enable_if_has_string_view<Type>> {
using T = util::string_view;
static T Unbox(const Scalar& val) {
if (!val.is_valid) return util::string_view();
return util::string_view(*checked_cast<const BaseBinaryScalar&>(val).value);
return checked_cast<const ::arrow::internal::PrimitiveScalarBase&>(val).view();
}
};

Expand Down
32 changes: 32 additions & 0 deletions cpp/src/arrow/util/hashing.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,20 @@ class ScalarMemoTable : public MemoTable {
hash_t ComputeHash(const Scalar& value) const {
return ScalarHelper<Scalar, 0>::ComputeHash(value);
}

public:
// defined here so that `HashTableType` is visible
// Merge entries from `other_table` into `this->hash_table_`.
Status MergeTable(const ScalarMemoTable& other_table) {
const HashTableType& other_hashtable = other_table.hash_table_;

other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) {
int32_t unused;
DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused));
});
// TODO: ARROW-17074 - implement proper error handling
return Status::OK();
}
};

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -568,6 +582,15 @@ class SmallScalarMemoTable : public MemoTable {
// (which is also 1 + the largest memo index)
int32_t size() const override { return static_cast<int32_t>(index_to_value_.size()); }

// Merge entries from `other_table` into `this`.
Status MergeTable(const SmallScalarMemoTable& other_table) {
for (const Scalar& other_val : other_table.index_to_value_) {
int32_t unused;
RETURN_NOT_OK(this->GetOrInsert(other_val, &unused));
}
return Status::OK();
}

// Copy values starting from index `start` into `out_data`
void CopyValues(int32_t start, Scalar* out_data) const {
DCHECK_GE(start, 0);
Expand Down Expand Up @@ -824,6 +847,15 @@ class BinaryMemoTable : public MemoTable {
};
return hash_table_.Lookup(h, cmp_func);
}

public:
Status MergeTable(const BinaryMemoTable& other_table) {
other_table.VisitValues(0, [this](const util::string_view& other_value) {
int32_t unused;
DCHECK_OK(this->GetOrInsert(other_value, &unused));
});
return Status::OK();
}
};

template <typename T, typename Enable = void>
Expand Down
9 changes: 9 additions & 0 deletions r/tests/testthat/test-dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ test_that("Group by any/all", {
)
})

test_that("n_distinct() with many batches", {
tf <- tempfile()
write_parquet(dplyr::starwars, tf, chunk_size = 20)

ds <- open_dataset(tf)
expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(),
ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)))
})

test_that("n_distinct() on dataset", {
# With groupby
compare_dplyr_binding(
Expand Down