From 68c3e3e004207f25b4a3a8215923651f22de080d Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Tue, 12 Jul 2022 08:47:03 -0700 Subject: [PATCH 01/11] [ARROW-16807]: fixed merge logic in count_distinct The bug described in ARROW-16807 is that when merging states for count distinct, the non_null counts are simply added (no overload that I saw). The fix was to change this into a proper merge of the data in the MemoTable. There are 3 derived classes from MemoTable. This adds a `MaybeInsert` and `MergeTable` function to each to support the merge logic. --- .../arrow/compute/kernels/aggregate_basic.cc | 3 +- cpp/src/arrow/util/hashing.h | 97 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 57cee87f00d..80596d36941 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -156,7 +156,8 @@ struct CountDistinctImpl : public ScalarAggregator { Status MergeFrom(KernelContext*, KernelState&& src) override { const auto& other_state = checked_cast(src); - this->non_nulls += other_state.non_nulls; + this->memo_table_->MergeTable(*(other_state.memo_table_)); + this->non_nulls = this->memo_table_->size(); this->has_nulls = this->has_nulls || other_state.has_nulls; return Status::OK(); } diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index d2c0178b008..6f4c4b8cca1 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -428,6 +428,24 @@ class ScalarMemoTable : public MemoTable { value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); } + Status MaybeInsert(const Scalar& value) { + auto cmp_func = [value](const Payload* payload) -> bool { + return ScalarHelper::CompareScalars(value, payload->value); + }; + + hash_t val_hash = ComputeHash(value); + auto hash_entry = hash_table_.Lookup(val_hash, cmp_func); + + // Insert if it wasn't found; otherwise, we're done + if (!hash_entry.second) { + RETURN_NOT_OK( + hash_table_.Insert(hash_entry.first, val_hash, { value, size() }) + ); + } + + return Status::OK(); + } + int32_t GetNull() const { return null_index_; } template @@ -485,6 +503,22 @@ class ScalarMemoTable : public MemoTable { hash_t ComputeHash(const Scalar& value) const { return ScalarHelper::ComputeHash(value); } + + public: + // defined here so that `HashTableType` is visible + // Merge entries from `other_table` into `this->hash_table_`. + void MergeTable(ScalarMemoTable &other_table) { + HashTableType &other_hashtable = other_table.hash_table_; + + other_hashtable.VisitEntries( + [=](const HashTableEntry *other_entry) { + ARROW_WARN_NOT_OK( + this->MaybeInsert(other_entry->payload.value) + ,"Merging ScalarMemoTable" + ); + } + ); + } }; // ---------------------------------------------------------------------- @@ -545,6 +579,22 @@ class SmallScalarMemoTable : public MemoTable { value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); } + Status MaybeInsert(const Scalar& value) { + auto value_index = AsIndex(value); + auto memo_index = value_to_index_[value_index]; + + if (memo_index == kKeyNotFound) { + memo_index = static_cast(index_to_value_.size()); + + index_to_value_.push_back(value); + value_to_index_[value_index] = memo_index; + + DCHECK_LT(memo_index, cardinality + 1); + } + + return Status::OK(); + } + int32_t GetNull() const { return value_to_index_[cardinality]; } template @@ -568,6 +618,16 @@ class SmallScalarMemoTable : public MemoTable { // (which is also 1 + the largest memo index) int32_t size() const override { return static_cast(index_to_value_.size()); } + // Merge entries from `other_table` into `this`. + void MergeTable(SmallScalarMemoTable &other_table) { + for (const Scalar &other_val : other_table.index_to_value_) { + auto insert_status = this->MaybeInsert(other_val); + if (not insert_status.ok()) { + ARROW_WARN_NOT_OK(insert_status, "Merging SmallScalarMemoTable"); + } + } + } + // Copy values starting from index `start` into `out_data` void CopyValues(int32_t start, Scalar* out_data) const { DCHECK_GE(start, 0); @@ -683,6 +743,30 @@ class BinaryMemoTable : public MemoTable { return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {}); } + Status MaybeInsert(const util::string_view& value) { + const void *val_data = value.data(); + auto val_length = static_cast(value.length()); + + hash_t val_hash = ComputeStringHash<0>(val_data, val_length); + auto hash_entry = Lookup(val_hash, val_data, val_length); + + if (!hash_entry.second) { + // Insert string value + RETURN_NOT_OK( + binary_builder_.Append(static_cast(val_data), val_length) + ); + + // Insert hash entry + RETURN_NOT_OK( + hash_table_.Insert( + const_cast(hash_entry.first), val_hash, { size() } + ) + ); + } + + return Status::OK(); + } + // The number of entries in the memo table // (which is also 1 + the largest memo index) int32_t size() const override { @@ -824,6 +908,19 @@ class BinaryMemoTable : public MemoTable { }; return hash_table_.Lookup(h, cmp_func); } + + public: + void MergeTable(BinaryMemoTable &other_table) { + other_table.VisitValues( + 0 + ,[=](const util::string_view &other_value) { + ARROW_WARN_NOT_OK( + this->MaybeInsert(other_value) + ,"Merging BinaryMemoTable" + ); + } + ); + } }; template From 797ad4f9b98f4f6115e91fcfd5b1d770711f4cae Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Tue, 12 Jul 2022 08:48:01 -0700 Subject: [PATCH 02/11] [ARROW-16807]: new tests on chunked arrays This adds a unittest to run count_distinct on chunked arrays, which were incorrectly handled before. --- .../arrow/compute/kernels/aggregate_test.cc | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index aa54fe5f3e2..319e7fb5b12 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -962,11 +962,76 @@ class TestCountDistinctKernel : public ::testing::Test { EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); } + void CheckChunkedArr( const std::shared_ptr &type + ,const std::vector &json + ,int64_t expected_all + ,bool has_nulls = true) { + Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls); + } + CountOptions only_valid{CountOptions::ONLY_VALID}; CountOptions only_null{CountOptions::ONLY_NULL}; CountOptions all{CountOptions::ALL}; }; +TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) { + // Boolean + CheckChunkedArr(boolean(), {"[]" , "[]"}, 0, /*has_nulls=*/false); + CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", "[true]"}, 3); + + // Number + for (auto ty : NumericTypes()) { + CheckChunkedArr(ty,{"[1, 1, null, 2]","[5, 8, 9, 9, null, 10]","[6, 6, 8, 9, 10]"},8); + CheckChunkedArr(ty,{"[1, 1, 8, 2]","[5, 8, 9, 9, 10]","[10, 6, 6]"},7,/*has_nulls=*/false); + } + + // Date + CheckChunkedArr(date32(),{"[0, 11016]", "[0, null, 14241, 14241, null]"},4); + CheckChunkedArr(date64(),{"[0, null]", "[0, null, 0, 0, 1262217600000]"},3); + + // Time + CheckChunkedArr(time32(TimeUnit::SECOND),{"[ 0, 11, 0, null]", "[14, 14, null]"},4); + CheckChunkedArr(time32(TimeUnit::MILLI),{"[ 0, 11000, 0]", "[null, 11000, 11000]"},3); + + CheckChunkedArr(time64(TimeUnit::MICRO),{"[84203999999, 0, null, 84203999999]", "[0]"},3); + CheckChunkedArr(time64(TimeUnit::NANO),{"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"},3); + + // Timestamp & Duration + for (auto u : TimeUnit::values()) { + CheckChunkedArr(duration(u),{"[123456789, null, 987654321]", "[123456789, null]"},3); + + CheckChunkedArr(duration(u), + {"[123456789, 987654321, 123456789, 123456789]","[123456789]"},2, + /*has_nulls=*/false); + + auto ts = std::vector {R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])", + R"(["2020-01-01", null])",R"(["2020-01-01", null])"}; + CheckChunkedArr(timestamp(u), ts, 3); + CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3); + } + + // Interval + CheckChunkedArr(month_interval(),{"[9012, 5678, null, 9012]", "[5678, null, 9012]"},3); + CheckChunkedArr(day_time_interval(),{"[[0, 1], [0, 1]]","[null, [0, 1], [1234, 5678]]"},3); + CheckChunkedArr(month_day_nano_interval(),{"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"},2); + + // Binary & String & Fixed binary + auto samples = std::vector {R"([null, "abc", null])", + R"(["abc", "abc", "cba"])", + R"(["bca", "cba", null])"}; + + CheckChunkedArr( binary(), samples, 4); + CheckChunkedArr( large_binary(), samples, 4); + CheckChunkedArr( utf8(), samples, 4); + CheckChunkedArr( large_utf8(), samples, 4); + CheckChunkedArr(fixed_size_binary(3), samples, 4); + + // Decimal + samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"}; + CheckChunkedArr(decimal128(21, 3), samples, 3); + CheckChunkedArr(decimal256(13, 3), samples, 3); +} + TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) { // Boolean Check(boolean(), "[]", 0, /*has_nulls=*/false); From cb18b294cd26b8b35c708009e922ffd1697136a5 Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Tue, 12 Jul 2022 09:24:31 -0700 Subject: [PATCH 03/11] [ARROW-16807]: Adding an R test This is a test to reproduce the issue seen in ARROW-16807, which calls `count_distinct` (via `n_distinct`) on a dataset with many chunks. If `summarize` is called before `collect`, then count_distinct merges distinct counts across chunks incorrectly. Calling collect before summarize does not expose the bug. --- r/tests/testthat/test-dplyr-summarize.R | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 5ad7425ee87..13c5ce530fa 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -218,6 +218,14 @@ test_that("Group by any/all", { ) }) +test_that("n_distinct() with many batches", { + tf <- tempfile() + write_parquet(dplyr::starwars, tf, chunk_size = 20) + + ds <- open_dataset(tf) + expect_true(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect() == 5) +}) + test_that("n_distinct() on dataset", { # With groupby compare_dplyr_binding( From 8452dfdfa0b69bc97d77a62984d53b0a2ab77631 Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Tue, 12 Jul 2022 09:29:51 -0700 Subject: [PATCH 04/11] [ARROW-16807]: minor addition to test Added an extra assert to be sure that collect and summarize produce the same results when commuted. --- r/tests/testthat/test-dplyr-summarize.R | 1 + 1 file changed, 1 insertion(+) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 13c5ce530fa..596be9671e8 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -224,6 +224,7 @@ test_that("n_distinct() with many batches", { ds <- open_dataset(tf) expect_true(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect() == 5) + expect_true(ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)) == 5) }) test_that("n_distinct() on dataset", { From 1debb4ebb37a911c19fc050a22e48576455b9037 Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Tue, 12 Jul 2022 09:40:32 -0700 Subject: [PATCH 05/11] [ARROW-16807]: fixing style --- .../arrow/compute/kernels/aggregate_test.cc | 67 ++++++++++--------- cpp/src/arrow/util/hashing.h | 64 +++++++----------- 2 files changed, 61 insertions(+), 70 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 319e7fb5b12..abd5b5210ae 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -962,10 +962,9 @@ class TestCountDistinctKernel : public ::testing::Test { EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one); } - void CheckChunkedArr( const std::shared_ptr &type - ,const std::vector &json - ,int64_t expected_all - ,bool has_nulls = true) { + void CheckChunkedArr(const std::shared_ptr& type, + const std::vector& json, int64_t expected_all, + bool has_nulls = true) { Check(ChunkedArrayFromJSON(type, json), expected_all, has_nulls); } @@ -976,58 +975,66 @@ class TestCountDistinctKernel : public ::testing::Test { TEST_F(TestCountDistinctKernel, AllChunkedArrayTypesWithNulls) { // Boolean - CheckChunkedArr(boolean(), {"[]" , "[]"}, 0, /*has_nulls=*/false); + CheckChunkedArr(boolean(), {"[]", "[]"}, 0, /*has_nulls=*/false); CheckChunkedArr(boolean(), {"[true, null]", "[false, null, false]", "[true]"}, 3); // Number for (auto ty : NumericTypes()) { - CheckChunkedArr(ty,{"[1, 1, null, 2]","[5, 8, 9, 9, null, 10]","[6, 6, 8, 9, 10]"},8); - CheckChunkedArr(ty,{"[1, 1, 8, 2]","[5, 8, 9, 9, 10]","[10, 6, 6]"},7,/*has_nulls=*/false); + CheckChunkedArr(ty, {"[1, 1, null, 2]", "[5, 8, 9, 9, null, 10]", "[6, 6, 8, 9, 10]"}, + 8); + CheckChunkedArr(ty, {"[1, 1, 8, 2]", "[5, 8, 9, 9, 10]", "[10, 6, 6]"}, 7, + /*has_nulls=*/false); } // Date - CheckChunkedArr(date32(),{"[0, 11016]", "[0, null, 14241, 14241, null]"},4); - CheckChunkedArr(date64(),{"[0, null]", "[0, null, 0, 0, 1262217600000]"},3); + CheckChunkedArr(date32(), {"[0, 11016]", "[0, null, 14241, 14241, null]"}, 4); + CheckChunkedArr(date64(), {"[0, null]", "[0, null, 0, 0, 1262217600000]"}, 3); // Time - CheckChunkedArr(time32(TimeUnit::SECOND),{"[ 0, 11, 0, null]", "[14, 14, null]"},4); - CheckChunkedArr(time32(TimeUnit::MILLI),{"[ 0, 11000, 0]", "[null, 11000, 11000]"},3); + CheckChunkedArr(time32(TimeUnit::SECOND), {"[ 0, 11, 0, null]", "[14, 14, null]"}, 4); + CheckChunkedArr(time32(TimeUnit::MILLI), {"[ 0, 11000, 0]", "[null, 11000, 11000]"}, 3); - CheckChunkedArr(time64(TimeUnit::MICRO),{"[84203999999, 0, null, 84203999999]", "[0]"},3); - CheckChunkedArr(time64(TimeUnit::NANO),{"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"},3); + CheckChunkedArr(time64(TimeUnit::MICRO), {"[84203999999, 0, null, 84203999999]", "[0]"}, + 3); + CheckChunkedArr(time64(TimeUnit::NANO), + {"[11715003000000, 0, null, 0, 0]", "[0, 0, null]"}, 3); // Timestamp & Duration for (auto u : TimeUnit::values()) { - CheckChunkedArr(duration(u),{"[123456789, null, 987654321]", "[123456789, null]"},3); + CheckChunkedArr(duration(u), {"[123456789, null, 987654321]", "[123456789, null]"}, + 3); CheckChunkedArr(duration(u), - {"[123456789, 987654321, 123456789, 123456789]","[123456789]"},2, + {"[123456789, 987654321, 123456789, 123456789]", "[123456789]"}, 2, /*has_nulls=*/false); - auto ts = std::vector {R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])", - R"(["2020-01-01", null])",R"(["2020-01-01", null])"}; - CheckChunkedArr(timestamp(u), ts, 3); + auto ts = + std::vector{R"(["2009-12-31T04:20:20", "2009-12-31T04:20:20"])", + R"(["2020-01-01", null])", R"(["2020-01-01", null])"}; + CheckChunkedArr(timestamp(u), ts, 3); CheckChunkedArr(timestamp(u, "Pacific/Marquesas"), ts, 3); } // Interval - CheckChunkedArr(month_interval(),{"[9012, 5678, null, 9012]", "[5678, null, 9012]"},3); - CheckChunkedArr(day_time_interval(),{"[[0, 1], [0, 1]]","[null, [0, 1], [1234, 5678]]"},3); - CheckChunkedArr(month_day_nano_interval(),{"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"},2); + CheckChunkedArr(month_interval(), {"[9012, 5678, null, 9012]", "[5678, null, 9012]"}, + 3); + CheckChunkedArr(day_time_interval(), + {"[[0, 1], [0, 1]]", "[null, [0, 1], [1234, 5678]]"}, 3); + CheckChunkedArr(month_day_nano_interval(), + {"[[0, 1, 2]]", "[[0, 1, 2], null, [0, 1, 2]]"}, 2); // Binary & String & Fixed binary - auto samples = std::vector {R"([null, "abc", null])", - R"(["abc", "abc", "cba"])", - R"(["bca", "cba", null])"}; - - CheckChunkedArr( binary(), samples, 4); - CheckChunkedArr( large_binary(), samples, 4); - CheckChunkedArr( utf8(), samples, 4); - CheckChunkedArr( large_utf8(), samples, 4); + auto samples = std::vector{ + R"([null, "abc", null])", R"(["abc", "abc", "cba"])", R"(["bca", "cba", null])"}; + + CheckChunkedArr(binary(), samples, 4); + CheckChunkedArr(large_binary(), samples, 4); + CheckChunkedArr(utf8(), samples, 4); + CheckChunkedArr(large_utf8(), samples, 4); CheckChunkedArr(fixed_size_binary(3), samples, 4); // Decimal - samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"}; + samples = {R"(["12345.679", "98765.421"])", R"([null, "12345.679", "98765.421"])"}; CheckChunkedArr(decimal128(21, 3), samples, 3); CheckChunkedArr(decimal256(13, 3), samples, 3); } diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 6f4c4b8cca1..b10daa9ae8b 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -433,14 +433,12 @@ class ScalarMemoTable : public MemoTable { return ScalarHelper::CompareScalars(value, payload->value); }; - hash_t val_hash = ComputeHash(value); - auto hash_entry = hash_table_.Lookup(val_hash, cmp_func); + hash_t val_hash = ComputeHash(value); + auto hash_entry = hash_table_.Lookup(val_hash, cmp_func); // Insert if it wasn't found; otherwise, we're done if (!hash_entry.second) { - RETURN_NOT_OK( - hash_table_.Insert(hash_entry.first, val_hash, { value, size() }) - ); + RETURN_NOT_OK(hash_table_.Insert(hash_entry.first, val_hash, {value, size()})); } return Status::OK(); @@ -507,17 +505,13 @@ class ScalarMemoTable : public MemoTable { public: // defined here so that `HashTableType` is visible // Merge entries from `other_table` into `this->hash_table_`. - void MergeTable(ScalarMemoTable &other_table) { - HashTableType &other_hashtable = other_table.hash_table_; - - other_hashtable.VisitEntries( - [=](const HashTableEntry *other_entry) { - ARROW_WARN_NOT_OK( - this->MaybeInsert(other_entry->payload.value) - ,"Merging ScalarMemoTable" - ); - } - ); + void MergeTable(ScalarMemoTable& other_table) { + HashTableType& other_hashtable = other_table.hash_table_; + + other_hashtable.VisitEntries([=](const HashTableEntry* other_entry) { + ARROW_WARN_NOT_OK(this->MaybeInsert(other_entry->payload.value), + "Merging ScalarMemoTable"); + }); } }; @@ -581,7 +575,7 @@ class SmallScalarMemoTable : public MemoTable { Status MaybeInsert(const Scalar& value) { auto value_index = AsIndex(value); - auto memo_index = value_to_index_[value_index]; + auto memo_index = value_to_index_[value_index]; if (memo_index == kKeyNotFound) { memo_index = static_cast(index_to_value_.size()); @@ -619,8 +613,8 @@ class SmallScalarMemoTable : public MemoTable { int32_t size() const override { return static_cast(index_to_value_.size()); } // Merge entries from `other_table` into `this`. - void MergeTable(SmallScalarMemoTable &other_table) { - for (const Scalar &other_val : other_table.index_to_value_) { + void MergeTable(SmallScalarMemoTable& other_table) { + for (const Scalar& other_val : other_table.index_to_value_) { auto insert_status = this->MaybeInsert(other_val); if (not insert_status.ok()) { ARROW_WARN_NOT_OK(insert_status, "Merging SmallScalarMemoTable"); @@ -744,24 +738,20 @@ class BinaryMemoTable : public MemoTable { } Status MaybeInsert(const util::string_view& value) { - const void *val_data = value.data(); - auto val_length = static_cast(value.length()); + const void* val_data = value.data(); + auto val_length = static_cast(value.length()); - hash_t val_hash = ComputeStringHash<0>(val_data, val_length); - auto hash_entry = Lookup(val_hash, val_data, val_length); + hash_t val_hash = ComputeStringHash<0>(val_data, val_length); + auto hash_entry = Lookup(val_hash, val_data, val_length); if (!hash_entry.second) { // Insert string value RETURN_NOT_OK( - binary_builder_.Append(static_cast(val_data), val_length) - ); + binary_builder_.Append(static_cast(val_data), val_length)); // Insert hash entry - RETURN_NOT_OK( - hash_table_.Insert( - const_cast(hash_entry.first), val_hash, { size() } - ) - ); + RETURN_NOT_OK(hash_table_.Insert(const_cast(hash_entry.first), + val_hash, {size()})); } return Status::OK(); @@ -910,16 +900,10 @@ class BinaryMemoTable : public MemoTable { } public: - void MergeTable(BinaryMemoTable &other_table) { - other_table.VisitValues( - 0 - ,[=](const util::string_view &other_value) { - ARROW_WARN_NOT_OK( - this->MaybeInsert(other_value) - ,"Merging BinaryMemoTable" - ); - } - ); + void MergeTable(BinaryMemoTable& other_table) { + other_table.VisitValues(0, [=](const util::string_view& other_value) { + ARROW_WARN_NOT_OK(this->MaybeInsert(other_value), "Merging BinaryMemoTable"); + }); } }; From 63470e29943f4b4e70ca2faf2cb949d14864e9fa Mon Sep 17 00:00:00 2001 From: Aldrin Montana Date: Wed, 13 Jul 2022 11:45:25 -0700 Subject: [PATCH 06/11] [ARROW-16807]: minor update to R test --- r/tests/testthat/test-dplyr-summarize.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 596be9671e8..9c713666097 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -223,8 +223,8 @@ test_that("n_distinct() with many batches", { write_parquet(dplyr::starwars, tf, chunk_size = 20) ds <- open_dataset(tf) - expect_true(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect() == 5) - expect_true(ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE)) == 5) + expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), + ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))) }) test_that("n_distinct() on dataset", { From 821f39f6dac89ed2e3b6fc6d6712e354930b410d Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Wed, 13 Jul 2022 16:21:18 -0700 Subject: [PATCH 07/11] [ARROW-16807]: fixed path for scalar input The path for count_distinct with scalar inputs didn't update state, and instead only added a single count. To fix this, we use MaybeInsert and UnboxScalar to insert the new value. The non_null count can then be set the same way for both vector and scalar inputs --- cpp/src/arrow/compute/kernels/aggregate_basic.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 80596d36941..a988395c50f 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -136,21 +136,26 @@ struct CountDistinctImpl : public ScalarAggregator { Status Consume(KernelContext*, const ExecBatch& batch) override { if (batch[0].is_array()) { const ArrayData& arr = *batch[0].array(); + this->has_nulls = arr.GetNullCount() > 0; + auto visit_null = []() { return Status::OK(); }; auto visit_value = [&](VisitorArgType arg) { int y; return memo_table_->GetOrInsert(arg, &y); }; RETURN_NOT_OK(VisitArraySpanInline(arr, visit_value, visit_null)); - this->non_nulls += memo_table_->size(); - this->has_nulls = arr.GetNullCount() > 0; + } else { const Scalar& input = *batch[0].scalar(); this->has_nulls = !input.is_valid; + if (input.is_valid) { - this->non_nulls += batch.length; + RETURN_NOT_OK(memo_table_->MaybeInsert(UnboxScalar::Unbox(input))); } } + + this->non_nulls = memo_table_->size(); + return Status::OK(); } From 7372a0622e10980be4265700ddf754320fc195af Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Wed, 13 Jul 2022 16:24:31 -0700 Subject: [PATCH 08/11] [ARROW-16807]: updated UnboxScalar for string_view The UnboxScalar implementation that has `enable_if_has_string_view` is also true for Decimal128 and Decimal256 types. Usually the implementation templated for these types would be called, but The BinaryMemoTable only has a type for BinaryTypes. This means that for count_distinct with Decimal columns, the UnboxScalar implementation that gets called is the one for string_view type --- cpp/src/arrow/compute/kernels/codegen_internal.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 1d5f5dd9bd5..d8d2192382d 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -343,6 +343,22 @@ struct UnboxScalar> { using T = util::string_view; static T Unbox(const Scalar& val) { if (!val.is_valid) return util::string_view(); + + switch (val.type->id()) { + case arrow::Type::DECIMAL128: { + return util::string_view(checked_cast(val).view()); + break; + } + + case arrow::Type::DECIMAL256: { + return util::string_view(checked_cast(val).view()); + break; + } + + default: + break; + } + return util::string_view(*checked_cast(val).value); } }; From 5d04cb098282269e1cad569e45ea9a5cacacaf07 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 14 Jul 2022 08:36:18 -0500 Subject: [PATCH 09/11] Simpler static UnboxScalar solution, remove MaybeInsert --- .../arrow/compute/kernels/aggregate_basic.cc | 7 +- .../arrow/compute/kernels/codegen_internal.h | 18 +---- cpp/src/arrow/util/hashing.h | 79 ++++--------------- 3 files changed, 20 insertions(+), 84 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index a988395c50f..fec483318ef 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -140,7 +140,7 @@ struct CountDistinctImpl : public ScalarAggregator { auto visit_null = []() { return Status::OK(); }; auto visit_value = [&](VisitorArgType arg) { - int y; + int32_t y; return memo_table_->GetOrInsert(arg, &y); }; RETURN_NOT_OK(VisitArraySpanInline(arr, visit_value, visit_null)); @@ -150,7 +150,8 @@ struct CountDistinctImpl : public ScalarAggregator { this->has_nulls = !input.is_valid; if (input.is_valid) { - RETURN_NOT_OK(memo_table_->MaybeInsert(UnboxScalar::Unbox(input))); + int32_t unused; + RETURN_NOT_OK(memo_table_->GetOrInsert(UnboxScalar::Unbox(input), &unused)); } } @@ -161,7 +162,7 @@ struct CountDistinctImpl : public ScalarAggregator { Status MergeFrom(KernelContext*, KernelState&& src) override { const auto& other_state = checked_cast(src); - this->memo_table_->MergeTable(*(other_state.memo_table_)); + RETURN_NOT_OK(this->memo_table_->MergeTable(*(other_state.memo_table_))); this->non_nulls = this->memo_table_->size(); this->has_nulls = this->has_nulls || other_state.has_nulls; return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index d8d2192382d..f008314e8be 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -343,23 +343,7 @@ struct UnboxScalar> { using T = util::string_view; static T Unbox(const Scalar& val) { if (!val.is_valid) return util::string_view(); - - switch (val.type->id()) { - case arrow::Type::DECIMAL128: { - return util::string_view(checked_cast(val).view()); - break; - } - - case arrow::Type::DECIMAL256: { - return util::string_view(checked_cast(val).view()); - break; - } - - default: - break; - } - - return util::string_view(*checked_cast(val).value); + return checked_cast(val).view(); } }; diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index b10daa9ae8b..9673a522629 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -428,22 +428,6 @@ class ScalarMemoTable : public MemoTable { value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); } - Status MaybeInsert(const Scalar& value) { - auto cmp_func = [value](const Payload* payload) -> bool { - return ScalarHelper::CompareScalars(value, payload->value); - }; - - hash_t val_hash = ComputeHash(value); - auto hash_entry = hash_table_.Lookup(val_hash, cmp_func); - - // Insert if it wasn't found; otherwise, we're done - if (!hash_entry.second) { - RETURN_NOT_OK(hash_table_.Insert(hash_entry.first, val_hash, {value, size()})); - } - - return Status::OK(); - } - int32_t GetNull() const { return null_index_; } template @@ -505,13 +489,15 @@ class ScalarMemoTable : public MemoTable { public: // defined here so that `HashTableType` is visible // Merge entries from `other_table` into `this->hash_table_`. - void MergeTable(ScalarMemoTable& other_table) { + Status MergeTable(ScalarMemoTable& other_table) { HashTableType& other_hashtable = other_table.hash_table_; - other_hashtable.VisitEntries([=](const HashTableEntry* other_entry) { - ARROW_WARN_NOT_OK(this->MaybeInsert(other_entry->payload.value), - "Merging ScalarMemoTable"); + other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { + int32_t unused; + DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused)); }); + // TODO: implement proper (and perhaps more performant) error handling + return Status::OK(); } }; @@ -573,22 +559,6 @@ class SmallScalarMemoTable : public MemoTable { value, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); } - Status MaybeInsert(const Scalar& value) { - auto value_index = AsIndex(value); - auto memo_index = value_to_index_[value_index]; - - if (memo_index == kKeyNotFound) { - memo_index = static_cast(index_to_value_.size()); - - index_to_value_.push_back(value); - value_to_index_[value_index] = memo_index; - - DCHECK_LT(memo_index, cardinality + 1); - } - - return Status::OK(); - } - int32_t GetNull() const { return value_to_index_[cardinality]; } template @@ -613,13 +583,12 @@ class SmallScalarMemoTable : public MemoTable { int32_t size() const override { return static_cast(index_to_value_.size()); } // Merge entries from `other_table` into `this`. - void MergeTable(SmallScalarMemoTable& other_table) { + Status MergeTable(SmallScalarMemoTable& other_table) { for (const Scalar& other_val : other_table.index_to_value_) { - auto insert_status = this->MaybeInsert(other_val); - if (not insert_status.ok()) { - ARROW_WARN_NOT_OK(insert_status, "Merging SmallScalarMemoTable"); - } + int32_t unused; + RETURN_NOT_OK(this->GetOrInsert(other_val, &unused)); } + return Status::OK(); } // Copy values starting from index `start` into `out_data` @@ -737,26 +706,6 @@ class BinaryMemoTable : public MemoTable { return GetOrInsertNull([](int32_t i) {}, [](int32_t i) {}); } - Status MaybeInsert(const util::string_view& value) { - const void* val_data = value.data(); - auto val_length = static_cast(value.length()); - - hash_t val_hash = ComputeStringHash<0>(val_data, val_length); - auto hash_entry = Lookup(val_hash, val_data, val_length); - - if (!hash_entry.second) { - // Insert string value - RETURN_NOT_OK( - binary_builder_.Append(static_cast(val_data), val_length)); - - // Insert hash entry - RETURN_NOT_OK(hash_table_.Insert(const_cast(hash_entry.first), - val_hash, {size()})); - } - - return Status::OK(); - } - // The number of entries in the memo table // (which is also 1 + the largest memo index) int32_t size() const override { @@ -900,10 +849,12 @@ class BinaryMemoTable : public MemoTable { } public: - void MergeTable(BinaryMemoTable& other_table) { - other_table.VisitValues(0, [=](const util::string_view& other_value) { - ARROW_WARN_NOT_OK(this->MaybeInsert(other_value), "Merging BinaryMemoTable"); + Status MergeTable(BinaryMemoTable& other_table) { + other_table.VisitValues(0, [this](const util::string_view& other_value) { + int32_t unused; + DCHECK_OK(this->GetOrInsert(other_value, &unused)); }); + return Status::OK(); } }; From 91d2cab1932e8186156df15ed84867963fcd6145 Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Thu, 14 Jul 2022 09:12:33 -0700 Subject: [PATCH 10/11] [ARROW-16807]: retabbed R test --- r/tests/testthat/test-dplyr-summarize.R | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 9c713666097..e3cb82a6e1d 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -219,12 +219,12 @@ test_that("Group by any/all", { }) test_that("n_distinct() with many batches", { - tf <- tempfile() - write_parquet(dplyr::starwars, tf, chunk_size = 20) + tf <- tempfile() + write_parquet(dplyr::starwars, tf, chunk_size = 20) - ds <- open_dataset(tf) - expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), - ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))) + ds <- open_dataset(tf) + expect_equal(ds %>% summarise(n_distinct(sex, na.rm = FALSE)) %>% collect(), + ds %>% collect() %>% summarise(n_distinct(sex, na.rm = FALSE))) }) test_that("n_distinct() on dataset", { From 7c4d70193e3a000d4b51a66cc1553dd0b7a9cf2b Mon Sep 17 00:00:00 2001 From: Aldrin M Date: Fri, 15 Jul 2022 09:48:00 -0700 Subject: [PATCH 11/11] [ARROW-16807]: added const and JIRA reference --- cpp/src/arrow/util/hashing.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 9673a522629..ca5a6c766bd 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -489,14 +489,14 @@ class ScalarMemoTable : public MemoTable { public: // defined here so that `HashTableType` is visible // Merge entries from `other_table` into `this->hash_table_`. - Status MergeTable(ScalarMemoTable& other_table) { - HashTableType& other_hashtable = other_table.hash_table_; + Status MergeTable(const ScalarMemoTable& other_table) { + const HashTableType& other_hashtable = other_table.hash_table_; other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { int32_t unused; DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused)); }); - // TODO: implement proper (and perhaps more performant) error handling + // TODO: ARROW-17074 - implement proper error handling return Status::OK(); } }; @@ -583,7 +583,7 @@ class SmallScalarMemoTable : public MemoTable { int32_t size() const override { return static_cast(index_to_value_.size()); } // Merge entries from `other_table` into `this`. - Status MergeTable(SmallScalarMemoTable& other_table) { + Status MergeTable(const SmallScalarMemoTable& other_table) { for (const Scalar& other_val : other_table.index_to_value_) { int32_t unused; RETURN_NOT_OK(this->GetOrInsert(other_val, &unused)); @@ -849,7 +849,7 @@ class BinaryMemoTable : public MemoTable { } public: - Status MergeTable(BinaryMemoTable& other_table) { + Status MergeTable(const BinaryMemoTable& other_table) { other_table.VisitValues(0, [this](const util::string_view& other_value) { int32_t unused; DCHECK_OK(this->GetOrInsert(other_value, &unused));