diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc index 0d27b9df1c1..c2b9090d317 100644 --- a/cpp/src/arrow/compute/compute-test.cc +++ b/cpp/src/arrow/compute/compute-test.cc @@ -876,6 +876,23 @@ void CheckDictEncode(FunctionContext* ctx, const shared_ptr& type, ASSERT_ARRAYS_EQUAL(expected, *result); } +template +void CheckCountValues(FunctionContext* ctx, const shared_ptr& type, + const vector& in_values, const vector& in_is_valid, + const vector& out_values, const vector& out_is_valid, + const vector& out_counts) { + shared_ptr input = _MakeArray(type, in_values, in_is_valid); + shared_ptr ex_values = _MakeArray(type, out_values, out_is_valid); + shared_ptr ex_counts = + _MakeArray(int64(), out_counts, out_is_valid); + + shared_ptr result_values; + shared_ptr result_counts; + ASSERT_OK(CountValues(ctx, Datum(input), &result_values, &result_counts)); + ASSERT_ARRAYS_EQUAL(*ex_values, *result_values); + ASSERT_ARRAYS_EQUAL(*ex_counts, *result_counts); +} + class TestHashKernel : public ComputeFixture, public TestBase {}; template @@ -903,6 +920,14 @@ TYPED_TEST(TestHashKernelPrimitive, DictEncode) { {0, 0, 0, 1, 0, 2}); } +TYPED_TEST(TestHashKernelPrimitive, CountValues) { + using T = typename TypeParam::c_type; + auto type = TypeTraits::type_singleton(); + CheckCountValues(&this->ctx_, type, {2, 1, 2, 1, 2, 3, 4}, + {true, false, true, true, true, true, false}, {2, 1, 3}, + {}, {3, 1, 1}); +} + TYPED_TEST(TestHashKernelPrimitive, PrimitiveResizeTable) { using T = typename TypeParam::c_type; // Skip this test for (u)int8 @@ -916,12 +941,14 @@ TYPED_TEST(TestHashKernelPrimitive, PrimitiveResizeTable) { vector values; vector uniques; vector indices; + vector counts; for (int64_t i = 0; i < kTotalValues * kRepeats; i++) { const auto val = static_cast(i % kTotalValues); values.push_back(val); if (i < kTotalValues) { uniques.push_back(val); + counts.push_back(kRepeats); } indices.push_back(static_cast(i % kTotalValues)); } @@ -930,6 +957,8 @@ TYPED_TEST(TestHashKernelPrimitive, PrimitiveResizeTable) { CheckUnique(&this->ctx_, type, values, {}, uniques, {}); CheckDictEncode(&this->ctx_, type, values, {}, uniques, {}, indices); + + CheckCountValues(&this->ctx_, type, values, {}, uniques, {}, counts); } TEST_F(TestHashKernel, UniqueTimeTimestamp) { @@ -944,6 +973,19 @@ TEST_F(TestHashKernel, UniqueTimeTimestamp) { {}); } +TEST_F(TestHashKernel, CountValuesTimeTimestamp) { + CheckCountValues(&this->ctx_, time32(TimeUnit::SECOND), + {2, 1, 2, 1}, {true, false, true, true}, {2, 1}, + {}, {2, 1}); + + CheckCountValues(&this->ctx_, time64(TimeUnit::NANO), {2, 1, 2, 1}, + {true, false, true, true}, {2, 1}, {}, {2, 1}); + + CheckCountValues(&this->ctx_, timestamp(TimeUnit::NANO), + {2, 1, 2, 1}, {true, false, true, true}, + {2, 1}, {}, {2, 1}); +} + TEST_F(TestHashKernel, UniqueBoolean) { CheckUnique(&this->ctx_, boolean(), {true, true, false, true}, {true, false, true, true}, {true, false}, {}); @@ -978,6 +1020,23 @@ TEST_F(TestHashKernel, DictEncodeBoolean) { {}, {0, 1, 0, 1, 0}); } +TEST_F(TestHashKernel, CountValuesBoolean) { + CheckCountValues(&this->ctx_, boolean(), {true, true, false, true}, + {true, false, true, true}, {true, false}, {}, + {2, 1}); + + CheckCountValues(&this->ctx_, boolean(), {false, true, false, true}, + {true, false, true, true}, {false, true}, {}, + {2, 1}); + + // No nulls + CheckCountValues(&this->ctx_, boolean(), {true, true, false, true}, + {}, {true, false}, {}, {3, 1}); + + CheckCountValues(&this->ctx_, boolean(), {false, true, false, true}, + {}, {false, true}, {}, {2, 2}); +} + TEST_F(TestHashKernel, UniqueBinary) { CheckUnique(&this->ctx_, binary(), {"test", "", "test2", "test"}, @@ -997,6 +1056,16 @@ TEST_F(TestHashKernel, DictEncodeBinary) { {true, false, true, true, true}, {"test", "test2", "baz"}, {}, {0, 0, 1, 0, 2}); } +TEST_F(TestHashKernel, CountValuesBinary) { + CheckCountValues( + &this->ctx_, binary(), {"test", "", "test2", "test"}, {true, false, true, true}, + {"test", "test2"}, {}, {2, 1}); + + CheckCountValues( + &this->ctx_, utf8(), {"test", "", "test2", "test"}, {true, false, true, true}, + {"test", "test2"}, {}, {2, 1}); +} + TEST_F(TestHashKernel, BinaryResizeTable) { const int64_t kTotalValues = 10000; const int64_t kRepeats = 10; @@ -1046,6 +1115,7 @@ TEST_F(TestHashKernel, FixedSizeBinaryResizeTable) { vector values; vector uniques; vector indices; + vector counts; for (int64_t i = 0; i < kTotalValues * kRepeats; i++) { int64_t index = i % kTotalValues; std::stringstream ss; @@ -1056,6 +1126,7 @@ TEST_F(TestHashKernel, FixedSizeBinaryResizeTable) { if (i < kTotalValues) { uniques.push_back(val); + counts.push_back(kRepeats); } indices.push_back(static_cast(i % kTotalValues)); } @@ -1065,6 +1136,8 @@ TEST_F(TestHashKernel, FixedSizeBinaryResizeTable) { {}); CheckDictEncode(&this->ctx_, type, values, {}, uniques, {}, indices); + CheckCountValues(&this->ctx_, type, values, {}, + uniques, {}, counts); } TEST_F(TestHashKernel, UniqueDecimal) { @@ -1084,6 +1157,15 @@ TEST_F(TestHashKernel, DictEncodeDecimal) { {}, {0, 0, 1, 0, 2}); } +TEST_F(TestHashKernel, CountValuesDecimal) { + vector values{12, 12, 11, 12}; + vector expected{12, 11}; + + CheckCountValues(&this->ctx_, decimal(2, 0), values, + {true, false, true, true}, expected, {}, + {2, 1}); +} + TEST_F(TestHashKernel, ChunkedArrayInvoke) { vector values1 = {"foo", "bar", "foo"}; vector values2 = {"bar", "baz", "quuux", "foo"}; @@ -1095,6 +1177,9 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) { vector dict_values = {"foo", "bar", "baz", "quuux"}; auto ex_dict = _MakeArray(type, dict_values, {}); + vector counts = {3, 2, 1, 1}; + auto ex_counts = _MakeArray(int64(), counts, {}); + ArrayVector arrays = {a1, a2}; auto carr = std::make_shared(arrays); @@ -1103,6 +1188,13 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) { ASSERT_OK(Unique(&this->ctx_, Datum(carr), &result)); ASSERT_ARRAYS_EQUAL(*ex_dict, *result); + // Count values + shared_ptr cv_uniques; + shared_ptr cv_counts; + ASSERT_OK(CountValues(&this->ctx_, Datum(carr), &cv_uniques, &cv_counts)); + ASSERT_ARRAYS_EQUAL(*ex_dict, *cv_uniques); + ASSERT_ARRAYS_EQUAL(*ex_counts, *cv_counts); + // Dictionary encode auto dict_type = dictionary(int32(), ex_dict); diff --git a/cpp/src/arrow/compute/kernels/hash.cc b/cpp/src/arrow/compute/kernels/hash.cc index dbce6e561c5..1854f85481d 100644 --- a/cpp/src/arrow/compute/kernels/hash.cc +++ b/cpp/src/arrow/compute/kernels/hash.cc @@ -749,6 +749,50 @@ class DictEncodeImpl : public HashTableKernel> { Int32Builder indices_builder_; }; +// ---------------------------------------------------------------------- +// Count values implementation + +template +class CountValuesImpl : public HashTableKernel> { + public: + static constexpr bool allow_expand = true; + using Base = HashTableKernel; + + CountValuesImpl(const std::shared_ptr& type, MemoryPool* pool) + : Base(type, pool) {} + + Status Reserve(const int64_t length) { + counts_.reserve(length); + return Status::OK(); + } + + void ObserveNull() {} + + void ObserveFound(const hash_slot_t slot) { counts_[slot]++; } + + void ObserveNotFound(const hash_slot_t slot) { counts_.emplace_back(1); } + + Status DoubleSize() { return Base::DoubleTableSize(); } + + Status Flush(Datum* out) override { + Int64Builder builder(Base::pool_); + std::shared_ptr result; + + for (const int64_t value : counts_) { + RETURN_NOT_OK(builder.Append(value)); + } + + RETURN_NOT_OK(builder.FinishInternal(&result)); + out->value = std::move(result); + return Status::OK(); + } + + using Base::Append; + + private: + std::vector counts_; +}; + // ---------------------------------------------------------------------- // Kernel wrapper for generic hash table kernels @@ -871,6 +915,48 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx, return Status::OK(); } +Status GetCountValuesKernel(FunctionContext* ctx, const std::shared_ptr& type, + std::unique_ptr* out) { + std::unique_ptr hasher; + +#define COUNT_VALUES_CASE(InType) \ + case InType::type_id: \ + hasher.reset(new CountValuesImpl(type, ctx->memory_pool())); \ + break + + switch (type->id()) { + COUNT_VALUES_CASE(NullType); + COUNT_VALUES_CASE(BooleanType); + COUNT_VALUES_CASE(UInt8Type); + COUNT_VALUES_CASE(Int8Type); + COUNT_VALUES_CASE(UInt16Type); + COUNT_VALUES_CASE(Int16Type); + COUNT_VALUES_CASE(UInt32Type); + COUNT_VALUES_CASE(Int32Type); + COUNT_VALUES_CASE(UInt64Type); + COUNT_VALUES_CASE(Int64Type); + COUNT_VALUES_CASE(FloatType); + COUNT_VALUES_CASE(DoubleType); + COUNT_VALUES_CASE(Date32Type); + COUNT_VALUES_CASE(Date64Type); + COUNT_VALUES_CASE(Time32Type); + COUNT_VALUES_CASE(Time64Type); + COUNT_VALUES_CASE(TimestampType); + COUNT_VALUES_CASE(BinaryType); + COUNT_VALUES_CASE(StringType); + COUNT_VALUES_CASE(FixedSizeBinaryType); + COUNT_VALUES_CASE(Decimal128Type); + default: + break; + } + +#undef COUNT_VALUES_CASE + + CHECK_IMPLEMENTED(hasher, "count-values", type); + out->reset(new HashKernelImpl(std::move(hasher))); + return Status::OK(); +} + namespace { Status InvokeHash(FunctionContext* ctx, HashKernel* func, const Datum& value, @@ -918,5 +1004,18 @@ Status DictionaryEncode(FunctionContext* ctx, const Datum& value, Datum* out) { return Status::OK(); } +Status CountValues(FunctionContext* ctx, const Datum& value, + std::shared_ptr* out_uniques, + std::shared_ptr* out_counts) { + std::unique_ptr func; + RETURN_NOT_OK(GetCountValuesKernel(ctx, value.type(), &func)); + + std::vector counts_datum; + RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &counts_datum, out_uniques)); + + *out_counts = MakeArray(counts_datum.back().array()); + return Status::OK(); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash.h b/cpp/src/arrow/compute/kernels/hash.h index 05f24294989..0f229865fef 100644 --- a/cpp/src/arrow/compute/kernels/hash.h +++ b/cpp/src/arrow/compute/kernels/hash.h @@ -51,6 +51,10 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx, const std::shared_ptr& type, std::unique_ptr* kernel); +ARROW_EXPORT +Status GetCountValuesKernel(FunctionContext* ctx, const std::shared_ptr& type, + std::unique_ptr* kernel); + /// \brief Compute unique elements from an array-like object /// \param[in] context the FunctionContext /// \param[in] datum array-like input @@ -71,6 +75,19 @@ Status Unique(FunctionContext* context, const Datum& datum, std::shared_ptr* out_uniques, + std::shared_ptr* out_counts); + // TODO(wesm): Define API for incremental dictionary encoding // TODO(wesm): Define API for regularizing DictionaryArray objects with @@ -95,11 +112,6 @@ Status DictionaryEncode(FunctionContext* context, const Datum& data, Datum* out) // Status IsIn(FunctionContext* context, const Datum& values, const Datum& member_set, // Datum* out); -// ARROW_EXPORT -// Status CountValues(FunctionContext* context, const Datum& values, -// std::shared_ptr* out_uniques, -// std::shared_ptr* out_counts); - } // namespace compute } // namespace arrow