diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc index f6f66d27cfe..9451e00b010 100644 --- a/cpp/src/arrow/array-test.cc +++ b/cpp/src/arrow/array-test.cc @@ -494,6 +494,22 @@ void TestPrimitiveBuilder::Check(const std::unique_ptr ASSERT_EQ(0, builder->null_count()); } +TEST(NumericBuilderAccessors, TestSettersGetters) { + int64_t datum = 42; + int64_t new_datum = 43; + NumericBuilder builder(int64(), default_memory_pool()); + + builder.Reset(); + ASSERT_OK(builder.Append(datum)); + ASSERT_EQ(builder.GetValue(0), datum); + + // Now update the value. + builder[0] = new_datum; + + ASSERT_EQ(builder.GetValue(0), new_datum); + ASSERT_EQ(((const NumericBuilder&)builder)[0], new_datum); +} + typedef ::testing::Types Primitives; diff --git a/cpp/src/arrow/array/builder_primitive.h b/cpp/src/arrow/array/builder_primitive.h index 5a9b69483af..95cfaa793d6 100644 --- a/cpp/src/arrow/array/builder_primitive.h +++ b/cpp/src/arrow/array/builder_primitive.h @@ -92,6 +92,12 @@ class NumericBuilder : public ArrayBuilder { return ArrayBuilder::Resize(capacity); } + value_type operator[](int64_t index) const { return GetValue(index); } + + value_type& operator[](int64_t index) { + return reinterpret_cast(data_builder_.mutable_data())[index]; + } + /// \brief Append a sequence of elements in one shot /// \param[in] values a contiguous C array of values /// \param[in] length the number of values to append diff --git a/cpp/src/arrow/compute/kernels/boolean-test.cc b/cpp/src/arrow/compute/kernels/boolean-test.cc index 439e0db9b6f..5e1da1be657 100644 --- a/cpp/src/arrow/compute/kernels/boolean-test.cc +++ b/cpp/src/arrow/compute/kernels/boolean-test.cc @@ -129,7 +129,6 @@ TEST_F(TestBooleanKernel, Invert) { } TEST_F(TestBooleanKernel, InvertEmptyArray) { - auto type = boolean(); std::vector> data_buffers(2); Datum input; input.value = ArrayData::Make(boolean(), 0 /* length */, std::move(data_buffers), diff --git a/cpp/src/arrow/compute/kernels/hash-test.cc b/cpp/src/arrow/compute/kernels/hash-test.cc index 84eec8ba1d7..d4ffa55cb0f 100644 --- a/cpp/src/arrow/compute/kernels/hash-test.cc +++ b/cpp/src/arrow/compute/kernels/hash-test.cc @@ -43,6 +43,8 @@ #include "arrow/compute/kernels/util-internal.h" #include "arrow/compute/test-util.h" +#include "arrow/ipc/json-simple.h" + using std::shared_ptr; using std::vector; @@ -61,9 +63,47 @@ void CheckUnique(FunctionContext* ctx, const shared_ptr& type, shared_ptr result; ASSERT_OK(Unique(ctx, input, &result)); + // TODO: We probably shouldn't rely on array ordering. ASSERT_ARRAYS_EQUAL(*expected, *result); } +template +void CheckValueCountsNull(FunctionContext* ctx, const shared_ptr& type) { + std::vector> data_buffers(2); + Datum input; + input.value = + ArrayData::Make(type, 0 /* length */, std::move(data_buffers), 0 /* null_count */); + + shared_ptr ex_values = ArrayFromJSON(type, "[]"); + shared_ptr ex_counts = ArrayFromJSON(int64(), "[]"); + + shared_ptr result; + ASSERT_OK(ValueCounts(ctx, input, &result)); + auto result_struct = std::dynamic_pointer_cast(result); + ASSERT_NE(result_struct->GetFieldByName(kValuesFieldName), nullptr); + // TODO: We probably shouldn't rely on value ordering. + ASSERT_ARRAYS_EQUAL(*ex_values, *result_struct->GetFieldByName(kValuesFieldName)); + ASSERT_ARRAYS_EQUAL(*ex_counts, *result_struct->GetFieldByName(kCountsFieldName)); +} + +template +void CheckValueCounts(FunctionContext* ctx, const shared_ptr& type, + const vector& in_values, const vector& in_is_valid, + const vector& out_values, const vector& out_is_valid, + const vector& out_counts) { + shared_ptr input = _MakeArray(type, in_values, in_is_valid); + shared_ptr ex_values = _MakeArray(type, out_values, out_is_valid); + shared_ptr ex_counts = + _MakeArray(int64(), out_counts, out_is_valid); + + shared_ptr result; + ASSERT_OK(ValueCounts(ctx, input, &result)); + auto result_struct = std::dynamic_pointer_cast(result); + // TODO: We probably shouldn't rely on value ordering. + ASSERT_ARRAYS_EQUAL(*ex_values, *result_struct->field(kValuesFieldIndex)); + ASSERT_ARRAYS_EQUAL(*ex_counts, *result_struct->field(kCountsFieldIndex)); +} + template void CheckDictEncode(FunctionContext* ctx, const shared_ptr& type, const vector& in_values, const vector& in_is_valid, @@ -104,6 +144,16 @@ TYPED_TEST(TestHashKernelPrimitive, Unique) { {3, 1}, {}); } +TYPED_TEST(TestHashKernelPrimitive, ValueCounts) { + using T = typename TypeParam::c_type; + auto type = TypeTraits::type_singleton(); + CheckValueCounts(&this->ctx_, type, {2, 1, 2, 1, 2, 3, 4}, + {true, false, true, true, true, true, false}, {2, 1, 3}, + {}, {3, 1, 1}); + CheckValueCounts(&this->ctx_, type, {}, {}, {}, {}, {}); + CheckValueCountsNull(&this->ctx_, type); +} + TYPED_TEST(TestHashKernelPrimitive, DictEncode) { using T = typename TypeParam::c_type; auto type = TypeTraits::type_singleton(); @@ -121,19 +171,21 @@ TYPED_TEST(TestHashKernelPrimitive, PrimitiveResizeTable) { vector values; vector uniques; vector indices; + vector counts; for (int64_t i = 0; i < kTotalValues * kRepeats; i++) { const auto val = static_cast(i % kTotalValues); values.push_back(val); if (i < kTotalValues) { uniques.push_back(val); + counts.push_back(kRepeats); } indices.push_back(static_cast(i % kTotalValues)); } auto type = TypeTraits::type_singleton(); CheckUnique(&this->ctx_, type, values, {}, uniques, {}); - + CheckValueCounts(&this->ctx_, type, values, {}, uniques, {}, counts); CheckDictEncode(&this->ctx_, type, values, {}, uniques, {}, indices); } @@ -149,6 +201,19 @@ TEST_F(TestHashKernel, UniqueTimeTimestamp) { {}); } +TEST_F(TestHashKernel, ValueCountsTimeTimestamp) { + CheckValueCounts(&this->ctx_, time32(TimeUnit::SECOND), + {2, 1, 2, 1}, {true, false, true, true}, {2, 1}, + {}, {2, 1}); + + CheckValueCounts(&this->ctx_, time64(TimeUnit::NANO), {2, 1, 2, 1}, + {true, false, true, true}, {2, 1}, {}, {2, 1}); + + CheckValueCounts(&this->ctx_, timestamp(TimeUnit::NANO), + {2, 1, 2, 1}, {true, false, true, true}, + {2, 1}, {}, {2, 1}); +} + TEST_F(TestHashKernel, UniqueBoolean) { CheckUnique(&this->ctx_, boolean(), {true, true, false, true}, {true, false, true, true}, {true, false}, {}); @@ -164,6 +229,23 @@ TEST_F(TestHashKernel, UniqueBoolean) { {false, true}, {}); } +TEST_F(TestHashKernel, ValueCountsBoolean) { + CheckValueCounts(&this->ctx_, boolean(), {true, true, false, true}, + {true, false, true, true}, {true, false}, {}, + {2, 1}); + + CheckValueCounts(&this->ctx_, boolean(), {false, true, false, true}, + {true, false, true, true}, {false, true}, {}, + {2, 1}); + + // No nulls + CheckValueCounts(&this->ctx_, boolean(), {true, true, false, true}, + {}, {true, false}, {}, {3, 1}); + + CheckValueCounts(&this->ctx_, boolean(), {false, true, false, true}, + {}, {false, true}, {}, {2, 2}); +} + TEST_F(TestHashKernel, DictEncodeBoolean) { CheckDictEncode( &this->ctx_, boolean(), {true, true, false, true, false}, @@ -192,6 +274,16 @@ TEST_F(TestHashKernel, UniqueBinary) { {true, false, true, true}, {"test", "test2"}, {}); } +TEST_F(TestHashKernel, ValueCountsBinary) { + CheckValueCounts( + &this->ctx_, binary(), {"test", "", "test2", "test"}, {true, false, true, true}, + {"test", "test2"}, {}, {2, 1}); + + CheckValueCounts( + &this->ctx_, utf8(), {"test", "", "test2", "test"}, {true, false, true, true}, + {"test", "test2"}, {}, {2, 1}); +} + TEST_F(TestHashKernel, DictEncodeBinary) { CheckDictEncode( &this->ctx_, binary(), {"test", "", "test2", "test", "baz"}, @@ -214,6 +306,7 @@ TEST_F(TestHashKernel, BinaryResizeTable) { vector values; vector uniques; vector indices; + vector counts; char buf[20] = "test"; for (int32_t i = 0; i < kTotalValues * kRepeats; i++) { @@ -224,15 +317,21 @@ TEST_F(TestHashKernel, BinaryResizeTable) { if (i < kTotalValues) { uniques.push_back(values.back()); + counts.push_back(kRepeats); } indices.push_back(index); } CheckUnique(&this->ctx_, binary(), values, {}, uniques, {}); + CheckValueCounts(&this->ctx_, binary(), values, {}, uniques, + {}, counts); + CheckDictEncode(&this->ctx_, binary(), values, {}, uniques, {}, indices); CheckUnique(&this->ctx_, utf8(), values, {}, uniques, {}); + CheckValueCounts(&this->ctx_, utf8(), values, {}, uniques, {}, + counts); CheckDictEncode(&this->ctx_, utf8(), values, {}, uniques, {}, indices); } @@ -291,6 +390,15 @@ TEST_F(TestHashKernel, UniqueDecimal) { {true, false, true, true}, expected, {}); } +TEST_F(TestHashKernel, ValueCountsDecimal) { + vector values{12, 12, 11, 12}; + vector expected{12, 11}; + + CheckValueCounts(&this->ctx_, decimal(2, 0), values, + {true, false, true, true}, expected, {}, + {2, 1}); +} + TEST_F(TestHashKernel, DictEncodeDecimal) { vector values{12, 12, 11, 12, 13}; vector expected{12, 11, 13}; @@ -300,6 +408,20 @@ TEST_F(TestHashKernel, DictEncodeDecimal) { {}, {0, 0, 1, 0, 2}); } +/* TODO(ARROW-4124): Determine if we wan to do something that is reproducable with floats. +TEST_F(TestHashKernel, ValueCountsFloat) { + + // No nulls + CheckValueCounts(&this->ctx_, float32(), {1.0f, 0.0f, -0.0f, +std::nan("1"), std::nan("2") }, + {}, {0.0f, 1.0f, std::nan("1")}, {}, {}); + + CheckValueCounts(&this->ctx_, float64(), {1.0f, 0.0f, -0.0f, +std::nan("1"), std::nan("2") }, + {}, {0.0f, 1.0f, std::nan("1")}, {}, {}); +} +*/ + TEST_F(TestHashKernel, ChunkedArrayInvoke) { vector values1 = {"foo", "bar", "foo"}; vector values2 = {"bar", "baz", "quuux", "foo"}; @@ -311,6 +433,9 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) { vector dict_values = {"foo", "bar", "baz", "quuux"}; auto ex_dict = _MakeArray(type, dict_values, {}); + vector counts = {3, 2, 1, 1}; + auto ex_counts = _MakeArray(int64(), counts, {}); + ArrayVector arrays = {a1, a2}; auto carr = std::make_shared(arrays); @@ -329,6 +454,14 @@ TEST_F(TestHashKernel, ChunkedArrayInvoke) { std::make_shared(dict_type, i2)}; auto dict_carr = std::make_shared(dict_arrays); + // Unique counts + shared_ptr counts_array; + ASSERT_OK(ValueCounts(&this->ctx_, carr, &counts_array)); + auto counts_struct = std::dynamic_pointer_cast(counts_array); + ASSERT_ARRAYS_EQUAL(*ex_dict, *counts_struct->field(0)); + ASSERT_ARRAYS_EQUAL(*ex_counts, *counts_struct->field(1)); + + // Dictionary encode Datum encoded_out; ASSERT_OK(DictionaryEncode(&this->ctx_, carr, &encoded_out)); ASSERT_EQ(Datum::CHUNKED_ARRAY, encoded_out.kind()); diff --git a/cpp/src/arrow/compute/kernels/hash.cc b/cpp/src/arrow/compute/kernels/hash.cc index f443282b1b8..2a3031fc7bc 100644 --- a/cpp/src/arrow/compute/kernels/hash.cc +++ b/cpp/src/arrow/compute/kernels/hash.cc @@ -87,16 +87,81 @@ class UniqueAction : public ActionBase { template void ObserveFound(Index index) {} + template + void ObserveNotFound(Index index, Status* err_status) { + ARROW_LOG(FATAL) << "ObserveNotFound with err_status should not be called"; + } + template void ObserveNotFound(Index index) {} Status Flush(Datum* out) { return Status::OK(); } std::shared_ptr out_type() const { return type_; } + + Status FlushFinal(Datum* out) { return Status::OK(); } }; // ---------------------------------------------------------------------- -// Dictionary encode implementation +// Count values implementation (see HashKernel for description of methods) + +class ValueCountsAction : ActionBase { + public: + using ActionBase::ActionBase; + + ValueCountsAction(const std::shared_ptr& type, MemoryPool* pool) + : ActionBase(type, pool), count_builder_(pool) {} + + Status Reserve(const int64_t length) { + // builder size is independent of input array size. + return Status::OK(); + } + + Status Reset() { + count_builder_.Reset(); + return Status::OK(); + } + + // Don't do anything on flush because we don't want to finalize the builder + // or incur the cost of memory copies. + Status Flush(Datum* out) { return Status::OK(); } + + std::shared_ptr out_type() const { return type_; } + + // Return the counts corresponding the MemoTable keys. + Status FlushFinal(Datum* out) { + std::shared_ptr result; + RETURN_NOT_OK(count_builder_.FinishInternal(&result)); + out->value = std::move(result); + return Status::OK(); + } + + void ObserveNull() {} + + template + void ObserveFound(Index slot) { + count_builder_[slot]++; + } + + template + void ObserveNotFound(Index slot) { + ARROW_LOG(FATAL) << "ObserveNotFound without err_status should not be called"; + } + + template + void ObserveNotFound(Index slot, Status* status) { + Status s = count_builder_.Append(1); + if (ARROW_PREDICT_FALSE(!s.ok())) { + *status = s; + } + } + + private: + Int64Builder count_builder_; +}; + +// ---------------------------------------------------------------------- +// Dictionary encode implementation (see HashKernel for description of methods) class DictEncodeAction : public ActionBase { public: @@ -119,7 +184,12 @@ class DictEncodeAction : public ActionBase { template void ObserveNotFound(Index index) { - return ObserveFound(index); + ObserveFound(index); + } + + template + void ObserveNotFound(Index index, Status* err_status) { + ARROW_LOG(FATAL) << "ObserveNotFound with err_status should not be called"; } Status Flush(Datum* out) { @@ -130,11 +200,33 @@ class DictEncodeAction : public ActionBase { } std::shared_ptr out_type() const { return int32(); } + Status FlushFinal(Datum* out) { return Status::OK(); } private: Int32Builder indices_builder_; }; +/// \brief Invoke hash table kernel on input array, returning any output +/// values. Implementations should be thread-safe +/// +/// This interface is implemented below using visitor pattern on "Action" +/// implementations. It is not consolidate to keep the contract clearer. +class HashKernel : public UnaryKernel { + public: + // Reset for another run. + virtual Status Reset() = 0; + // Prepare the Action for the given input (e.g. reserve appropriately sized + // data structures) and visit the given input with Action. + virtual Status Append(FunctionContext* ctx, const ArrayData& input) = 0; + // Flush out accumulated results from the last invocation of Call. + virtual Status Flush(Datum* out) = 0; + // Flush out accumulated results across all invocations of Call. The kernel + // should not be used until after Reset() is called. + virtual Status FlushFinal(Datum* out) = 0; + // Get the values (keys) acummulated in the dictionary so far. + virtual Status GetDictionary(std::shared_ptr* out) = 0; +}; + // ---------------------------------------------------------------------- // Base class for all hash kernel implementations @@ -161,7 +253,7 @@ class HashKernelImpl : public HashKernel { // Base class for all "regular" hash kernel implementations // (NullType has a separate implementation) -template +template class RegularHashKernelImpl : public HashKernelImpl { public: RegularHashKernelImpl(const std::shared_ptr& type, MemoryPool* pool) @@ -179,6 +271,8 @@ class RegularHashKernelImpl : public HashKernelImpl { Status Flush(Datum* out) override { return action_.Flush(out); } + Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } + Status GetDictionary(std::shared_ptr* out) override { return DictionaryTraits::GetDictionaryArrayData(pool_, type_, *memo_table_, 0 /* start_offset */, out); @@ -186,15 +280,26 @@ class RegularHashKernelImpl : public HashKernelImpl { Status VisitNull() { action_.ObserveNull(); - return Status::OK(); + return Status::Status::OK(); } Status VisitValue(const Scalar& value) { auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); }; - auto on_not_found = [this](int32_t memo_index) { - action_.ObserveNotFound(memo_index); - }; - memo_table_->GetOrInsert(value, on_found, on_not_found); + + if (with_error_status) { + Status status; + auto on_not_found = [this, &status](int32_t memo_index) { + action_.ObserveNotFound(memo_index, &status); + }; + memo_table_->GetOrInsert(value, on_found, on_not_found); + return status; + } else { + auto on_not_found = [this](int32_t memo_index) { + action_.ObserveNotFound(memo_index); + }; + + memo_table_->GetOrInsert(value, on_found, on_not_found); + } return Status::OK(); } @@ -229,6 +334,7 @@ class NullHashKernelImpl : public HashKernelImpl { } Status Flush(Datum* out) override { return action_.Flush(out); } + Status FlushFinal(Datum* out) override { return action_.FlushFinal(out); } Status GetDictionary(std::shared_ptr* out) override { // TODO(wesm): handle null being a valid dictionary value @@ -248,74 +354,80 @@ class NullHashKernelImpl : public HashKernelImpl { // ---------------------------------------------------------------------- // Kernel wrapper for generic hash table kernels -template +template struct HashKernelTraits {}; -template -struct HashKernelTraits> { +template +struct HashKernelTraits> { using HashKernelImpl = NullHashKernelImpl; }; -template -struct HashKernelTraits> { - using HashKernelImpl = RegularHashKernelImpl; +template +struct HashKernelTraits> { + using HashKernelImpl = + RegularHashKernelImpl; }; -template -struct HashKernelTraits> { - using HashKernelImpl = RegularHashKernelImpl; +template +struct HashKernelTraits> { + using HashKernelImpl = RegularHashKernelImpl; }; -template -struct HashKernelTraits> { - using HashKernelImpl = RegularHashKernelImpl; +template +struct HashKernelTraits> { + using HashKernelImpl = + RegularHashKernelImpl; }; -template -struct HashKernelTraits> { - using HashKernelImpl = RegularHashKernelImpl; +template +struct HashKernelTraits> { + using HashKernelImpl = + RegularHashKernelImpl; }; } // namespace +#define PROCESS_SUPPORTED_HASH_TYPES(PROCESS) \ + PROCESS(NullType) \ + PROCESS(BooleanType) \ + PROCESS(UInt8Type) \ + PROCESS(Int8Type) \ + PROCESS(UInt16Type) \ + PROCESS(Int16Type) \ + PROCESS(UInt32Type) \ + PROCESS(Int32Type) \ + PROCESS(UInt64Type) \ + PROCESS(Int64Type) \ + PROCESS(FloatType) \ + PROCESS(DoubleType) \ + PROCESS(Date32Type) \ + PROCESS(Date64Type) \ + PROCESS(Time32Type) \ + PROCESS(Time64Type) \ + PROCESS(TimestampType) \ + PROCESS(BinaryType) \ + PROCESS(StringType) \ + PROCESS(FixedSizeBinaryType) \ + PROCESS(Decimal128Type) + Status GetUniqueKernel(FunctionContext* ctx, const std::shared_ptr& type, std::unique_ptr* out) { std::unique_ptr kernel; - -#define UNIQUE_CASE(InType) \ - case InType::type_id: \ - kernel.reset(new typename HashKernelTraits::HashKernelImpl( \ - type, ctx->memory_pool())); \ - break - switch (type->id()) { - UNIQUE_CASE(NullType); - UNIQUE_CASE(BooleanType); - UNIQUE_CASE(UInt8Type); - UNIQUE_CASE(Int8Type); - UNIQUE_CASE(UInt16Type); - UNIQUE_CASE(Int16Type); - UNIQUE_CASE(UInt32Type); - UNIQUE_CASE(Int32Type); - UNIQUE_CASE(UInt64Type); - UNIQUE_CASE(Int64Type); - UNIQUE_CASE(FloatType); - UNIQUE_CASE(DoubleType); - UNIQUE_CASE(Date32Type); - UNIQUE_CASE(Date64Type); - UNIQUE_CASE(Time32Type); - UNIQUE_CASE(Time64Type); - UNIQUE_CASE(TimestampType); - UNIQUE_CASE(BinaryType); - UNIQUE_CASE(StringType); - UNIQUE_CASE(FixedSizeBinaryType); - UNIQUE_CASE(Decimal128Type); +#define PROCESS(InType) \ + case InType::type_id: \ + kernel.reset(new \ + typename HashKernelTraits::HashKernelImpl( \ + type, ctx->memory_pool())); \ + break; + + PROCESS_SUPPORTED_HASH_TYPES(PROCESS) +#undef PROCESS default: break; } -#undef UNIQUE_CASE - CHECK_IMPLEMENTED(kernel, "unique", type); RETURN_NOT_OK(kernel->Reset()); *out = std::move(kernel); @@ -327,35 +439,16 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx, std::unique_ptr* out) { std::unique_ptr kernel; -#define DICTIONARY_ENCODE_CASE(InType) \ - case InType::type_id: \ - kernel.reset(new \ - typename HashKernelTraits::HashKernelImpl( \ - type, ctx->memory_pool())); \ - break - switch (type->id()) { - DICTIONARY_ENCODE_CASE(NullType); - DICTIONARY_ENCODE_CASE(BooleanType); - DICTIONARY_ENCODE_CASE(UInt8Type); - DICTIONARY_ENCODE_CASE(Int8Type); - DICTIONARY_ENCODE_CASE(UInt16Type); - DICTIONARY_ENCODE_CASE(Int16Type); - DICTIONARY_ENCODE_CASE(UInt32Type); - DICTIONARY_ENCODE_CASE(Int32Type); - DICTIONARY_ENCODE_CASE(UInt64Type); - DICTIONARY_ENCODE_CASE(Int64Type); - DICTIONARY_ENCODE_CASE(FloatType); - DICTIONARY_ENCODE_CASE(DoubleType); - DICTIONARY_ENCODE_CASE(Date32Type); - DICTIONARY_ENCODE_CASE(Date64Type); - DICTIONARY_ENCODE_CASE(Time32Type); - DICTIONARY_ENCODE_CASE(Time64Type); - DICTIONARY_ENCODE_CASE(TimestampType); - DICTIONARY_ENCODE_CASE(BinaryType); - DICTIONARY_ENCODE_CASE(StringType); - DICTIONARY_ENCODE_CASE(FixedSizeBinaryType); - DICTIONARY_ENCODE_CASE(Decimal128Type); +#define PROCESS(InType) \ + case InType::type_id: \ + kernel.reset( \ + new typename HashKernelTraits::HashKernelImpl( \ + type, ctx->memory_pool())); \ + break; + + PROCESS_SUPPORTED_HASH_TYPES(PROCESS) +#undef PROCESS default: break; } @@ -368,6 +461,30 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx, return Status::OK(); } +Status GetValueCountsKernel(FunctionContext* ctx, const std::shared_ptr& type, + std::unique_ptr* out) { + std::unique_ptr kernel; + + switch (type->id()) { +#define PROCESS(InType) \ + case InType::type_id: \ + kernel.reset( \ + new typename HashKernelTraits::HashKernelImpl( \ + type, ctx->memory_pool())); \ + break; + + PROCESS_SUPPORTED_HASH_TYPES(PROCESS) +#undef PROCESS + default: + break; + } + + CHECK_IMPLEMENTED(kernel, "count-values", type); + RETURN_NOT_OK(kernel->Reset()); + *out = std::move(kernel); + return Status::OK(); +} + namespace { Status InvokeHash(FunctionContext* ctx, HashKernel* func, const Datum& value, @@ -415,5 +532,31 @@ Status DictionaryEncode(FunctionContext* ctx, const Datum& value, Datum* out) { return Status::OK(); } +const char kValuesFieldName[] = "values"; +const char kCountsFieldName[] = "counts"; +const int32_t kValuesFieldIndex = 0; +const int32_t kCountsFieldIndex = 1; +Status ValueCounts(FunctionContext* ctx, const Datum& value, + std::shared_ptr* counts) { + std::unique_ptr func; + RETURN_NOT_OK(GetValueCountsKernel(ctx, value.type(), &func)); + + // Calls return nothing for counts. + std::vector unused_output; + std::shared_ptr uniques; + RETURN_NOT_OK(InvokeHash(ctx, func.get(), value, &unused_output, &uniques)); + + Datum value_counts; + RETURN_NOT_OK(func->FlushFinal(&value_counts)); + + auto data_type = std::make_shared(std::vector>{ + std::make_shared(kValuesFieldName, uniques->type()), + std::make_shared(kCountsFieldName, int64())}); + *counts = std::make_shared( + data_type, uniques->length(), + std::vector>{uniques, MakeArray(value_counts.array())}); + return Status::OK(); +} +#undef PROCESS_SUPPORTED_HASH_TYPES } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash.h b/cpp/src/arrow/compute/kernels/hash.h index 6bbe3cfb447..edc7c493e46 100644 --- a/cpp/src/arrow/compute/kernels/hash.h +++ b/cpp/src/arrow/compute/kernels/hash.h @@ -34,29 +34,10 @@ namespace compute { class FunctionContext; -/// \brief Invoke hash table kernel on input array, returning any output -/// values. Implementations should be thread-safe -class ARROW_EXPORT HashKernel : public UnaryKernel { - public: - // XXX why are those methods exposed? - virtual Status Reset() = 0; - virtual Status Append(FunctionContext* ctx, const ArrayData& input) = 0; - virtual Status Flush(Datum* out) = 0; - virtual Status GetDictionary(std::shared_ptr* out) = 0; -}; - -/// \since 0.8.0 -/// \note API not yet finalized -ARROW_EXPORT -Status GetUniqueKernel(FunctionContext* ctx, const std::shared_ptr& type, - std::unique_ptr* kernel); - -ARROW_EXPORT -Status GetDictionaryEncodeKernel(FunctionContext* ctx, - const std::shared_ptr& type, - std::unique_ptr* kernel); - /// \brief Compute unique elements from an array-like object +/// +/// Note if a null occurs in the input it will NOT be included in the output. +/// /// \param[in] context the FunctionContext /// \param[in] datum array-like input /// \param[out] out result as Array @@ -66,6 +47,29 @@ Status GetDictionaryEncodeKernel(FunctionContext* ctx, ARROW_EXPORT Status Unique(FunctionContext* context, const Datum& datum, std::shared_ptr* out); +// Constants for accessing the output of ValueCounts +ARROW_EXPORT extern const char kValuesFieldName[]; +ARROW_EXPORT extern const char kCountsFieldName[]; +ARROW_EXPORT extern const int32_t kValuesFieldIndex; +ARROW_EXPORT extern const int32_t kCountsFieldIndex; +/// \brief Return counts of unique elements from an array-like object. +/// +/// Note that the counts do not include counts for nulls in the array. These can be +/// obtained separately from metadata. +/// +/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values +/// which can lead to unexpected results if the input Array has these values. +/// +/// \param[in] context the FunctionContext +/// \param[in] value array-like input +/// \param[out] counts An array of structs. +/// +/// \since 0.13.0 +/// \note API not yet finalized +ARROW_EXPORT +Status ValueCounts(FunctionContext* context, const Datum& value, + std::shared_ptr* counts); + /// \brief Dictionary-encode values in an array-like object /// \param[in] context the FunctionContext /// \param[in] data array-like input @@ -81,11 +85,6 @@ Status DictionaryEncode(FunctionContext* context, const Datum& data, Datum* out) // TODO(wesm): Define API for regularizing DictionaryArray objects with // different dictionaries -// class DictionaryEncoder { -// public: -// virtual Encode(const Datum& data, Datum* out) = 0; -// }; - // // ARROW_EXPORT // Status DictionaryEncode(FunctionContext* context, const Datum& data, @@ -100,11 +99,6 @@ Status DictionaryEncode(FunctionContext* context, const Datum& data, Datum* out) // Status IsIn(FunctionContext* context, const Datum& values, const Datum& member_set, // Datum* out); -// ARROW_EXPORT -// Status CountValues(FunctionContext* context, const Datum& values, -// std::shared_ptr* out_uniques, -// std::shared_ptr* out_counts); - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 3dde0beeb19..044d4e96624 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -473,7 +473,10 @@ class SmallScalarMemoTable { // Copy values starting from index `start` into `out_data` void CopyValues(int32_t start, Scalar* out_data) const { - memcpy(out_data, &index_to_value_[start], size() - start); + DCHECK_GE(start, 0); + DCHECK_LE(static_cast(start), index_to_value_.size()); + int64_t offset = start * static_cast(sizeof(Scalar)); + memcpy(out_data, index_to_value_.data() + offset, (size() - start) * sizeof(Scalar)); } void CopyValues(Scalar* out_data) const { CopyValues(0, out_data); }