diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 57cee87f00d..fec483318ef 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -136,27 +136,34 @@ struct CountDistinctImpl : public ScalarAggregator { Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { const ArrayData& arr = *batch[0].array(); + this->has_nulls = arr.GetNullCount() > 0; + auto visit_null = []() { return Status::OK(); }; auto visit_value = [&](VisitorArgType arg) { - int y; + int32_t y; return memo_table_->GetOrInsert(arg, &y); }; RETURN_NOT_OK(VisitArraySpanInline(arr, visit_value, visit_null)); - this->non_nulls += memo_table_->size(); - this->has_nulls = arr.GetNullCount() > 0; + } else { const Scalar& input = *batch[0].scalar(); this->has_nulls = !input.is_valid; + if (input.is_valid) { - this->non_nulls += batch.length; + int32_t unused; + RETURN_NOT_OK(memo_table_->GetOrInsert(UnboxScalar::Unbox(input), &unused)); } } + + this->non_nulls = memo_table_->size(); + return Status::OK(); } Status MergeFrom(KernelContext*, KernelState&& src) override { const auto& other_state = checked_cast(src); - this->non_nulls += other_state.non_nulls; + RETURN_NOT_OK(this->memo_table_->MergeTable(*(other_state.memo_table_))); + this->non_nulls = this->memo_table_->size(); this->has_nulls = this->has_nulls || other_state.has_nulls; return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index aa54fe5f3e2..abd5b5210ae 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -962,11 +962,83 @@ class TestCountDistinctKernel : public ::testing::Test { EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); } + void CheckChunkedArr(const std::shared_ptr& type, + const std::vector& json, int64_t expected_all, + bool has_nulls = true) { + Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls); + } + CountOptions only_valid{CountOptions::ONLY_VALID}; CountOptions only_null{CountOptions::ONLY_NULL}; CountOptions all{CountOptions::ALL}; }; +TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) { + // Boolean + CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false); + CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", "[true]"}, 3); + + // Number + for (auto ty : NumericTypes()) { + CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6, 8, 9, 10]"}, + 8); + CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7, + /*has_nulls=*/false); + } + + // Date + CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"}, 4); + CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"}, 3); + + // Time + CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14, null]"}, 4); + CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000, 11000]"}, 3); + + CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null, 84203999999]", "[0]"}, + 3); + CheckChunkedArr(time64(TimeUnit::NANO), + {"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3); + + // Timestamp & Duration + for (auto u : TimeUnit::values()) { + CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789, null]"}, + 3); + + CheckChunkedArr(duration(u), + {"[123456789, 987654321, 123456789, 123456789]", "[123456789]"}, 2, + /*has_nulls=*/false); + + auto ts = + std::vector{R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])", + R"(["2020-01-01", null])", R"(["2020-01-01", null])"}; + CheckChunkedArr(timestamp(u), ts, 3); + CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3); + } + + // Interval + CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null, 9012]"}, + 3); + CheckChunkedArr(day_time_interval(), + {"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3); + CheckChunkedArr(month_day_nano_interval(), + {"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2); + + // Binary & String & Fixed binary + auto samples = std::vector{ + R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba", null])"}; + + CheckChunkedArr(binary(), samples, 4); + CheckChunkedArr(large_binary(), samples, 4); + CheckChunkedArr(utf8(), samples, 4); + CheckChunkedArr(large_utf8(), samples, 4); + CheckChunkedArr(fixed_size_binary(3), samples, 4); + + // Decimal + samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"}; + CheckChunkedArr(decimal128(21, 3), samples, 3); + CheckChunkedArr(decimal256(13, 3), samples, 3); +} + TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) { // Boolean Check(boolean(), "[]", 0, /*has_nulls=*/false); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 1d5f5dd9bd5..f008314e8be 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -343,7 +343,7 @@ struct UnboxScalar> { using T = util::string_view; static T Unbox(const Scalar& val) { if (!val.is_valid) return util::string_view(); - return util::string_view(*checked_cast(val).value); + return checked_cast(val).view(); } }; diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index d2c0178b008..ca5a6c766bd 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -485,6 +485,20 @@ class ScalarMemoTable : public MemoTable { hash_t ComputeHash(const Scalar& value) const { return ScalarHelper::ComputeHash(value); } + + public: + // defined here so that `HashTableType` is visible + // Merge entries from `other_table` into `this->hash_table_`. + Status MergeTable(const ScalarMemoTable& other_table) { + const HashTableType& other_hashtable = other_table.hash_table_; + + other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { + int32_t unused; + DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused)); + }); + // TODO: ARROW-17074 - implement proper error handling + return Status::OK(); + } }; // ---------------------------------------------------------------------- @@ -568,6 +582,15 @@ class SmallScalarMemoTable : public MemoTable { // (which is also 1 + the largest memo index) int32_t size() const override { return static_cast(index_to_value_.size()); } + // Merge entries from `other_table` into `this`. + Status MergeTable(const SmallScalarMemoTable& other_table) { + for (const Scalar& other_val : other_table.index_to_value_) { + int32_t unused; + RETURN_NOT_OK(this->GetOrInsert(other_val, &unused)); + } + return Status::OK(); + } + // Copy values starting from index `start` into `out_data` void CopyValues(int32_t start, Scalar* out_data) const { DCHECK_GE(start, 0); @@ -824,6 +847,15 @@ class BinaryMemoTable : public MemoTable { }; return hash_table_.Lookup(h, cmp_func); } + + public: + Status MergeTable(const BinaryMemoTable& other_table) { + other_table.VisitValues(0, [this](const util::string_view& other_value) { + int32_t unused; + DCHECK_OK(this->GetOrInsert(other_value, &unused)); + }); + return Status::OK(); + } }; template diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 5ad7425ee87..e3cb82a6e1d 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -218,6 +218,15 @@ test_that("Group by any/all", { ) }) +test_that("n_distinct() with many batches", { + tf <- tempfile() + write_parquet(dplyr::starwars, tf, chunk_size = 20) + + ds <- open_dataset(tf) + expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), + ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))) +}) + test_that("n_distinct() on dataset", { # With groupby compare_dplyr_binding(