diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index fd554ba3d83..37cce33d99e 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -349,6 +349,37 @@ std::shared_ptr RunEndEncoded( std::move(value_type_matcher)); } +class NotMatcher : public TypeMatcher { + public: + explicit NotMatcher(std::shared_ptr base_matcher) + : base_matcher{std::move(base_matcher)} {} + + ~NotMatcher() override = default; + + bool Matches(const DataType& type) const override { + return !base_matcher->Matches(type); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + const auto* casted = dynamic_cast(&other); + return casted != nullptr && base_matcher->Equals(*casted->base_matcher); + } + + std::string ToString() const override { + return "not(" + base_matcher->ToString() + ")"; + }; + + private: + std::shared_ptr base_matcher; +}; + +std::shared_ptr Not(std::shared_ptr base_matcher) { + return std::make_shared(std::move(base_matcher)); +} + } // namespace match // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 1adb3e96c97..9a48e52fab2 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -172,6 +172,11 @@ ARROW_EXPORT std::shared_ptr RunEndEncoded( std::shared_ptr run_end_type_matcher, std::shared_ptr value_type_matcher); +/// \brief Match types that the base_matcher doesn't match +/// +/// @param[in] base_matcher a matcher used to negation match +ARROW_EXPORT std::shared_ptr Not(std::shared_ptr base_matcher); + } // namespace match /// \brief An object used for type-checking arguments to be passed to a kernel diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 1fbcd6a2490..c4370f9385a 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -482,6 +482,9 @@ void AddFirstOrLastAggKernel(ScalarAggregateFunction* func, // ---------------------------------------------------------------------- // MinMax implementation +using arrow::compute::match::Not; +using arrow::compute::match::SameTypeId; + Result> MinMaxInit(KernelContext* ctx, const KernelInitArgs& args) { ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, @@ -494,9 +497,10 @@ Result> MinMaxInit(KernelContext* ctx, // For "min" and "max" functions: override finalize and return the actual value template -void AddMinOrMaxAggKernel(ScalarAggregateFunction* func, - ScalarAggregateFunction* min_max_func) { - auto sig = KernelSignature::Make({InputType::Any()}, FirstType); +void AddMinOrMaxAggKernels(ScalarAggregateFunction* func, + ScalarAggregateFunction* min_max_func) { + std::shared_ptr sig = + KernelSignature::Make({InputType(Not(SameTypeId(Type::DICTIONARY)))}, FirstType); auto init = [min_max_func]( KernelContext* ctx, const KernelInitArgs& args) -> Result> { @@ -516,6 +520,9 @@ void AddMinOrMaxAggKernel(ScalarAggregateFunction* func, // Note SIMD level is always NONE, but the convenience kernel will // dispatch to an appropriate implementation + AddAggKernel(std::move(sig), init, finalize, func); + + sig = KernelSignature::Make({InputType(Type::DICTIONARY)}, DictionaryValueType); AddAggKernel(std::move(sig), std::move(init), std::move(finalize), func); } @@ -873,6 +880,15 @@ Result MinMaxType(KernelContext*, const std::vector& typ return struct_({field("min", ty), field("max", ty)}); } +Result DictionaryMinMaxType(KernelContext*, + const std::vector& types) { + // T -> struct + auto ty = types.front(); + const DictionaryType& ty_dict = checked_cast(*ty); + return struct_( + {field("min", ty_dict.value_type()), field("max", ty_dict.value_type())}); +} + } // namespace Result FirstLastType(KernelContext*, const std::vector& types) { @@ -896,7 +912,12 @@ void AddFirstLastKernels(KernelInit init, void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id, ScalarAggregateFunction* func, SimdLevel::type simd_level) { - auto sig = KernelSignature::Make({InputType(get_id.id)}, MinMaxType); + std::shared_ptr sig; + if (get_id.id == Type::DICTIONARY) { + sig = KernelSignature::Make({InputType(get_id.id)}, DictionaryMinMaxType); + } else { + sig = KernelSignature::Make({InputType(get_id.id)}, MinMaxType); + } AddAggKernel(std::move(sig), init, func, simd_level); } @@ -1118,6 +1139,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { AddMinMaxKernels(MinMaxInit, NumericTypes(), func.get()); AddMinMaxKernels(MinMaxInit, TemporalTypes(), func.get()); AddMinMaxKernels(MinMaxInit, BaseBinaryTypes(), func.get()); + AddMinMaxKernel(MinMaxInit, Type::DICTIONARY, func.get()); AddMinMaxKernel(MinMaxInit, Type::FIXED_SIZE_BINARY, func.get()); AddMinMaxKernel(MinMaxInit, Type::INTERVAL_MONTHS, func.get()); AddMinMaxKernel(MinMaxInit, Type::DECIMAL128, func.get()); @@ -1140,12 +1162,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Add min/max as convenience functions func = std::make_shared("min", Arity::Unary(), min_or_max_doc, &default_scalar_aggregate_options); - AddMinOrMaxAggKernel(func.get(), min_max_func); + AddMinOrMaxAggKernels(func.get(), min_max_func); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("max", Arity::Unary(), min_or_max_doc, &default_scalar_aggregate_options); - AddMinOrMaxAggKernel(func.get(), min_max_func); + AddMinOrMaxAggKernels(func.get(), min_max_func); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("product", Arity::Unary(), product_doc, diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index f08e7aaa538..00bf666c31b 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -971,6 +971,11 @@ struct FirstLastInitState { } }; +template +std::unique_ptr DictionaryMinMaxImplFunc(const DataType& in_type, + std::shared_ptr out_type, + ScalarAggregateOptions options); + template struct MinMaxInitState { std::unique_ptr state; @@ -1002,6 +1007,11 @@ struct MinMaxInitState { return Status::OK(); } + Status Visit(const DictionaryType&) { + state = DictionaryMinMaxImplFunc(in_type, out_type, options); + return Status::OK(); + } + template enable_if_physical_integer Visit(const Type&) { using PhysicalType = typename Type::PhysicalType; @@ -1033,4 +1043,112 @@ struct MinMaxInitState { } }; +template +struct DictionaryMinMaxImpl : public ScalarAggregator { + using ThisType = DictionaryMinMaxImpl; + + DictionaryMinMaxImpl(const DataType& in_type, std::shared_ptr out_type, + ScalarAggregateOptions options) + : options(std::move(options)), + out_type(std::move(out_type)), + has_nulls(false), + count(0), + value_type(checked_cast(in_type).value_type()), + value_state(nullptr) { + this->options.min_count = std::max(1, this->options.min_count); + } + + Status Consume(KernelContext* ctx, const ExecSpan& batch) override { + if (batch[0].is_scalar()) { + return Status::NotImplemented("No min/max implemented for DictionaryScalar"); + } + RETURN_NOT_OK(this->InitValueState()); + + // The minmax is computed from dictionay values, in case some values are not + // referenced by indices, a compaction needs to be excuted here. + DictionaryArray dict_arr(batch[0].array.ToArrayData()); + ARROW_ASSIGN_OR_RAISE(auto compacted_arr, dict_arr.Compact(ctx->memory_pool())); + const DictionaryArray& compacted_dict_arr = + checked_cast(*compacted_arr); + const int64_t null_count = compacted_dict_arr.ComputeLogicalNullCount(); + const int64_t non_null_count = compacted_dict_arr.length() - null_count; + + this->has_nulls |= null_count > 0; + this->count += non_null_count; + if ((this->has_nulls && !options.skip_nulls) || (non_null_count == 0)) { + return Status::OK(); + } + + const ArrayData& dict_data = + checked_cast(*compacted_dict_arr.dictionary()->data()); + RETURN_NOT_OK( + checked_cast(this->value_state.get()) + ->Consume(nullptr, ExecSpan(std::vector({ExecValue(dict_data)}), 1))); + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + auto&& other = checked_cast(src); + this->has_nulls |= other.has_nulls; + this->count += other.count; + if ((this->has_nulls && !options.skip_nulls) || other.value_state == nullptr) { + return Status::OK(); + } + + if (this->value_state == nullptr) { + this->value_state.reset(other.value_state.release()); + } else { + RETURN_NOT_OK(checked_cast(this->value_state.get()) + ->MergeFrom(nullptr, std::move(*other.value_state))); + } + return Status::OK(); + } + + Status Finalize(KernelContext*, Datum* out) override { + if ((this->has_nulls && !options.skip_nulls) || (this->count < options.min_count) || + this->value_state.get() == nullptr) { + const auto& struct_type = checked_cast(*out_type); + const auto& child_type = struct_type.field(0)->type(); + + std::shared_ptr null_scalar = MakeNullScalar(child_type); + std::vector> values = {null_scalar, null_scalar}; + out->value = std::make_shared(std::move(values), this->out_type); + } else { + Datum temp; + RETURN_NOT_OK(checked_cast(this->value_state.get()) + ->Finalize(nullptr, &temp)); + const auto& result = temp.scalar_as(); + DCHECK(result.is_valid); + out->value = result.GetSharedPtr(); + } + return Status::OK(); + } + + ScalarAggregateOptions options; + std::shared_ptr out_type; + bool has_nulls; + int64_t count; + std::shared_ptr value_type; + std::unique_ptr value_state; + + private: + inline Status InitValueState() { + if (this->value_state == nullptr) { + const DataType& value_type_ref = checked_cast(*this->value_type); + ScalarAggregateOptions options = ScalarAggregateOptions::Defaults(); + MinMaxInitState valueMinMaxInitState(nullptr, value_type_ref, + out_type, options); + ARROW_ASSIGN_OR_RAISE(this->value_state, valueMinMaxInitState.Create()); + } + return Status::OK(); + } +}; + +template +std::unique_ptr DictionaryMinMaxImplFunc(const DataType& in_type, + std::shared_ptr out_type, + ScalarAggregateOptions options) { + return std::make_unique>(in_type, out_type, options); +} + } // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 65439af2748..d855b3c2604 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -2175,6 +2175,257 @@ TEST(TestFixedSizeBinaryMinMaxKernel, Basics) { EXPECT_THAT(MinMax(ScalarFromJSON(ty, R"("aa")"), options), ResultWith(null)); } +void CheckDictionaryMinMax(const std::shared_ptr& value_type, Datum datum, + const std::string& expected_min, + const std::string& expected_max, + const ScalarAggregateOptions& options) { + std::shared_ptr result_type = + struct_({field("min", value_type), field("max", value_type)}); + EXPECT_THAT( + MinMax(datum, options), + ResultWith(ScalarFromJSON(result_type, "{\"min\": " + expected_min + + ", \"max\": " + expected_max + "}"))); +} + +void CheckDictionaryMinMax(const std::shared_ptr& index_type, + const std::shared_ptr& value_type, + const std::string& input_index_json, + const std::string& input_dictionary_json, + const std::string& expected_min, + const std::string& expected_max, + const ScalarAggregateOptions& options) { + auto dict_type = dictionary(index_type, value_type); + auto arr = DictArrayFromJSON(dict_type, input_index_json, input_dictionary_json); + CheckDictionaryMinMax(value_type, arr, expected_min, expected_max, options); +} + +TEST(TestDictionaryMinMaxKernel, IntegersValue) { + ScalarAggregateOptions options; + std::shared_ptr dict_ty; + + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + + for (const auto& value_ty : + {int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()}) { + dict_ty = dictionary(index_type, value_ty); + + auto chunk1 = DictArrayFromJSON(dict_ty, R"([null, 0])", R"([5])"); + auto chunk2 = DictArrayFromJSON(dict_ty, R"([0, 1, 1])", R"([3, 1])"); + ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2})); + + options = ScalarAggregateOptions(/*skip_nulls=*/true); + CheckDictionaryMinMax(value_ty, chunked, "1", "5", options); // chunked + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // noraml + R"([5, 1, 2, 3, 4])", "1", "5", options); + CheckDictionaryMinMax(index_type, value_ty, // null in indices + R"([0, 1, 2, 3, null])", R"([5, 9, 2, 3])", "2", "9", + options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // null in values + R"([null, null, 2, 3, 4])", "2", "4", options); + CheckDictionaryMinMax(index_type, value_ty, // null in both indices and values + R"([0, 1, 2, 3, null])", R"([null, null, 2, 3])", "2", "3", + options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced values + R"([0, 1, 2, 3, 5])", R"([5, 1, 2, 3, 4, 100, 101, 102])", + "1", "100", options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced values + R"([0, 1, 2, 3, 5, 0, 3, 1])", R"([5, 1, 2, 3, 4, 100, 101])", + "1", "100", options); + + options = ScalarAggregateOptions(/*skip_nulls=*/true, /*min_count=*/4); + CheckDictionaryMinMax(value_ty, chunked, "1", "5", options); // chunked + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // noraml + R"([5, 1, 2, 3, 4])", "1", "5", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2])", // too short + R"([5, 1, 2])", "null", "null", options); + CheckDictionaryMinMax(index_type, value_ty, // null in indices + R"([0, 1, 2, 3, null])", R"([5, 9, 2, 3])", "2", "9", + options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // null in values + R"([null, null, 2, 3, 4])", "null", "null", options); + CheckDictionaryMinMax(index_type, value_ty, // null in both indices and values + R"([0, 1, 2, 3, null])", R"([null, null, 2, 3])", "null", + "null", options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced values + R"([0, 1, 2, 3, 5])", R"([5, 1, 2, 3, 4, 100, 101, 102])", + "1", "100", options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced nulls + R"([0, 1, 2, 3, 5])", + R"([5, 1, 2, 3, null, 100, null, null])", "1", "100", + options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced values + R"([0, 1, 2, 3, 5, 0, 3, 1])", R"([5, 1, 2, 3, 4, 100])", "1", + "100", options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced nulls + R"([0, 1, 2, 3, 5, 0, 3, 1])", + R"([null, null, 2, null, 4, 100])", "null", "null", options); + + options = ScalarAggregateOptions(/*skip_nulls=*/false); + CheckDictionaryMinMax(value_ty, chunked, "null", "null", options); // chunked + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // noraml + R"([5, 1, 2, 3, 4])", "1", "5", options); + CheckDictionaryMinMax(index_type, value_ty, // null in indices + R"([0, 1, 2, 3, null])", R"([5, 9, 2, 3])", "null", "null", + options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // null in values + R"([null, null, 2, 3, 4])", "null", "null", options); + CheckDictionaryMinMax(index_type, value_ty, // null in both indices and values + R"([0, 1, 2, 3, null])", R"([null, null, 2, 3])", "null", + "null", options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced values + R"([0, 1, 2, 3, 5])", R"([5, 1, 2, 3, 4, 100, 101, 102])", + "1", "100", options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced nulls + R"([0, 1, 2, 3, 5])", + R"([5, 1, 2, 3, null, 100, null, null])", "1", "100", + options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced values + R"([0, 1, 2, 3, 5, 0, 3, 1])", R"([5, 1, 2, 3, 4, 100])", "1", + "100", options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced nulls + R"([0, 1, 2, 3, 5, 0, 3, 1])", + R"([null, null, 2, null, 4, 100])", "null", "null", options); + + options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/4); + CheckDictionaryMinMax(value_ty, chunked, "null", "null", options); // chunked + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // noraml + R"([5, 1, 2, 3, 4])", "1", "5", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2])", // too short + R"([5, 1, 2])", "null", "null", options); + CheckDictionaryMinMax(index_type, value_ty, // null in indices + R"([0, 1, 2, 3, null])", R"([5, 9, 2, 3])", "null", "null", + options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", // null in values + R"([null, null, 2, 3, 4])", "null", "null", options); + CheckDictionaryMinMax(index_type, value_ty, // null in both indices and values + R"([0, 1, 2, 3, null])", R"([null, null, 2, 3])", "null", + "null", options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced values + R"([0, 1, 2, 3, 5])", R"([5, 1, 2, 3, 4, 100, 101, 102])", + "1", "100", options); + CheckDictionaryMinMax(index_type, value_ty, // unreferenced nulls + R"([0, 1, 2, 3, 5])", + R"([5, 1, 2, 3, null, 100, null, null])", "1", "100", + options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced values + R"([0, 1, 2, 3, 5, 0, 3, 1])", R"([5, 1, 2, 3, 4, 100])", "1", + "100", options); + CheckDictionaryMinMax(index_type, value_ty, // multiply referenced nulls + R"([0, 1, 2, 3, 5, 0, 3, 1])", + R"([null, null, 2, null, 4, 100])", "null", "null", options); + } + } +} + +TEST(TestDictionaryMinMaxKernel, DecimalsValue) { + ScalarAggregateOptions options; + std::shared_ptr dict_ty; + std::shared_ptr ty; + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + + for (const auto& value_ty : {decimal128(5, 2), decimal256(5, 2)}) { + dict_ty = dictionary(index_type, value_ty); + ty = struct_({field("min", value_ty), field("max", value_ty)}); + + auto chunk1 = DictArrayFromJSON(dict_ty, R"([null, 0])", R"(["5.10"])"); + auto chunk2 = DictArrayFromJSON(dict_ty, R"([0, 1, 1])", R"(["3.10", "-1.23"])"); + ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2})); + + options = ScalarAggregateOptions(/*skip_nulls=*/true); + CheckDictionaryMinMax(value_ty, chunked, R"("-1.23")", R"("5.10")", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 1, 0])", + R"(["5.10", "-1.23"])", R"("-1.23")", R"("5.10")", options); + CheckDictionaryMinMax(index_type, value_ty, R"([3, 1, 1, 4, 0, 2, null])", + R"(["5.10", "-1.23", "2.00", "3.45", "4.56"])", R"("-1.23")", + R"("5.10")", options); + } + } +} + +TEST(TestDictionaryMinMaxKernel, BooleansValue) { + ScalarAggregateOptions options; + std::shared_ptr value_ty = boolean(); + std::shared_ptr dict_ty; + + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + dict_ty = dictionary(index_type, value_ty); + + auto chunk1 = DictArrayFromJSON(dict_ty, R"([null, 0])", R"([true])"); + auto chunk2 = DictArrayFromJSON(dict_ty, R"([0, 1, 1])", R"([false, true])"); + ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2})); + + CheckDictionaryMinMax(value_ty, chunked, "false", "true", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 0, 1])", R"([false, true])", + "false", "true", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 0, 0])", R"([false])", "false", + "false", options); + } +} + +TEST(TestDictionaryMinMaxKernel, FloatsValue) { + ScalarAggregateOptions options; + std::shared_ptr dict_ty; + + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + + for (const auto& value_ty : {float32(), float64()}) { + dict_ty = dictionary(index_type, value_ty); + + auto chunk1 = DictArrayFromJSON(dict_ty, R"([null, 0])", R"([5])"); + auto chunk2 = DictArrayFromJSON(dict_ty, R"([0, 1, 1])", R"([-Inf, 1])"); + ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2})); + + CheckDictionaryMinMax(value_ty, chunked, "-Inf", "5", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", + R"([5, 1, 2, 3, 4])", "1", "5", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", + R"([5, -Inf, 2, 3, 4])", "-Inf", "5", options); + } + } +} + +TEST(TestDictionaryMinMaxKernel, BinarysValue) { + ScalarAggregateOptions options; + std::shared_ptr dict_ty; + + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + + for (const auto& value_ty : {fixed_size_binary(2), binary(), large_binary()}) { + dict_ty = dictionary(index_type, value_ty); + auto chunk1 = DictArrayFromJSON(dict_ty, R"([null, 0])", R"(["hz"])"); + auto chunk2 = DictArrayFromJSON(dict_ty, R"([0, 1, 1])", R"(["aa", "bb"])"); + ASSERT_OK_AND_ASSIGN(auto chunked, ChunkedArray::Make({chunk1, chunk2})); + + CheckDictionaryMinMax(value_ty, chunked, R"("aa")", R"("hz")", options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", + R"(["hz", "bb", "bf", "cc", "fa"])", R"("bb")", R"("hz")", + options); + CheckDictionaryMinMax(index_type, value_ty, R"([0, 1, 2, 3, 4])", + R"(["hz", "aa", "bf", "cc", "fa"])", R"("aa")", R"("hz")", + options); + } + } +} + +TEST(TestDictionaryMinMaxKernel, NullValue) { + ScalarAggregateOptions options; + std::shared_ptr value_ty = null(); + + for (const auto& index_type : all_dictionary_index_types()) { + ARROW_SCOPED_TRACE("index_type = ", index_type->ToString()); + + CheckDictionaryMinMax(index_type, value_ty, R"([null, null])", R"([])", "null", + "null", options); + CheckDictionaryMinMax(index_type, value_ty, R"([])", R"([])", "null", "null", + options); + } +} + template struct MinMaxResult { using T = typename ArrowType::c_type; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 00a833742f9..593024c27c0 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -61,6 +61,14 @@ Result ListValuesType(KernelContext*, const std::vector& return list_type.value_type().get(); } +Result DictionaryValueType(KernelContext*, + const std::vector& types) { + // T -> T.value_type + auto ty = types.front(); + const DictionaryType& ty_dict = checked_cast(*ty); + return ty_dict.value_type(); +} + void EnsureDictionaryDecoded(std::vector* types) { EnsureDictionaryDecoded(types->data(), types->size()); } diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 72b29057b82..ec005234ec7 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -461,6 +461,8 @@ static void VisitTwoArrayValuesInline(const ArraySpan& arr0, const ArraySpan& ar Result FirstType(KernelContext*, const std::vector& types); Result LastType(KernelContext*, const std::vector& types); Result ListValuesType(KernelContext*, const std::vector& types); +Result DictionaryValueType(KernelContext*, + const std::vector& types); // ---------------------------------------------------------------------- // Helpers for iterating over common DataType instances for adding kernels to