diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index e845b1d80cc..5a2cc127771 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -112,6 +112,9 @@ G_BEGIN_DECLS * #GArrowScalarAggregateOptions is a class to customize the scalar * aggregate functions such as `count` function and convenient * functions of them such as garrow_array_count(). + + * #GArrowCountOptions is a class to customize the `count` function and + * garrow_array_count() family. * * #GArrowFilterOptions is a class to customize the `filter` function and * garrow_array_filter() family. @@ -767,6 +770,135 @@ garrow_scalar_aggregate_options_new(void) } +typedef struct GArrowCountOptionsPrivate_ { + arrow::compute::CountOptions options; +} GArrowCountOptionsPrivate; + +enum { + PROP_MODE = 1, +}; + +static arrow::compute::FunctionOptions * +garrow_count_options_get_raw_function_options(GArrowFunctionOptions *options) +{ + return garrow_count_options_get_raw(GARROW_COUNT_OPTIONS(options)); +} + +static void +garrow_count_options_function_options_interface_init( + GArrowFunctionOptionsInterface *iface) +{ + iface->get_raw = garrow_count_options_get_raw_function_options; +} + +G_DEFINE_TYPE_WITH_CODE(GArrowCountOptions, + garrow_count_options, + G_TYPE_OBJECT, + G_ADD_PRIVATE(GArrowCountOptions) + G_IMPLEMENT_INTERFACE( + GARROW_TYPE_FUNCTION_OPTIONS, + garrow_count_options_function_options_interface_init)) + +#define GARROW_COUNT_OPTIONS_GET_PRIVATE(object) \ + static_cast( \ + garrow_count_options_get_instance_private( \ + GARROW_COUNT_OPTIONS(object))) + +static void +garrow_count_options_finalize(GObject *object) +{ + auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object); + priv->options.~CountOptions(); + G_OBJECT_CLASS(garrow_count_options_parent_class)->finalize(object); +} + +static void +garrow_count_options_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_MODE: + priv->options.mode = + static_cast(g_value_get_enum(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_count_options_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_MODE: + g_value_set_enum(value, priv->options.mode); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +garrow_count_options_init(GArrowCountOptions *object) +{ + auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object); + new(&priv->options) arrow::compute::CountOptions; +} + +static void +garrow_count_options_class_init(GArrowCountOptionsClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = garrow_count_options_finalize; + gobject_class->set_property = garrow_count_options_set_property; + gobject_class->get_property = garrow_count_options_get_property; + + arrow::compute::CountOptions default_options; + + GParamSpec *spec; + /** + * GArrowCountOptions:null-selection-behavior: + * + * How to handle counted values. + * + * Since: 0.17.0 + */ + spec = g_param_spec_enum("mode", + "Count mode", + "Which values to count", + GARROW_TYPE_COUNT_MODE, + static_cast(default_options.mode), + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_MODE, spec); +} + +/** + * garrow_count_options_new: + * + * Returns: A newly created #GArrowCountOptions. + * + * Since: 6.0.0 + */ +GArrowCountOptions * +garrow_count_options_new(void) +{ + auto count_options = g_object_new(GARROW_TYPE_COUNT_OPTIONS, NULL); + return GARROW_COUNT_OPTIONS(count_options); +} + + typedef struct GArrowFilterOptionsPrivate_ { arrow::compute::FilterOptions options; } GArrowFilterOptionsPrivate; @@ -1558,7 +1690,7 @@ garrow_array_dictionary_encode(GArrowArray *array, /** * garrow_array_count: * @array: A #GArrowArray. - * @options: (nullable): A #GArrowScalarAggregateOptions. + * @options: (nullable): A #GArrowCountOptions. * @error: (nullable): Return location for a #GError or %NULL. * * Returns: The number of target values on success. If an error is occurred, @@ -1568,14 +1700,14 @@ garrow_array_dictionary_encode(GArrowArray *array, */ gint64 garrow_array_count(GArrowArray *array, - GArrowScalarAggregateOptions *options, + GArrowCountOptions *options, GError **error) { auto arrow_array = garrow_array_get_raw(array); auto arrow_array_raw = arrow_array.get(); arrow::Result arrow_counted_datum; if (options) { - auto arrow_options = garrow_scalar_aggregate_options_get_raw(options); + auto arrow_options = garrow_count_options_get_raw(options); arrow_counted_datum = arrow::compute::Count(*arrow_array_raw, *arrow_options); } else { @@ -2694,6 +2826,13 @@ garrow_scalar_aggregate_options_get_raw( return &(priv->options); } +arrow::compute::CountOptions * +garrow_count_options_get_raw(GArrowCountOptions *count_options) +{ + auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(count_options); + return &(priv->options); +} + arrow::compute::FilterOptions * garrow_filter_options_get_raw(GArrowFilterOptions *filter_options) { diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index 1163983644c..fe2e60181ce 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -98,6 +98,38 @@ GARROW_AVAILABLE_IN_5_0 GArrowScalarAggregateOptions * garrow_scalar_aggregate_options_new(void); +/** + * GArrowCountMode: + * @GARROW_COUNT_MODE_ONLY_VALID: + * Only non-null values will be counted. + * @GARROW_COUNT_MODE_ONLY_NULL: + * Only null values will be counted. + * @GARROW_COUNT_MODE_ALL: + * All will be counted. + * + * They correspond to the values of `arrow::compute::CountOptions::CountMode`. + */ +typedef enum { + GARROW_COUNT_MODE_ONLY_VALID, + GARROW_COUNT_MODE_ONLY_NULL, + GARROW_COUNT_MODE_ALL, +} GArrowCountMode; + +#define GARROW_TYPE_COUNT_OPTIONS (garrow_count_options_get_type()) +G_DECLARE_DERIVABLE_TYPE(GArrowCountOptions, + garrow_count_options, + GARROW, + COUNT_OPTIONS, + GObject) +struct _GArrowCountOptionsClass +{ + GObjectClass parent_class; +}; + +GARROW_AVAILABLE_IN_6_0 +GArrowCountOptions * +garrow_count_options_new(void); + /** * GArrowFilterNullSelectionBehavior: @@ -242,7 +274,7 @@ GArrowDictionaryArray *garrow_array_dictionary_encode(GArrowArray *array, GError **error); GARROW_AVAILABLE_IN_0_13 gint64 garrow_array_count(GArrowArray *array, - GArrowScalarAggregateOptions *options, + GArrowCountOptions *options, GError **error); GARROW_AVAILABLE_IN_0_13 GArrowStructArray *garrow_array_count_values(GArrowArray *array, diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp index 8089a1d3364..ba56bf0753a 100644 --- a/c_glib/arrow-glib/compute.hpp +++ b/c_glib/arrow-glib/compute.hpp @@ -53,6 +53,9 @@ arrow::compute::ScalarAggregateOptions * garrow_scalar_aggregate_options_get_raw( GArrowScalarAggregateOptions *scalar_aggregate_options); +arrow::compute::CountOptions * +garrow_count_options_get_raw(GArrowCountOptions *count_options); + arrow::compute::FilterOptions * garrow_filter_options_get_raw(GArrowFilterOptions *filter_options); diff --git a/c_glib/test/test-count.rb b/c_glib/test/test-count.rb index 39b6f06c4e6..6e94219143b 100644 --- a/c_glib/test/test-count.rb +++ b/c_glib/test/test-count.rb @@ -19,15 +19,25 @@ class TestCount < Test::Unit::TestCase include Helper::Buildable include Helper::Omittable - sub_test_case("skip_nulls") do + sub_test_case("mode") do def test_default assert_equal(2, build_int32_array([1, nil, 3]).count) + + options = Arrow::CountOptions.new + options.mode = Arrow::CountMode::ONLY_VALID + assert_equal(2, build_int32_array([1, nil, 3]).count(options)) end - def test_false - options = Arrow::ScalarAggregateOptions.new - options.skip_nulls = false + def test_nulls + options = Arrow::CountOptions.new + options.mode = Arrow::CountMode::ONLY_NULL assert_equal(1, build_int32_array([1, nil, 3]).count(options)) end + + def test_all + options = Arrow::CountOptions.new + options.mode = Arrow::CountMode::ALL + assert_equal(3, build_int32_array([1, nil, 3]).count(options)) + end end end diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc index 2261333a880..af7aec865fc 100644 --- a/cpp/src/arrow/compute/api_aggregate.cc +++ b/cpp/src/arrow/compute/api_aggregate.cc @@ -27,6 +27,24 @@ namespace arrow { namespace internal { +template <> +struct EnumTraits + : BasicEnumTraits { + static std::string name() { return "CountOptions::CountMode"; } + static std::string value_name(compute::CountOptions::CountMode value) { + switch (value) { + case compute::CountOptions::ONLY_VALID: + return "NON_NULL"; + case compute::CountOptions::ONLY_NULL: + return "NULLS"; + case compute::CountOptions::ALL: + return "ALL"; + } + return ""; + } +}; + template <> struct EnumTraits : BasicEnumTraits( DataMember("skip_nulls", &ScalarAggregateOptions::skip_nulls), DataMember("min_count", &ScalarAggregateOptions::min_count)); +static auto kCountOptionsType = + GetFunctionOptionsType(DataMember("mode", &CountOptions::mode)); static auto kModeOptionsType = GetFunctionOptionsType(DataMember("n", &ModeOptions::n)); static auto kVarianceOptionsType = @@ -86,6 +106,10 @@ ScalarAggregateOptions::ScalarAggregateOptions(bool skip_nulls, uint32_t min_cou min_count(min_count) {} constexpr char ScalarAggregateOptions::kTypeName[]; +CountOptions::CountOptions(CountMode mode) + : FunctionOptions(internal::kCountOptionsType), mode(mode) {} +constexpr char CountOptions::kTypeName[]; + ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {} constexpr char ModeOptions::kTypeName[]; @@ -124,6 +148,7 @@ constexpr char IndexOptions::kTypeName[]; namespace internal { void RegisterAggregateOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kScalarAggregateOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kCountOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kModeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kVarianceOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kQuantileOptionsType)); @@ -135,8 +160,7 @@ void RegisterAggregateOptions(FunctionRegistry* registry) { // ---------------------------------------------------------------------- // Scalar aggregates -Result Count(const Datum& value, const ScalarAggregateOptions& options, - ExecContext* ctx) { +Result Count(const Datum& value, const CountOptions& options, ExecContext* ctx) { return CallFunction("count", {value}, &options, ctx); } diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h index d73be1af97a..880424e97f8 100644 --- a/cpp/src/arrow/compute/api_aggregate.h +++ b/cpp/src/arrow/compute/api_aggregate.h @@ -53,6 +53,26 @@ class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { uint32_t min_count; }; +/// \brief Control count aggregate kernel behavior. +/// +/// By default, only non-null values are counted. +class ARROW_EXPORT CountOptions : public FunctionOptions { + public: + enum CountMode { + /// Count only non-null values. + ONLY_VALID = 0, + /// Count only null values. + ONLY_NULL, + /// Count both non-null and null values. + ALL, + }; + explicit CountOptions(CountMode mode = CountMode::ONLY_VALID); + constexpr static char const kTypeName[] = "CountOptions"; + static CountOptions Defaults() { return CountOptions{}; } + + CountMode mode; +}; + /// \brief Control Mode kernel behavior /// /// Returns top-n common values and counts. @@ -139,9 +159,9 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions { /// @} -/// \brief Count non-null (or null) values in an array. +/// \brief Count values in an array. /// -/// \param[in] options counting options, see ScalarAggregateOptions for more information +/// \param[in] options counting options, see CountOptions for more information /// \param[in] datum to count /// \param[in] ctx the function execution context, optional /// \return out resulting datum @@ -149,10 +169,9 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions { /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result Count( - const Datum& datum, - const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), - ExecContext* ctx = NULLPTR); +Result Count(const Datum& datum, + const CountOptions& options = CountOptions::Defaults(), + ExecContext* ctx = NULLPTR); /// \brief Compute the mean of a numeric array. /// diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index 7aca10ef0fa..16d3affe720 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -40,6 +40,8 @@ TEST(FunctionOptions, Equality) { std::vector> options; options.emplace_back(new ScalarAggregateOptions()); options.emplace_back(new ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1)); + options.emplace_back(new CountOptions()); + options.emplace_back(new CountOptions(CountOptions::ALL)); options.emplace_back(new ModeOptions()); options.emplace_back(new ModeOptions(/*n=*/2)); options.emplace_back(new VarianceOptions()); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index bc92c203687..d65e0ef506e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -57,10 +57,12 @@ namespace aggregate { // Count implementation struct CountImpl : public ScalarAggregator { - explicit CountImpl(ScalarAggregateOptions options) : options(std::move(options)) {} + explicit CountImpl(CountOptions options) : options(std::move(options)) {} Status Consume(KernelContext*, const ExecBatch& batch) override { - if (batch[0].is_array()) { + if (options.mode == CountOptions::ALL) { + this->non_nulls += batch.length; + } else if (batch[0].is_array()) { const ArrayData& input = *batch[0].array(); const int64_t nulls = input.GetNullCount(); this->nulls += nulls; @@ -82,15 +84,23 @@ struct CountImpl : public ScalarAggregator { Status Finalize(KernelContext* ctx, Datum* out) override { const auto& state = checked_cast(*ctx->state()); - if (state.options.skip_nulls) { - *out = Datum(state.non_nulls); - } else { - *out = Datum(state.nulls); + switch (state.options.mode) { + case CountOptions::ONLY_VALID: + case CountOptions::ALL: + // ALL is equivalent since we don't count the null/non-null + // separately to avoid potentially computing null count + *out = Datum(state.non_nulls); + break; + case CountOptions::ONLY_NULL: + *out = Datum(state.nulls); + break; + default: + DCHECK(false) << "unreachable"; } return Status::OK(); } - ScalarAggregateOptions options; + CountOptions options; int64_t non_nulls = 0; int64_t nulls = 0; }; @@ -98,7 +108,7 @@ struct CountImpl : public ScalarAggregator { Result> CountInit(KernelContext*, const KernelInitArgs& args) { return ::arrow::internal::make_unique( - static_cast(*args.options)); + static_cast(*args.options)); } // ---------------------------------------------------------------------- @@ -560,9 +570,9 @@ namespace { const FunctionDoc count_doc{"Count the number of null / non-null values", ("By default, only non-null values are counted.\n" - "This can be changed through ScalarAggregateOptions."), + "This can be changed through CountOptions."), {"array"}, - "ScalarAggregateOptions"}; + "CountOptions"}; const FunctionDoc sum_doc{ "Compute the sum of a numeric array", @@ -623,17 +633,15 @@ const FunctionDoc index_doc{"Find the index of the first occurrence of a given v void RegisterScalarAggregateBasic(FunctionRegistry* registry) { static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults(); + static auto default_count_options = CountOptions::Defaults(); auto func = std::make_shared( - "count", Arity::Unary(), &count_doc, &default_scalar_aggregate_options); + "count", Arity::Unary(), &count_doc, &default_count_options); - // Takes any array input, outputs int64 scalar - InputType any_array(ValueDescr::ARRAY); - AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())), + // Takes any input, outputs int64 scalar + InputType any_input; + AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), aggregate::CountInit, func.get()); - AddAggKernel( - KernelSignature::Make({InputType(ValueDescr::SCALAR)}, ValueDescr::Scalar(int64())), - aggregate::CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared("sum", Arity::Unary(), &sum_doc, diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index db96bc3cad8..8aee3049970 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -577,14 +577,18 @@ static CountPair NaiveCount(const Array& array) { } void ValidateCount(const Array& input, CountPair expected) { - ScalarAggregateOptions all = ScalarAggregateOptions(/*skip_nulls=*/true); - ScalarAggregateOptions nulls = ScalarAggregateOptions(/*skip_nulls=*/false); + CountOptions non_null; + CountOptions nulls(CountOptions::ONLY_NULL); + CountOptions all(CountOptions::ALL); - ASSERT_OK_AND_ASSIGN(Datum result, Count(input, all)); + ASSERT_OK_AND_ASSIGN(Datum result, Count(input, non_null)); AssertDatumsEqual(result, Datum(expected.first)); ASSERT_OK_AND_ASSIGN(result, Count(input, nulls)); AssertDatumsEqual(result, Datum(expected.second)); + + ASSERT_OK_AND_ASSIGN(result, Count(input, all)); + AssertDatumsEqual(result, Datum(expected.first + expected.second)); } template @@ -608,11 +612,15 @@ TYPED_TEST(TestCountKernel, SimpleCount) { auto ty = TypeTraits::type_singleton(); EXPECT_THAT(Count(MakeNullScalar(ty)), ResultWith(Datum(int64_t(0)))); - EXPECT_THAT(Count(MakeNullScalar(ty), ScalarAggregateOptions(/*skip_nulls=*/false)), + EXPECT_THAT(Count(MakeNullScalar(ty), CountOptions(CountOptions::ONLY_NULL)), ResultWith(Datum(int64_t(1)))); EXPECT_THAT(Count(*MakeScalar(ty, 1)), ResultWith(Datum(int64_t(1)))); - EXPECT_THAT(Count(*MakeScalar(ty, 1), ScalarAggregateOptions(/*skip_nulls=*/false)), + EXPECT_THAT(Count(*MakeScalar(ty, 1), CountOptions(CountOptions::ONLY_NULL)), ResultWith(Datum(int64_t(0)))); + + CountOptions all(CountOptions::ALL); + EXPECT_THAT(Count(MakeNullScalar(ty), all), ResultWith(Datum(int64_t(1)))); + EXPECT_THAT(Count(*MakeScalar(ty, 1), all), ResultWith(Datum(int64_t(1)))); } template diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index b3d602a89ac..5774f93750c 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -829,7 +829,7 @@ Status AddHashAggKernels( struct GroupedCountImpl : public GroupedAggregator { Status Init(ExecContext* ctx, const FunctionOptions* options) override { - options_ = checked_cast(*options); + options_ = checked_cast(*options); counts_ = BufferBuilder(ctx->memory_pool()); return Status::OK(); } @@ -859,25 +859,36 @@ struct GroupedCountImpl : public GroupedAggregator { const auto& input = batch[0].array(); - if (options_.skip_nulls) { - auto g_begin = - reinterpret_cast(batch[1].array()->buffers[1]->data()); - - arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset, - input->length, - [&](int64_t offset, int64_t length) { - auto g = g_begin + offset; - for (int64_t i = 0; i < length; ++i, ++g) { - counts[*g] += 1; - } - }); - } else if (input->MayHaveNulls()) { - auto g = batch[1].array()->GetValues(1); - - auto end = input->offset + input->length; - for (int64_t i = input->offset; i < end; ++i, ++g) { - counts[*g] += !BitUtil::GetBit(input->buffers[0]->data(), i); + auto g_begin = batch[1].array()->GetValues(1); + switch (options_.mode) { + case CountOptions::ONLY_VALID: { + arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset, + input->length, + [&](int64_t offset, int64_t length) { + auto g = g_begin + offset; + for (int64_t i = 0; i < length; ++i, ++g) { + counts[*g] += 1; + } + }); + break; + } + case CountOptions::ONLY_NULL: { + if (input->MayHaveNulls()) { + auto end = input->offset + input->length; + for (int64_t i = input->offset; i < end; ++i, ++g_begin) { + counts[*g_begin] += !BitUtil::GetBit(input->buffers[0]->data(), i); + } + } + break; } + case CountOptions::ALL: { + for (int64_t i = 0; i < batch.length; ++i, ++g_begin) { + counts[*g_begin] += 1; + } + break; + } + default: + DCHECK(false) << "unreachable"; } return Status::OK(); } @@ -890,7 +901,7 @@ struct GroupedCountImpl : public GroupedAggregator { std::shared_ptr out_type() const override { return int64(); } int64_t num_groups_ = 0; - ScalarAggregateOptions options_; + CountOptions options_; BufferBuilder counts_; }; @@ -2082,7 +2093,7 @@ const FunctionDoc hash_count_doc{"Count the number of null / non-null values", ("By default, non-null values are counted.\n" "This can be changed through ScalarAggregateOptions."), {"array", "group_id_array"}, - "ScalarAggregateOptions"}; + "CountOptions"}; const FunctionDoc hash_sum_doc{"Sum values of a numeric array", ("Null values are ignored."), @@ -2140,9 +2151,9 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true", void RegisterHashAggregateBasic(FunctionRegistry* registry) { static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults(); { + static auto default_count_options = CountOptions::Defaults(); auto func = std::make_shared( - "hash_count", Arity::Binary(), &hash_count_doc, - &default_scalar_aggregate_options); + "hash_count", Arity::Binary(), &hash_count_doc, &default_count_options); DCHECK_OK(func->AddKernel( MakeKernel(ValueDescr::ARRAY, HashAggregateInit))); diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index cb9d05e0d35..ca9a8241305 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -1077,7 +1077,9 @@ TEST(GroupBy, CountAndSum) { [null, 3] ])"); - ScalarAggregateOptions count_options; + CountOptions count_options; + CountOptions count_nulls(CountOptions::ONLY_NULL); + CountOptions count_all(CountOptions::ALL); ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3); ASSERT_OK_AND_ASSIGN( Datum aggregated_and_grouped, @@ -1087,6 +1089,8 @@ TEST(GroupBy, CountAndSum) { batch->GetColumnByName("argument"), batch->GetColumnByName("argument"), batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), + batch->GetColumnByName("argument"), batch->GetColumnByName("key"), }, { @@ -1094,6 +1098,8 @@ TEST(GroupBy, CountAndSum) { }, { {"hash_count", &count_options}, + {"hash_count", &count_nulls}, + {"hash_count", &count_all}, {"hash_sum", nullptr}, {"hash_sum", &min_count}, {"hash_sum", nullptr}, @@ -1101,6 +1107,8 @@ TEST(GroupBy, CountAndSum) { AssertDatumsEqual( ArrayFromJSON(struct_({ + field("hash_count", int64()), + field("hash_count", int64()), field("hash_count", int64()), // NB: summing a float32 array results in float64 sums field("hash_sum", float64()), @@ -1109,10 +1117,10 @@ TEST(GroupBy, CountAndSum) { field("key_0", int64()), }), R"([ - [2, 4.25, null, 3, 1], - [3, -0.125, -0.125, 6, 2], - [0, null, null, 6, 3], - [2, 4.75, null, null, null] + [2, 1, 3, 4.25, null, 3, 1], + [3, 0, 3, -0.125, -0.125, 6, 2], + [0, 2, 2, null, null, 6, 3], + [2, 0, 2, 4.75, null, null, null] ])"), aggregated_and_grouped, /*verbose=*/true); @@ -1248,13 +1256,14 @@ TEST(GroupBy, ConcreteCaseWithValidateGroupBy) { ])"); ScalarAggregateOptions keepna{false, 1}; - ScalarAggregateOptions skipna{true, 1}; + CountOptions nulls(CountOptions::ONLY_NULL); + CountOptions non_null(CountOptions::ONLY_VALID); using internal::Aggregate; for (auto agg : { Aggregate{"hash_sum", nullptr}, - Aggregate{"hash_count", &skipna}, - Aggregate{"hash_count", &keepna}, + Aggregate{"hash_count", &non_null}, + Aggregate{"hash_count", &nulls}, Aggregate{"hash_min_max", nullptr}, Aggregate{"hash_min_max", &keepna}, }) { @@ -1273,7 +1282,7 @@ TEST(GroupBy, CountNull) { [3.0, "gama"] ])"); - ScalarAggregateOptions keepna{false}, skipna{true}; + CountOptions keepna{CountOptions::ONLY_NULL}, skipna{CountOptions::ONLY_VALID}; using internal::Aggregate; for (auto agg : { @@ -1323,7 +1332,6 @@ TEST(GroupBy, WithChunkedArray) { {"argument": 0.75, "key": null}, {"argument": null, "key": 3} ])"}); - ScalarAggregateOptions count_options; ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, internal::GroupBy( { @@ -1335,7 +1343,7 @@ TEST(GroupBy, WithChunkedArray) { table->GetColumnByName("key"), }, { - {"hash_count", &count_options}, + {"hash_count", nullptr}, {"hash_sum", nullptr}, {"hash_min_max", nullptr}, })); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 58321457ab0..5d4817bf079 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -190,23 +190,23 @@ Aggregations +---------------+-------+-------------+----------------+----------------------------------+-------+ | any | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) | +---------------+-------+-------------+----------------+----------------------------------+-------+ -| count | Unary | Any | Scalar Int64 | :struct:`ScalarAggregateOptions` | | +| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) | +---------------+-------+-------------+----------------+----------------------------------+-------+ | index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | | +---------------+-------+-------------+----------------+----------------------------------+-------+ | mean | Unary | Numeric | Scalar Float64 | :struct:`ScalarAggregateOptions` | | +---------------+-------+-------------+----------------+----------------------------------+-------+ -| min_max | Unary | Numeric | Scalar Struct | :struct:`ScalarAggregateOptions` | \(2) | +| min_max | Unary | Numeric | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) | +---------------+-------+-------------+----------------+----------------------------------+-------+ -| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(3) | +| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) | +---------------+-------+-------------+----------------+----------------------------------+-------+ -| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(4) | +| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) | +---------------+-------+-------------+----------------+----------------------------------+-------+ -| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(5) | +| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(6) | +---------------+-------+-------------+----------------+----------------------------------+-------+ | stddev | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | | +---------------+-------+-------------+----------------+----------------------------------+-------+ -| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(4) | +| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) | +---------------+-------+-------------+----------------+----------------------------------+-------+ | tdigest | Unary | Numeric | Scalar Float64 | :struct:`TDigestOptions` | | +---------------+-------+-------------+----------------+----------------------------------+-------+ @@ -218,18 +218,21 @@ Notes: * \(1) If null values are taken into account by setting ScalarAggregateOptions parameter skip_nulls = false then `Kleene logic`_ logic is applied. -* \(2) Output is a ``{"min": input type, "max": input type}`` Struct. +* \(2) CountMode controls whether only non-null values are counted (the + default), only null values are counted, or all values are counted. -* \(3) Output is an array of ``{"mode": input type, "count": Int64}`` Struct. +* \(3) Output is a ``{"min": input type, "max": input type}`` Struct. + +* \(4) Output is an array of ``{"mode": input type, "count": Int64}`` Struct. It contains the *N* most common elements in the input, in descending order, where *N* is given in :member:`ModeOptions::n`. If two values have the same count, the smallest one comes first. Note that the output can have less than *N* elements if the input has less than *N* distinct values. -* \(4) Output is Int64, UInt64 or Float64, depending on the input type. +* \(5) Output is Int64, UInt64 or Float64, depending on the input type. -* \(5) Output is Float64 or input type, depending on QuantileOptions. +* \(6) Output is Float64 or input type, depending on QuantileOptions. Element-wise ("scalar") functions --------------------------------- diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 46cfdc4e2ef..2d82928f377 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -887,6 +887,23 @@ class ScalarAggregateOptions(_ScalarAggregateOptions): self._set_options(skip_nulls, min_count) +cdef class _CountOptions(FunctionOptions): + def _set_options(self, mode): + if mode == 'only_valid': + self.wrapped.reset(new CCountOptions(CCountMode_ONLY_VALID)) + elif mode == 'only_null': + self.wrapped.reset(new CCountOptions(CCountMode_ONLY_NULL)) + elif mode == 'all': + self.wrapped.reset(new CCountOptions(CCountMode_ALL)) + else: + raise ValueError(f'"{mode}" is not a valid mode') + + +class CountOptions(_CountOptions): + def __init__(self, mode='only_valid'): + self._set_options(mode) + + cdef class _IndexOptions(FunctionOptions): def _set_options(self, Scalar scalar): self.wrapped.reset(new CIndexOptions(pyarrow_unwrap_scalar(scalar))) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 85f637fce5a..10880c2974c 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -31,6 +31,7 @@ # Option classes ArraySortOptions, CastOptions, + CountOptions, DictionaryEncodeOptions, ElementWiseAggregateOptions, ExtractRegexOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 34d81bce04b..7dcde652a95 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1959,6 +1959,16 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool skip_nulls uint32_t min_count + enum CCountMode "arrow::compute::CountOptions::CountMode": + CCountMode_ONLY_VALID "arrow::compute::CountOptions::ONLY_VALID" + CCountMode_ONLY_NULL "arrow::compute::CountOptions::ONLY_NULL" + CCountMode_ALL "arrow::compute::CountOptions::ALL" + + cdef cppclass CCountOptions \ + "arrow::compute::CountOptions"(CFunctionOptions): + CCountOptions(CCountMode mode) + CCountMode mode + cdef cppclass CModeOptions \ "arrow::compute::ModeOptions"(CFunctionOptions): CModeOptions(int64_t n) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index b0baa76e50a..60a2f60f942 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -1439,11 +1439,9 @@ def test_extract_datetime_components(): def test_count(): arr = pa.array([1, 2, 3, None, None]) assert pc.count(arr).as_py() == 3 - assert pc.count(arr, skip_nulls=True).as_py() == 3 - assert pc.count(arr, skip_nulls=False).as_py() == 2 - - with pytest.raises(TypeError, match="an integer is required"): - pc.count(arr, min_count='zzz') + assert pc.count(arr, mode='only_valid').as_py() == 3 + assert pc.count(arr, mode='only_null').as_py() == 2 + assert pc.count(arr, mode='all').as_py() == 5 def test_index(): diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 30821137383..0695e2525f7 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -172,7 +172,7 @@ std::shared_ptr make_compute_options( } if (func_name == "min_max" || func_name == "sum" || func_name == "mean" || - func_name == "count" || func_name == "any" || func_name == "all") { + func_name == "any" || func_name == "all") { using Options = arrow::compute::ScalarAggregateOptions; auto out = std::make_shared(Options::Defaults()); out->min_count = cpp11::as_cpp(options["na.min_count"]); @@ -180,6 +180,14 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "count") { + using Options = arrow::compute::CountOptions; + auto out = std::make_shared(Options::Defaults()); + out->mode = + cpp11::as_cpp(options["na.rm"]) ? Options::ONLY_VALID : Options::ONLY_NULL; + return out; + } + if (func_name == "min_element_wise" || func_name == "max_element_wise") { using Options = arrow::compute::ElementWiseAggregateOptions; bool skip_nulls = true;