From 6cd375c69df677ac9e431ee2882c11a5e52929dc Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 6 Aug 2021 14:16:41 -0400
Subject: [PATCH 1/2] ARROW-13574: [C++] Add count_all and hash_count_all
kernels
---
c_glib/arrow-glib/compute.cpp | 145 +++++++++++++++++-
c_glib/arrow-glib/compute.h | 34 +++-
c_glib/arrow-glib/compute.hpp | 3 +
c_glib/test/test-count.rb | 18 ++-
cpp/src/arrow/compute/api_aggregate.cc | 28 +++-
cpp/src/arrow/compute/api_aggregate.h | 31 +++-
cpp/src/arrow/compute/function_test.cc | 2 +
.../arrow/compute/kernels/aggregate_basic.cc | 36 +++--
.../arrow/compute/kernels/aggregate_test.cc | 18 ++-
.../arrow/compute/kernels/hash_aggregate.cc | 57 ++++---
.../compute/kernels/hash_aggregate_test.cc | 30 ++--
docs/source/cpp/compute.rst | 23 +--
python/pyarrow/_compute.pyx | 17 ++
python/pyarrow/compute.py | 1 +
python/pyarrow/includes/libarrow.pxd | 10 ++
python/pyarrow/tests/test_compute.py | 8 +-
r/src/compute.cpp | 10 +-
17 files changed, 385 insertions(+), 86 deletions(-)
diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp
index e845b1d80cc..5a2cc127771 100644
--- a/c_glib/arrow-glib/compute.cpp
+++ b/c_glib/arrow-glib/compute.cpp
@@ -112,6 +112,9 @@ G_BEGIN_DECLS
* #GArrowScalarAggregateOptions is a class to customize the scalar
* aggregate functions such as `count` function and convenient
* functions of them such as garrow_array_count().
+
+ * #GArrowCountOptions is a class to customize the `count` function and
+ * garrow_array_count() family.
*
* #GArrowFilterOptions is a class to customize the `filter` function and
* garrow_array_filter() family.
@@ -767,6 +770,135 @@ garrow_scalar_aggregate_options_new(void)
}
+typedef struct GArrowCountOptionsPrivate_ {
+ arrow::compute::CountOptions options;
+} GArrowCountOptionsPrivate;
+
+enum {
+ PROP_MODE = 1,
+};
+
+static arrow::compute::FunctionOptions *
+garrow_count_options_get_raw_function_options(GArrowFunctionOptions *options)
+{
+ return garrow_count_options_get_raw(GARROW_COUNT_OPTIONS(options));
+}
+
+static void
+garrow_count_options_function_options_interface_init(
+ GArrowFunctionOptionsInterface *iface)
+{
+ iface->get_raw = garrow_count_options_get_raw_function_options;
+}
+
+G_DEFINE_TYPE_WITH_CODE(GArrowCountOptions,
+ garrow_count_options,
+ G_TYPE_OBJECT,
+ G_ADD_PRIVATE(GArrowCountOptions)
+ G_IMPLEMENT_INTERFACE(
+ GARROW_TYPE_FUNCTION_OPTIONS,
+ garrow_count_options_function_options_interface_init))
+
+#define GARROW_COUNT_OPTIONS_GET_PRIVATE(object) \
+ static_cast( \
+ garrow_count_options_get_instance_private( \
+ GARROW_COUNT_OPTIONS(object)))
+
+static void
+garrow_count_options_finalize(GObject *object)
+{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object);
+ priv->options.~CountOptions();
+ G_OBJECT_CLASS(garrow_count_options_parent_class)->finalize(object);
+}
+
+static void
+garrow_count_options_set_property(GObject *object,
+ guint prop_id,
+ const GValue *value,
+ GParamSpec *pspec)
+{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object);
+
+ switch (prop_id) {
+ case PROP_MODE:
+ priv->options.mode =
+ static_cast(g_value_get_enum(value));
+ break;
+ default:
+ G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec);
+ break;
+ }
+}
+
+static void
+garrow_count_options_get_property(GObject *object,
+ guint prop_id,
+ GValue *value,
+ GParamSpec *pspec)
+{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object);
+
+ switch (prop_id) {
+ case PROP_MODE:
+ g_value_set_enum(value, priv->options.mode);
+ break;
+ default:
+ G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec);
+ break;
+ }
+}
+
+static void
+garrow_count_options_init(GArrowCountOptions *object)
+{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(object);
+ new(&priv->options) arrow::compute::CountOptions;
+}
+
+static void
+garrow_count_options_class_init(GArrowCountOptionsClass *klass)
+{
+ auto gobject_class = G_OBJECT_CLASS(klass);
+
+ gobject_class->finalize = garrow_count_options_finalize;
+ gobject_class->set_property = garrow_count_options_set_property;
+ gobject_class->get_property = garrow_count_options_get_property;
+
+ arrow::compute::CountOptions default_options;
+
+ GParamSpec *spec;
+ /**
+ * GArrowCountOptions:null-selection-behavior:
+ *
+ * How to handle counted values.
+ *
+ * Since: 0.17.0
+ */
+ spec = g_param_spec_enum("mode",
+ "Count mode",
+ "Which values to count",
+ GARROW_TYPE_COUNT_MODE,
+ static_cast(default_options.mode),
+ static_cast(G_PARAM_READWRITE));
+ g_object_class_install_property(gobject_class, PROP_MODE, spec);
+}
+
+/**
+ * garrow_count_options_new:
+ *
+ * Returns: A newly created #GArrowCountOptions.
+ *
+ * Since: 6.0.0
+ */
+GArrowCountOptions *
+garrow_count_options_new(void)
+{
+ auto count_options = g_object_new(GARROW_TYPE_COUNT_OPTIONS, NULL);
+ return GARROW_COUNT_OPTIONS(count_options);
+}
+
+
typedef struct GArrowFilterOptionsPrivate_ {
arrow::compute::FilterOptions options;
} GArrowFilterOptionsPrivate;
@@ -1558,7 +1690,7 @@ garrow_array_dictionary_encode(GArrowArray *array,
/**
* garrow_array_count:
* @array: A #GArrowArray.
- * @options: (nullable): A #GArrowScalarAggregateOptions.
+ * @options: (nullable): A #GArrowCountOptions.
* @error: (nullable): Return location for a #GError or %NULL.
*
* Returns: The number of target values on success. If an error is occurred,
@@ -1568,14 +1700,14 @@ garrow_array_dictionary_encode(GArrowArray *array,
*/
gint64
garrow_array_count(GArrowArray *array,
- GArrowScalarAggregateOptions *options,
+ GArrowCountOptions *options,
GError **error)
{
auto arrow_array = garrow_array_get_raw(array);
auto arrow_array_raw = arrow_array.get();
arrow::Result arrow_counted_datum;
if (options) {
- auto arrow_options = garrow_scalar_aggregate_options_get_raw(options);
+ auto arrow_options = garrow_count_options_get_raw(options);
arrow_counted_datum =
arrow::compute::Count(*arrow_array_raw, *arrow_options);
} else {
@@ -2694,6 +2826,13 @@ garrow_scalar_aggregate_options_get_raw(
return &(priv->options);
}
+arrow::compute::CountOptions *
+garrow_count_options_get_raw(GArrowCountOptions *count_options)
+{
+ auto priv = GARROW_COUNT_OPTIONS_GET_PRIVATE(count_options);
+ return &(priv->options);
+}
+
arrow::compute::FilterOptions *
garrow_filter_options_get_raw(GArrowFilterOptions *filter_options)
{
diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h
index 1163983644c..c74d28d17e1 100644
--- a/c_glib/arrow-glib/compute.h
+++ b/c_glib/arrow-glib/compute.h
@@ -98,6 +98,38 @@ GARROW_AVAILABLE_IN_5_0
GArrowScalarAggregateOptions *
garrow_scalar_aggregate_options_new(void);
+/**
+ * GArrowCountMode:
+ * @GARROW_COUNT_MODE_NON_NULL:
+ * Only non-null values will be counted.
+ * @GARROW_COUNT_MODE_NULLS:
+ * Only null values will be counted.
+ * @GARROW_COUNT_MODE_ALL:
+ * All will be counted.
+ *
+ * They correspond to the values of `arrow::compute::CountOptions::CountMode`.
+ */
+typedef enum {
+ GARROW_COUNT_MODE_NON_NULL,
+ GARROW_COUNT_MODE_NULLS,
+ GARROW_COUNT_MODE_ALL,
+} GArrowCountMode;
+
+#define GARROW_TYPE_COUNT_OPTIONS (garrow_count_options_get_type())
+G_DECLARE_DERIVABLE_TYPE(GArrowCountOptions,
+ garrow_count_options,
+ GARROW,
+ COUNT_OPTIONS,
+ GObject)
+struct _GArrowCountOptionsClass
+{
+ GObjectClass parent_class;
+};
+
+GARROW_AVAILABLE_IN_6_0
+GArrowCountOptions *
+garrow_count_options_new(void);
+
/**
* GArrowFilterNullSelectionBehavior:
@@ -242,7 +274,7 @@ GArrowDictionaryArray *garrow_array_dictionary_encode(GArrowArray *array,
GError **error);
GARROW_AVAILABLE_IN_0_13
gint64 garrow_array_count(GArrowArray *array,
- GArrowScalarAggregateOptions *options,
+ GArrowCountOptions *options,
GError **error);
GARROW_AVAILABLE_IN_0_13
GArrowStructArray *garrow_array_count_values(GArrowArray *array,
diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp
index 8089a1d3364..ba56bf0753a 100644
--- a/c_glib/arrow-glib/compute.hpp
+++ b/c_glib/arrow-glib/compute.hpp
@@ -53,6 +53,9 @@ arrow::compute::ScalarAggregateOptions *
garrow_scalar_aggregate_options_get_raw(
GArrowScalarAggregateOptions *scalar_aggregate_options);
+arrow::compute::CountOptions *
+garrow_count_options_get_raw(GArrowCountOptions *count_options);
+
arrow::compute::FilterOptions *
garrow_filter_options_get_raw(GArrowFilterOptions *filter_options);
diff --git a/c_glib/test/test-count.rb b/c_glib/test/test-count.rb
index 39b6f06c4e6..55804a1165b 100644
--- a/c_glib/test/test-count.rb
+++ b/c_glib/test/test-count.rb
@@ -19,15 +19,25 @@ class TestCount < Test::Unit::TestCase
include Helper::Buildable
include Helper::Omittable
- sub_test_case("skip_nulls") do
+ sub_test_case("mode") do
def test_default
assert_equal(2, build_int32_array([1, nil, 3]).count)
+
+ options = Arrow::CountOptions.new
+ options.mode = Arrow::CountMode::NON_NULL
+ assert_equal(2, build_int32_array([1, nil, 3]).count(options))
end
- def test_false
- options = Arrow::ScalarAggregateOptions.new
- options.skip_nulls = false
+ def test_nulls
+ options = Arrow::CountOptions.new
+ options.mode = Arrow::CountMode::NULLS
assert_equal(1, build_int32_array([1, nil, 3]).count(options))
end
+
+ def test_all
+ options = Arrow::CountOptions.new
+ options.mode = Arrow::CountMode::ALL
+ assert_equal(3, build_int32_array([1, nil, 3]).count(options))
+ end
end
end
diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc
index 2261333a880..a165f966e5a 100644
--- a/cpp/src/arrow/compute/api_aggregate.cc
+++ b/cpp/src/arrow/compute/api_aggregate.cc
@@ -27,6 +27,24 @@
namespace arrow {
namespace internal {
+template <>
+struct EnumTraits
+ : BasicEnumTraits {
+ static std::string name() { return "CountOptions::CountMode"; }
+ static std::string value_name(compute::CountOptions::CountMode value) {
+ switch (value) {
+ case compute::CountOptions::NON_NULL:
+ return "NON_NULL";
+ case compute::CountOptions::NULLS:
+ return "NULLS";
+ case compute::CountOptions::ALL:
+ return "ALL";
+ }
+ return "";
+ }
+};
+
template <>
struct EnumTraits
: BasicEnumTraits(
DataMember("skip_nulls", &ScalarAggregateOptions::skip_nulls),
DataMember("min_count", &ScalarAggregateOptions::min_count));
+static auto kCountOptionsType =
+ GetFunctionOptionsType(DataMember("mode", &CountOptions::mode));
static auto kModeOptionsType =
GetFunctionOptionsType(DataMember("n", &ModeOptions::n));
static auto kVarianceOptionsType =
@@ -86,6 +106,10 @@ ScalarAggregateOptions::ScalarAggregateOptions(bool skip_nulls, uint32_t min_cou
min_count(min_count) {}
constexpr char ScalarAggregateOptions::kTypeName[];
+CountOptions::CountOptions(CountMode mode)
+ : FunctionOptions(internal::kCountOptionsType), mode(mode) {}
+constexpr char CountOptions::kTypeName[];
+
ModeOptions::ModeOptions(int64_t n) : FunctionOptions(internal::kModeOptionsType), n(n) {}
constexpr char ModeOptions::kTypeName[];
@@ -124,6 +148,7 @@ constexpr char IndexOptions::kTypeName[];
namespace internal {
void RegisterAggregateOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kScalarAggregateOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kCountOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kModeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kVarianceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kQuantileOptionsType));
@@ -135,8 +160,7 @@ void RegisterAggregateOptions(FunctionRegistry* registry) {
// ----------------------------------------------------------------------
// Scalar aggregates
-Result Count(const Datum& value, const ScalarAggregateOptions& options,
- ExecContext* ctx) {
+Result Count(const Datum& value, const CountOptions& options, ExecContext* ctx) {
return CallFunction("count", {value}, &options, ctx);
}
diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h
index d73be1af97a..667b057a1f1 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -53,6 +53,26 @@ class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions {
uint32_t min_count;
};
+/// \brief Control count aggregate kernel behavior.
+///
+/// By default, only non-null values are counted.
+class ARROW_EXPORT CountOptions : public FunctionOptions {
+ public:
+ enum CountMode {
+ /// Count only non-null values.
+ NON_NULL = 0,
+ /// Count only null values.
+ NULLS,
+ /// Count both non-null and null values.
+ ALL,
+ };
+ explicit CountOptions(CountMode mode = CountMode::NON_NULL);
+ constexpr static char const kTypeName[] = "CountOptions";
+ static CountOptions Defaults() { return CountOptions{}; }
+
+ CountMode mode;
+};
+
/// \brief Control Mode kernel behavior
///
/// Returns top-n common values and counts.
@@ -139,9 +159,9 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions {
/// @}
-/// \brief Count non-null (or null) values in an array.
+/// \brief Count values in an array.
///
-/// \param[in] options counting options, see ScalarAggregateOptions for more information
+/// \param[in] options counting options, see CountOptions for more information
/// \param[in] datum to count
/// \param[in] ctx the function execution context, optional
/// \return out resulting datum
@@ -149,10 +169,9 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions {
/// \since 1.0.0
/// \note API not yet finalized
ARROW_EXPORT
-Result Count(
- const Datum& datum,
- const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(),
- ExecContext* ctx = NULLPTR);
+Result Count(const Datum& datum,
+ const CountOptions& options = CountOptions::Defaults(),
+ ExecContext* ctx = NULLPTR);
/// \brief Compute the mean of a numeric array.
///
diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc
index 7aca10ef0fa..16d3affe720 100644
--- a/cpp/src/arrow/compute/function_test.cc
+++ b/cpp/src/arrow/compute/function_test.cc
@@ -40,6 +40,8 @@ TEST(FunctionOptions, Equality) {
std::vector> options;
options.emplace_back(new ScalarAggregateOptions());
options.emplace_back(new ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/1));
+ options.emplace_back(new CountOptions());
+ options.emplace_back(new CountOptions(CountOptions::ALL));
options.emplace_back(new ModeOptions());
options.emplace_back(new ModeOptions(/*n=*/2));
options.emplace_back(new VarianceOptions());
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index bc92c203687..9bfb3cdde1a 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -57,7 +57,7 @@ namespace aggregate {
// Count implementation
struct CountImpl : public ScalarAggregator {
- explicit CountImpl(ScalarAggregateOptions options) : options(std::move(options)) {}
+ explicit CountImpl(CountOptions options) : options(std::move(options)) {}
Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_array()) {
@@ -82,15 +82,23 @@ struct CountImpl : public ScalarAggregator {
Status Finalize(KernelContext* ctx, Datum* out) override {
const auto& state = checked_cast(*ctx->state());
- if (state.options.skip_nulls) {
- *out = Datum(state.non_nulls);
- } else {
- *out = Datum(state.nulls);
+ switch (state.options.mode) {
+ case CountOptions::NON_NULL:
+ *out = Datum(state.non_nulls);
+ break;
+ case CountOptions::NULLS:
+ *out = Datum(state.nulls);
+ break;
+ case CountOptions::ALL:
+ *out = Datum(state.non_nulls + state.nulls);
+ break;
+ default:
+ DCHECK(false) << "unreachable";
}
return Status::OK();
}
- ScalarAggregateOptions options;
+ CountOptions options;
int64_t non_nulls = 0;
int64_t nulls = 0;
};
@@ -98,7 +106,7 @@ struct CountImpl : public ScalarAggregator {
Result> CountInit(KernelContext*,
const KernelInitArgs& args) {
return ::arrow::internal::make_unique(
- static_cast(*args.options));
+ static_cast(*args.options));
}
// ----------------------------------------------------------------------
@@ -562,7 +570,7 @@ const FunctionDoc count_doc{"Count the number of null / non-null values",
("By default, only non-null values are counted.\n"
"This can be changed through ScalarAggregateOptions."),
{"array"},
- "ScalarAggregateOptions"};
+ "CountOptions"};
const FunctionDoc sum_doc{
"Compute the sum of a numeric array",
@@ -623,17 +631,15 @@ const FunctionDoc index_doc{"Find the index of the first occurrence of a given v
void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
+ static auto default_count_options = CountOptions::Defaults();
auto func = std::make_shared(
- "count", Arity::Unary(), &count_doc, &default_scalar_aggregate_options);
+ "count", Arity::Unary(), &count_doc, &default_count_options);
- // Takes any array input, outputs int64 scalar
- InputType any_array(ValueDescr::ARRAY);
- AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())),
+ // Takes any input, outputs int64 scalar
+ InputType any_input;
+ AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())),
aggregate::CountInit, func.get());
- AddAggKernel(
- KernelSignature::Make({InputType(ValueDescr::SCALAR)}, ValueDescr::Scalar(int64())),
- aggregate::CountInit, func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
func = std::make_shared("sum", Arity::Unary(), &sum_doc,
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index db96bc3cad8..023651f695a 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -577,14 +577,18 @@ static CountPair NaiveCount(const Array& array) {
}
void ValidateCount(const Array& input, CountPair expected) {
- ScalarAggregateOptions all = ScalarAggregateOptions(/*skip_nulls=*/true);
- ScalarAggregateOptions nulls = ScalarAggregateOptions(/*skip_nulls=*/false);
+ CountOptions non_null;
+ CountOptions nulls(CountOptions::NULLS);
+ CountOptions all(CountOptions::ALL);
- ASSERT_OK_AND_ASSIGN(Datum result, Count(input, all));
+ ASSERT_OK_AND_ASSIGN(Datum result, Count(input, non_null));
AssertDatumsEqual(result, Datum(expected.first));
ASSERT_OK_AND_ASSIGN(result, Count(input, nulls));
AssertDatumsEqual(result, Datum(expected.second));
+
+ ASSERT_OK_AND_ASSIGN(result, Count(input, all));
+ AssertDatumsEqual(result, Datum(expected.first + expected.second));
}
template
@@ -608,11 +612,15 @@ TYPED_TEST(TestCountKernel, SimpleCount) {
auto ty = TypeTraits::type_singleton();
EXPECT_THAT(Count(MakeNullScalar(ty)), ResultWith(Datum(int64_t(0))));
- EXPECT_THAT(Count(MakeNullScalar(ty), ScalarAggregateOptions(/*skip_nulls=*/false)),
+ EXPECT_THAT(Count(MakeNullScalar(ty), CountOptions(CountOptions::NULLS)),
ResultWith(Datum(int64_t(1))));
EXPECT_THAT(Count(*MakeScalar(ty, 1)), ResultWith(Datum(int64_t(1))));
- EXPECT_THAT(Count(*MakeScalar(ty, 1), ScalarAggregateOptions(/*skip_nulls=*/false)),
+ EXPECT_THAT(Count(*MakeScalar(ty, 1), CountOptions(CountOptions::NULLS)),
ResultWith(Datum(int64_t(0))));
+
+ CountOptions all(CountOptions::ALL);
+ EXPECT_THAT(Count(MakeNullScalar(ty), all), ResultWith(Datum(int64_t(1))));
+ EXPECT_THAT(Count(*MakeScalar(ty, 1), all), ResultWith(Datum(int64_t(1))));
}
template
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index b3d602a89ac..472fe986a6e 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -829,7 +829,7 @@ Status AddHashAggKernels(
struct GroupedCountImpl : public GroupedAggregator {
Status Init(ExecContext* ctx, const FunctionOptions* options) override {
- options_ = checked_cast(*options);
+ options_ = checked_cast(*options);
counts_ = BufferBuilder(ctx->memory_pool());
return Status::OK();
}
@@ -859,25 +859,36 @@ struct GroupedCountImpl : public GroupedAggregator {
const auto& input = batch[0].array();
- if (options_.skip_nulls) {
- auto g_begin =
- reinterpret_cast(batch[1].array()->buffers[1]->data());
-
- arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
- input->length,
- [&](int64_t offset, int64_t length) {
- auto g = g_begin + offset;
- for (int64_t i = 0; i < length; ++i, ++g) {
- counts[*g] += 1;
- }
- });
- } else if (input->MayHaveNulls()) {
- auto g = batch[1].array()->GetValues(1);
-
- auto end = input->offset + input->length;
- for (int64_t i = input->offset; i < end; ++i, ++g) {
- counts[*g] += !BitUtil::GetBit(input->buffers[0]->data(), i);
+ auto g_begin = batch[1].array()->GetValues(1);
+ switch (options_.mode) {
+ case CountOptions::NON_NULL: {
+ arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
+ input->length,
+ [&](int64_t offset, int64_t length) {
+ auto g = g_begin + offset;
+ for (int64_t i = 0; i < length; ++i, ++g) {
+ counts[*g] += 1;
+ }
+ });
+ break;
+ }
+ case CountOptions::NULLS: {
+ if (input->MayHaveNulls()) {
+ auto end = input->offset + input->length;
+ for (int64_t i = input->offset; i < end; ++i, ++g_begin) {
+ counts[*g_begin] += !BitUtil::GetBit(input->buffers[0]->data(), i);
+ }
+ }
+ break;
}
+ case CountOptions::ALL: {
+ for (int64_t i = 0; i < batch.length; ++i, ++g_begin) {
+ counts[*g_begin] += 1;
+ }
+ break;
+ }
+ default:
+ DCHECK(false) << "unreachable";
}
return Status::OK();
}
@@ -890,7 +901,7 @@ struct GroupedCountImpl : public GroupedAggregator {
std::shared_ptr out_type() const override { return int64(); }
int64_t num_groups_ = 0;
- ScalarAggregateOptions options_;
+ CountOptions options_;
BufferBuilder counts_;
};
@@ -2082,7 +2093,7 @@ const FunctionDoc hash_count_doc{"Count the number of null / non-null values",
("By default, non-null values are counted.\n"
"This can be changed through ScalarAggregateOptions."),
{"array", "group_id_array"},
- "ScalarAggregateOptions"};
+ "CountOptions"};
const FunctionDoc hash_sum_doc{"Sum values of a numeric array",
("Null values are ignored."),
@@ -2140,9 +2151,9 @@ const FunctionDoc hash_all_doc{"Test whether all elements evaluate to true",
void RegisterHashAggregateBasic(FunctionRegistry* registry) {
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
{
+ static auto default_count_options = CountOptions::Defaults();
auto func = std::make_shared(
- "hash_count", Arity::Binary(), &hash_count_doc,
- &default_scalar_aggregate_options);
+ "hash_count", Arity::Binary(), &hash_count_doc, &default_count_options);
DCHECK_OK(func->AddKernel(
MakeKernel(ValueDescr::ARRAY, HashAggregateInit)));
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index cb9d05e0d35..1e04752fb3c 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -1077,7 +1077,9 @@ TEST(GroupBy, CountAndSum) {
[null, 3]
])");
- ScalarAggregateOptions count_options;
+ CountOptions count_options;
+ CountOptions count_nulls(CountOptions::NULLS);
+ CountOptions count_all(CountOptions::ALL);
ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(
Datum aggregated_and_grouped,
@@ -1087,6 +1089,8 @@ TEST(GroupBy, CountAndSum) {
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
+ batch->GetColumnByName("argument"),
batch->GetColumnByName("key"),
},
{
@@ -1094,6 +1098,8 @@ TEST(GroupBy, CountAndSum) {
},
{
{"hash_count", &count_options},
+ {"hash_count", &count_nulls},
+ {"hash_count", &count_all},
{"hash_sum", nullptr},
{"hash_sum", &min_count},
{"hash_sum", nullptr},
@@ -1101,6 +1107,8 @@ TEST(GroupBy, CountAndSum) {
AssertDatumsEqual(
ArrayFromJSON(struct_({
+ field("hash_count", int64()),
+ field("hash_count", int64()),
field("hash_count", int64()),
// NB: summing a float32 array results in float64 sums
field("hash_sum", float64()),
@@ -1109,10 +1117,10 @@ TEST(GroupBy, CountAndSum) {
field("key_0", int64()),
}),
R"([
- [2, 4.25, null, 3, 1],
- [3, -0.125, -0.125, 6, 2],
- [0, null, null, 6, 3],
- [2, 4.75, null, null, null]
+ [2, 1, 3, 4.25, null, 3, 1],
+ [3, 0, 3, -0.125, -0.125, 6, 2],
+ [0, 2, 2, null, null, 6, 3],
+ [2, 0, 2, 4.75, null, null, null]
])"),
aggregated_and_grouped,
/*verbose=*/true);
@@ -1248,13 +1256,14 @@ TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
])");
ScalarAggregateOptions keepna{false, 1};
- ScalarAggregateOptions skipna{true, 1};
+ CountOptions nulls(CountOptions::NULLS);
+ CountOptions non_null(CountOptions::NON_NULL);
using internal::Aggregate;
for (auto agg : {
Aggregate{"hash_sum", nullptr},
- Aggregate{"hash_count", &skipna},
- Aggregate{"hash_count", &keepna},
+ Aggregate{"hash_count", &non_null},
+ Aggregate{"hash_count", &nulls},
Aggregate{"hash_min_max", nullptr},
Aggregate{"hash_min_max", &keepna},
}) {
@@ -1273,7 +1282,7 @@ TEST(GroupBy, CountNull) {
[3.0, "gama"]
])");
- ScalarAggregateOptions keepna{false}, skipna{true};
+ CountOptions keepna{CountOptions::NULLS}, skipna{CountOptions::NON_NULL};
using internal::Aggregate;
for (auto agg : {
@@ -1323,7 +1332,6 @@ TEST(GroupBy, WithChunkedArray) {
{"argument": 0.75, "key": null},
{"argument": null, "key": 3}
])"});
- ScalarAggregateOptions count_options;
ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped,
internal::GroupBy(
{
@@ -1335,7 +1343,7 @@ TEST(GroupBy, WithChunkedArray) {
table->GetColumnByName("key"),
},
{
- {"hash_count", &count_options},
+ {"hash_count", nullptr},
{"hash_sum", nullptr},
{"hash_min_max", nullptr},
}));
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 58321457ab0..5d4817bf079 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -190,23 +190,23 @@ Aggregations
+---------------+-------+-------------+----------------+----------------------------------+-------+
| any | Unary | Boolean | Scalar Boolean | :struct:`ScalarAggregateOptions` | \(1) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
-| count | Unary | Any | Scalar Int64 | :struct:`ScalarAggregateOptions` | |
+| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | |
+---------------+-------+-------------+----------------+----------------------------------+-------+
| mean | Unary | Numeric | Scalar Float64 | :struct:`ScalarAggregateOptions` | |
+---------------+-------+-------------+----------------+----------------------------------+-------+
-| min_max | Unary | Numeric | Scalar Struct | :struct:`ScalarAggregateOptions` | \(2) |
+| min_max | Unary | Numeric | Scalar Struct | :struct:`ScalarAggregateOptions` | \(3) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
-| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(3) |
+| mode | Unary | Numeric | Struct | :struct:`ModeOptions` | \(4) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
-| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(4) |
+| product | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
-| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(5) |
+| quantile | Unary | Numeric | Scalar Numeric | :struct:`QuantileOptions` | \(6) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
| stddev | Unary | Numeric | Scalar Float64 | :struct:`VarianceOptions` | |
+---------------+-------+-------------+----------------+----------------------------------+-------+
-| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(4) |
+| sum | Unary | Numeric | Scalar Numeric | :struct:`ScalarAggregateOptions` | \(5) |
+---------------+-------+-------------+----------------+----------------------------------+-------+
| tdigest | Unary | Numeric | Scalar Float64 | :struct:`TDigestOptions` | |
+---------------+-------+-------------+----------------+----------------------------------+-------+
@@ -218,18 +218,21 @@ Notes:
* \(1) If null values are taken into account by setting ScalarAggregateOptions
parameter skip_nulls = false then `Kleene logic`_ logic is applied.
-* \(2) Output is a ``{"min": input type, "max": input type}`` Struct.
+* \(2) CountMode controls whether only non-null values are counted (the
+ default), only null values are counted, or all values are counted.
-* \(3) Output is an array of ``{"mode": input type, "count": Int64}`` Struct.
+* \(3) Output is a ``{"min": input type, "max": input type}`` Struct.
+
+* \(4) Output is an array of ``{"mode": input type, "count": Int64}`` Struct.
It contains the *N* most common elements in the input, in descending
order, where *N* is given in :member:`ModeOptions::n`.
If two values have the same count, the smallest one comes first.
Note that the output can have less than *N* elements if the input has
less than *N* distinct values.
-* \(4) Output is Int64, UInt64 or Float64, depending on the input type.
+* \(5) Output is Int64, UInt64 or Float64, depending on the input type.
-* \(5) Output is Float64 or input type, depending on QuantileOptions.
+* \(6) Output is Float64 or input type, depending on QuantileOptions.
Element-wise ("scalar") functions
---------------------------------
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 46cfdc4e2ef..530bccbe720 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -887,6 +887,23 @@ class ScalarAggregateOptions(_ScalarAggregateOptions):
self._set_options(skip_nulls, min_count)
+cdef class _CountOptions(FunctionOptions):
+ def _set_options(self, mode):
+ if mode == 'non_null':
+ self.wrapped.reset(new CCountOptions(CCountMode_NON_NULL))
+ elif mode == 'nulls':
+ self.wrapped.reset(new CCountOptions(CCountMode_NULLS))
+ elif mode == 'all':
+ self.wrapped.reset(new CCountOptions(CCountMode_ALL))
+ else:
+ raise ValueError(f'"{mode}" is not a valid mode')
+
+
+class CountOptions(_CountOptions):
+ def __init__(self, mode='non_null'):
+ self._set_options(mode)
+
+
cdef class _IndexOptions(FunctionOptions):
def _set_options(self, Scalar scalar):
self.wrapped.reset(new CIndexOptions(pyarrow_unwrap_scalar(scalar)))
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 85f637fce5a..10880c2974c 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -31,6 +31,7 @@
# Option classes
ArraySortOptions,
CastOptions,
+ CountOptions,
DictionaryEncodeOptions,
ElementWiseAggregateOptions,
ExtractRegexOptions,
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 34d81bce04b..6e9f0f5a7b1 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1959,6 +1959,16 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
c_bool skip_nulls
uint32_t min_count
+ enum CCountMode "arrow::compute::CountOptions::CountMode":
+ CCountMode_NON_NULL "arrow::compute::CountOptions::NON_NULL"
+ CCountMode_NULLS "arrow::compute::CountOptions::NULLS"
+ CCountMode_ALL "arrow::compute::CountOptions::ALL"
+
+ cdef cppclass CCountOptions \
+ "arrow::compute::CountOptions"(CFunctionOptions):
+ CCountOptions(CCountMode mode)
+ CCountMode mode
+
cdef cppclass CModeOptions \
"arrow::compute::ModeOptions"(CFunctionOptions):
CModeOptions(int64_t n)
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index b0baa76e50a..5ceb7337341 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -1439,11 +1439,9 @@ def test_extract_datetime_components():
def test_count():
arr = pa.array([1, 2, 3, None, None])
assert pc.count(arr).as_py() == 3
- assert pc.count(arr, skip_nulls=True).as_py() == 3
- assert pc.count(arr, skip_nulls=False).as_py() == 2
-
- with pytest.raises(TypeError, match="an integer is required"):
- pc.count(arr, min_count='zzz')
+ assert pc.count(arr, mode='non_null').as_py() == 3
+ assert pc.count(arr, mode='nulls').as_py() == 2
+ assert pc.count(arr, mode='all').as_py() == 5
def test_index():
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 30821137383..2a0f7907ca6 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -172,7 +172,7 @@ std::shared_ptr make_compute_options(
}
if (func_name == "min_max" || func_name == "sum" || func_name == "mean" ||
- func_name == "count" || func_name == "any" || func_name == "all") {
+ func_name == "any" || func_name == "all") {
using Options = arrow::compute::ScalarAggregateOptions;
auto out = std::make_shared(Options::Defaults());
out->min_count = cpp11::as_cpp(options["na.min_count"]);
@@ -180,6 +180,14 @@ std::shared_ptr make_compute_options(
return out;
}
+ if (func_name == "count") {
+ using Options = arrow::compute::CountOptions;
+ auto out = std::make_shared(Options::Defaults());
+ out->mode =
+ cpp11::as_cpp(options["na.rm"]) ? Options::NON_NULL : Options::NULLS;
+ return out;
+ }
+
if (func_name == "min_element_wise" || func_name == "max_element_wise") {
using Options = arrow::compute::ElementWiseAggregateOptions;
bool skip_nulls = true;
From 84a48538e7fd747849f2f4e7e85112dbc350c800 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 10 Aug 2021 14:51:19 -0400
Subject: [PATCH 2/2] ARROW-13574: [C++] Rename CountMode enum values
---
c_glib/arrow-glib/compute.h | 8 ++++----
c_glib/test/test-count.rb | 4 ++--
cpp/src/arrow/compute/api_aggregate.cc | 8 ++++----
cpp/src/arrow/compute/api_aggregate.h | 6 +++---
cpp/src/arrow/compute/kernels/aggregate_basic.cc | 16 +++++++++-------
cpp/src/arrow/compute/kernels/aggregate_test.cc | 6 +++---
cpp/src/arrow/compute/kernels/hash_aggregate.cc | 4 ++--
.../arrow/compute/kernels/hash_aggregate_test.cc | 8 ++++----
python/pyarrow/_compute.pyx | 10 +++++-----
python/pyarrow/includes/libarrow.pxd | 4 ++--
python/pyarrow/tests/test_compute.py | 4 ++--
r/src/compute.cpp | 2 +-
12 files changed, 41 insertions(+), 39 deletions(-)
diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h
index c74d28d17e1..fe2e60181ce 100644
--- a/c_glib/arrow-glib/compute.h
+++ b/c_glib/arrow-glib/compute.h
@@ -100,9 +100,9 @@ garrow_scalar_aggregate_options_new(void);
/**
* GArrowCountMode:
- * @GARROW_COUNT_MODE_NON_NULL:
+ * @GARROW_COUNT_MODE_ONLY_VALID:
* Only non-null values will be counted.
- * @GARROW_COUNT_MODE_NULLS:
+ * @GARROW_COUNT_MODE_ONLY_NULL:
* Only null values will be counted.
* @GARROW_COUNT_MODE_ALL:
* All will be counted.
@@ -110,8 +110,8 @@ garrow_scalar_aggregate_options_new(void);
* They correspond to the values of `arrow::compute::CountOptions::CountMode`.
*/
typedef enum {
- GARROW_COUNT_MODE_NON_NULL,
- GARROW_COUNT_MODE_NULLS,
+ GARROW_COUNT_MODE_ONLY_VALID,
+ GARROW_COUNT_MODE_ONLY_NULL,
GARROW_COUNT_MODE_ALL,
} GArrowCountMode;
diff --git a/c_glib/test/test-count.rb b/c_glib/test/test-count.rb
index 55804a1165b..6e94219143b 100644
--- a/c_glib/test/test-count.rb
+++ b/c_glib/test/test-count.rb
@@ -24,13 +24,13 @@ def test_default
assert_equal(2, build_int32_array([1, nil, 3]).count)
options = Arrow::CountOptions.new
- options.mode = Arrow::CountMode::NON_NULL
+ options.mode = Arrow::CountMode::ONLY_VALID
assert_equal(2, build_int32_array([1, nil, 3]).count(options))
end
def test_nulls
options = Arrow::CountOptions.new
- options.mode = Arrow::CountMode::NULLS
+ options.mode = Arrow::CountMode::ONLY_NULL
assert_equal(1, build_int32_array([1, nil, 3]).count(options))
end
diff --git a/cpp/src/arrow/compute/api_aggregate.cc b/cpp/src/arrow/compute/api_aggregate.cc
index a165f966e5a..af7aec865fc 100644
--- a/cpp/src/arrow/compute/api_aggregate.cc
+++ b/cpp/src/arrow/compute/api_aggregate.cc
@@ -29,14 +29,14 @@ namespace arrow {
namespace internal {
template <>
struct EnumTraits
- : BasicEnumTraits {
+ : BasicEnumTraits {
static std::string name() { return "CountOptions::CountMode"; }
static std::string value_name(compute::CountOptions::CountMode value) {
switch (value) {
- case compute::CountOptions::NON_NULL:
+ case compute::CountOptions::ONLY_VALID:
return "NON_NULL";
- case compute::CountOptions::NULLS:
+ case compute::CountOptions::ONLY_NULL:
return "NULLS";
case compute::CountOptions::ALL:
return "ALL";
diff --git a/cpp/src/arrow/compute/api_aggregate.h b/cpp/src/arrow/compute/api_aggregate.h
index 667b057a1f1..880424e97f8 100644
--- a/cpp/src/arrow/compute/api_aggregate.h
+++ b/cpp/src/arrow/compute/api_aggregate.h
@@ -60,13 +60,13 @@ class ARROW_EXPORT CountOptions : public FunctionOptions {
public:
enum CountMode {
/// Count only non-null values.
- NON_NULL = 0,
+ ONLY_VALID = 0,
/// Count only null values.
- NULLS,
+ ONLY_NULL,
/// Count both non-null and null values.
ALL,
};
- explicit CountOptions(CountMode mode = CountMode::NON_NULL);
+ explicit CountOptions(CountMode mode = CountMode::ONLY_VALID);
constexpr static char const kTypeName[] = "CountOptions";
static CountOptions Defaults() { return CountOptions{}; }
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index 9bfb3cdde1a..d65e0ef506e 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -60,7 +60,9 @@ struct CountImpl : public ScalarAggregator {
explicit CountImpl(CountOptions options) : options(std::move(options)) {}
Status Consume(KernelContext*, const ExecBatch& batch) override {
- if (batch[0].is_array()) {
+ if (options.mode == CountOptions::ALL) {
+ this->non_nulls += batch.length;
+ } else if (batch[0].is_array()) {
const ArrayData& input = *batch[0].array();
const int64_t nulls = input.GetNullCount();
this->nulls += nulls;
@@ -83,15 +85,15 @@ struct CountImpl : public ScalarAggregator {
Status Finalize(KernelContext* ctx, Datum* out) override {
const auto& state = checked_cast(*ctx->state());
switch (state.options.mode) {
- case CountOptions::NON_NULL:
+ case CountOptions::ONLY_VALID:
+ case CountOptions::ALL:
+ // ALL is equivalent since we don't count the null/non-null
+ // separately to avoid potentially computing null count
*out = Datum(state.non_nulls);
break;
- case CountOptions::NULLS:
+ case CountOptions::ONLY_NULL:
*out = Datum(state.nulls);
break;
- case CountOptions::ALL:
- *out = Datum(state.non_nulls + state.nulls);
- break;
default:
DCHECK(false) << "unreachable";
}
@@ -568,7 +570,7 @@ namespace {
const FunctionDoc count_doc{"Count the number of null / non-null values",
("By default, only non-null values are counted.\n"
- "This can be changed through ScalarAggregateOptions."),
+ "This can be changed through CountOptions."),
{"array"},
"CountOptions"};
diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc
index 023651f695a..8aee3049970 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc
@@ -578,7 +578,7 @@ static CountPair NaiveCount(const Array& array) {
void ValidateCount(const Array& input, CountPair expected) {
CountOptions non_null;
- CountOptions nulls(CountOptions::NULLS);
+ CountOptions nulls(CountOptions::ONLY_NULL);
CountOptions all(CountOptions::ALL);
ASSERT_OK_AND_ASSIGN(Datum result, Count(input, non_null));
@@ -612,10 +612,10 @@ TYPED_TEST(TestCountKernel, SimpleCount) {
auto ty = TypeTraits::type_singleton();
EXPECT_THAT(Count(MakeNullScalar(ty)), ResultWith(Datum(int64_t(0))));
- EXPECT_THAT(Count(MakeNullScalar(ty), CountOptions(CountOptions::NULLS)),
+ EXPECT_THAT(Count(MakeNullScalar(ty), CountOptions(CountOptions::ONLY_NULL)),
ResultWith(Datum(int64_t(1))));
EXPECT_THAT(Count(*MakeScalar(ty, 1)), ResultWith(Datum(int64_t(1))));
- EXPECT_THAT(Count(*MakeScalar(ty, 1), CountOptions(CountOptions::NULLS)),
+ EXPECT_THAT(Count(*MakeScalar(ty, 1), CountOptions(CountOptions::ONLY_NULL)),
ResultWith(Datum(int64_t(0))));
CountOptions all(CountOptions::ALL);
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
index 472fe986a6e..5774f93750c 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc
@@ -861,7 +861,7 @@ struct GroupedCountImpl : public GroupedAggregator {
auto g_begin = batch[1].array()->GetValues(1);
switch (options_.mode) {
- case CountOptions::NON_NULL: {
+ case CountOptions::ONLY_VALID: {
arrow::internal::VisitSetBitRunsVoid(input->buffers[0], input->offset,
input->length,
[&](int64_t offset, int64_t length) {
@@ -872,7 +872,7 @@ struct GroupedCountImpl : public GroupedAggregator {
});
break;
}
- case CountOptions::NULLS: {
+ case CountOptions::ONLY_NULL: {
if (input->MayHaveNulls()) {
auto end = input->offset + input->length;
for (int64_t i = input->offset; i < end; ++i, ++g_begin) {
diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
index 1e04752fb3c..ca9a8241305 100644
--- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
+++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
@@ -1078,7 +1078,7 @@ TEST(GroupBy, CountAndSum) {
])");
CountOptions count_options;
- CountOptions count_nulls(CountOptions::NULLS);
+ CountOptions count_nulls(CountOptions::ONLY_NULL);
CountOptions count_all(CountOptions::ALL);
ScalarAggregateOptions min_count(/*skip_nulls=*/true, /*min_count=*/3);
ASSERT_OK_AND_ASSIGN(
@@ -1256,8 +1256,8 @@ TEST(GroupBy, ConcreteCaseWithValidateGroupBy) {
])");
ScalarAggregateOptions keepna{false, 1};
- CountOptions nulls(CountOptions::NULLS);
- CountOptions non_null(CountOptions::NON_NULL);
+ CountOptions nulls(CountOptions::ONLY_NULL);
+ CountOptions non_null(CountOptions::ONLY_VALID);
using internal::Aggregate;
for (auto agg : {
@@ -1282,7 +1282,7 @@ TEST(GroupBy, CountNull) {
[3.0, "gama"]
])");
- CountOptions keepna{CountOptions::NULLS}, skipna{CountOptions::NON_NULL};
+ CountOptions keepna{CountOptions::ONLY_NULL}, skipna{CountOptions::ONLY_VALID};
using internal::Aggregate;
for (auto agg : {
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 530bccbe720..2d82928f377 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -889,10 +889,10 @@ class ScalarAggregateOptions(_ScalarAggregateOptions):
cdef class _CountOptions(FunctionOptions):
def _set_options(self, mode):
- if mode == 'non_null':
- self.wrapped.reset(new CCountOptions(CCountMode_NON_NULL))
- elif mode == 'nulls':
- self.wrapped.reset(new CCountOptions(CCountMode_NULLS))
+ if mode == 'only_valid':
+ self.wrapped.reset(new CCountOptions(CCountMode_ONLY_VALID))
+ elif mode == 'only_null':
+ self.wrapped.reset(new CCountOptions(CCountMode_ONLY_NULL))
elif mode == 'all':
self.wrapped.reset(new CCountOptions(CCountMode_ALL))
else:
@@ -900,7 +900,7 @@ cdef class _CountOptions(FunctionOptions):
class CountOptions(_CountOptions):
- def __init__(self, mode='non_null'):
+ def __init__(self, mode='only_valid'):
self._set_options(mode)
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 6e9f0f5a7b1..7dcde652a95 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -1960,8 +1960,8 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
uint32_t min_count
enum CCountMode "arrow::compute::CountOptions::CountMode":
- CCountMode_NON_NULL "arrow::compute::CountOptions::NON_NULL"
- CCountMode_NULLS "arrow::compute::CountOptions::NULLS"
+ CCountMode_ONLY_VALID "arrow::compute::CountOptions::ONLY_VALID"
+ CCountMode_ONLY_NULL "arrow::compute::CountOptions::ONLY_NULL"
CCountMode_ALL "arrow::compute::CountOptions::ALL"
cdef cppclass CCountOptions \
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 5ceb7337341..60a2f60f942 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -1439,8 +1439,8 @@ def test_extract_datetime_components():
def test_count():
arr = pa.array([1, 2, 3, None, None])
assert pc.count(arr).as_py() == 3
- assert pc.count(arr, mode='non_null').as_py() == 3
- assert pc.count(arr, mode='nulls').as_py() == 2
+ assert pc.count(arr, mode='only_valid').as_py() == 3
+ assert pc.count(arr, mode='only_null').as_py() == 2
assert pc.count(arr, mode='all').as_py() == 5
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 2a0f7907ca6..0695e2525f7 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -184,7 +184,7 @@ std::shared_ptr make_compute_options(
using Options = arrow::compute::CountOptions;
auto out = std::make_shared(Options::Defaults());
out->mode =
- cpp11::as_cpp(options["na.rm"]) ? Options::NON_NULL : Options::NULLS;
+ cpp11::as_cpp(options["na.rm"]) ? Options::ONLY_VALID : Options::ONLY_NULL;
return out;
}