Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ if(ARROW_COMPUTE)
compute/kernels/aggregate.cc
compute/kernels/boolean.cc
compute/kernels/cast.cc
compute/kernels/count.cc
compute/kernels/hash.cc
compute/kernels/mean.cc
compute/kernels/sum.cc
Expand Down
13 changes: 13 additions & 0 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ struct ARROW_EXPORT Datum {
Datum(const std::shared_ptr<T>& value) // NOLINT implicit conversion
: Datum(std::shared_ptr<Array>(value)) {}

// Convenience constructors
explicit Datum(bool value) : value(std::make_shared<BooleanScalar>(value)) {}
explicit Datum(int8_t value) : value(std::make_shared<Int8Scalar>(value)) {}
explicit Datum(uint8_t value) : value(std::make_shared<UInt8Scalar>(value)) {}
explicit Datum(int16_t value) : value(std::make_shared<Int16Scalar>(value)) {}
explicit Datum(uint16_t value) : value(std::make_shared<UInt16Scalar>(value)) {}
explicit Datum(int32_t value) : value(std::make_shared<Int32Scalar>(value)) {}
explicit Datum(uint32_t value) : value(std::make_shared<UInt32Scalar>(value)) {}
explicit Datum(int64_t value) : value(std::make_shared<Int64Scalar>(value)) {}
explicit Datum(uint64_t value) : value(std::make_shared<UInt64Scalar>(value)) {}
explicit Datum(float value) : value(std::make_shared<FloatScalar>(value)) {}
explicit Datum(double value) : value(std::make_shared<DoubleScalar>(value)) {}

~Datum() {}

Datum(const Datum& other) noexcept { this->value = other.value; }
Expand Down
75 changes: 75 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "arrow/array.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/count.h"
#include "arrow/compute/kernels/mean.h"
#include "arrow/compute/kernels/sum-internal.h"
#include "arrow/compute/kernels/sum.h"
Expand All @@ -43,6 +44,10 @@ using std::vector;
namespace arrow {
namespace compute {

///
/// Sum
///

template <typename ArrowType>
using SumResult =
std::pair<typename FindAccumulatorType<ArrowType>::Type::c_type, size_t>;
Expand Down Expand Up @@ -162,6 +167,10 @@ TYPED_TEST(TestRandomNumericSumKernel, RandomSliceArraySum) {
}
}

///
/// Mean
///

template <typename ArrowType>
static Datum NaiveMean(const Array& array) {
using MeanScalarType = typename TypeTraits<DoubleType>::ScalarType;
Expand Down Expand Up @@ -237,5 +246,71 @@ TYPED_TEST(TestRandomNumericMeanKernel, RandomArrayMean) {
}
}

///
/// Count
///
//
using CountPair = std::pair<int64_t, int64_t>;

static CountPair NaiveCount(const Array& array) {
CountPair count;

count.first = array.length() - array.null_count();
count.second = array.null_count();

return count;
}

void ValidateCount(FunctionContext* ctx, const Array& input, CountPair expected) {
CountOptions all = CountOptions(CountOptions::COUNT_ALL);
CountOptions nulls = CountOptions(CountOptions::COUNT_NULL);
Datum result;

ASSERT_OK(Count(ctx, all, input, &result));
DatumEqual<Int64Type>::EnsureEqual(result, Datum(expected.first));

ASSERT_OK(Count(ctx, nulls, input, &result));
DatumEqual<Int64Type>::EnsureEqual(result, Datum(expected.second));
}

template <typename ArrowType>
void ValidateCount(FunctionContext* ctx, const char* json, CountPair expected) {
auto array = ArrayFromJSON(TypeTraits<ArrowType>::type_singleton(), json);
ValidateCount(ctx, *array, expected);
}

void ValidateCount(FunctionContext* ctx, const Array& input) {
ValidateCount(ctx, input, NaiveCount(input));
}

template <typename ArrowType>
class TestCountKernel : public ComputeFixture, public TestBase {};

TYPED_TEST_CASE(TestCountKernel, NumericArrowTypes);
TYPED_TEST(TestCountKernel, SimpleCount) {
ValidateCount<TypeParam>(&this->ctx_, "[]", {0, 0});
ValidateCount<TypeParam>(&this->ctx_, "[null]", {0, 1});
ValidateCount<TypeParam>(&this->ctx_, "[1, null, 2]", {2, 1});
ValidateCount<TypeParam>(&this->ctx_, "[null, null, null]", {0, 3});
ValidateCount<TypeParam>(&this->ctx_, "[1, 2, 3, 4, 5, 6, 7, 8, 9]", {9, 0});
}

template <typename ArrowType>
class TestRandomNumericCountKernel : public ComputeFixture, public TestBase {};

TYPED_TEST_CASE(TestRandomNumericCountKernel, NumericArrowTypes);
TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) {
auto rand = random::RandomArrayGenerator(0x1205643);
for (size_t i = 3; i < 14; i++) {
for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) {
for (auto length_adjust : {-2, -1, 0, 1, 2}) {
int64_t length = (1UL << i) + length_adjust;
auto array = rand.Numeric<TypeParam>(length, 0, 100, null_probability);
ValidateCount(&this->ctx_, *array);
}
}
}
}

} // namespace compute
} // namespace arrow
115 changes: 115 additions & 0 deletions cpp/src/arrow/compute/kernels/count.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/kernels/count.h"

#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/aggregate.h"

namespace arrow {
namespace compute {

struct CountState {
CountState() : non_nulls(0), nulls(0) {}
CountState(int64_t non_nulls, int64_t nulls) : non_nulls(non_nulls), nulls(nulls) {}

CountState operator+(const CountState& rhs) const {
return CountState(this->non_nulls + rhs.non_nulls, this->nulls + rhs.nulls);
}

CountState& operator+=(const CountState& rhs) {
this->non_nulls += rhs.non_nulls;
this->nulls += rhs.nulls;
return *this;
}

std::shared_ptr<Scalar> NonNullsAsScalar() const {
using ScalarType = typename CTypeTraits<int64_t>::ScalarType;
return std::make_shared<ScalarType>(non_nulls);
}

std::shared_ptr<Scalar> NullsAsScalar() const {
using ScalarType = typename CTypeTraits<int64_t>::ScalarType;
return std::make_shared<ScalarType>(nulls);
}

int64_t non_nulls = 0;
int64_t nulls = 0;
};

class CountAggregateFunction final : public AggregateFunctionStaticState<CountState> {
public:
explicit CountAggregateFunction(const CountOptions& options) : options_(options) {}

Status Consume(const Array& input, CountState* state) const override {
const int64_t length = input.length();
const int64_t nulls = input.null_count();

state->nulls = nulls;
state->non_nulls = length - nulls;

return Status::OK();
}

Status Merge(const CountState& src, CountState* dst) const override {
*dst += src;
return Status::OK();
}

Status Finalize(const CountState& src, Datum* output) const override {
switch (options_.count_mode) {
case CountOptions::COUNT_ALL:
*output = src.NonNullsAsScalar();
break;
case CountOptions::COUNT_NULL:
*output = src.NullsAsScalar();
break;
default:
return Status::Invalid("Unknown CountOptions encountered");
}

return Status::OK();
}

std::shared_ptr<DataType> out_type() const override { return int64(); }

private:
CountOptions options_;
};

std::shared_ptr<AggregateFunction> MakeCountAggregateFunction(
FunctionContext* context, const CountOptions& options) {
return std::make_shared<CountAggregateFunction>(options);
}

Status Count(FunctionContext* context, const CountOptions& options, const Datum& value,
Datum* out) {
if (!value.is_array()) return Status::Invalid("Count is expecting an array datum.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave TODO about evaluation on chunked arrays (we need to think about how to address this generally rather than having a bunch of boilerplate)?


auto aggregate = MakeCountAggregateFunction(context, options);
auto kernel = std::make_shared<AggregateUnaryKernel>(aggregate);

return kernel->Call(context, value, out);
}

Status Count(FunctionContext* context, const CountOptions& options, const Array& array,
Datum* out) {
return Count(context, options, array.data(), out);
}

} // namespace compute
} // namespace arrow
88 changes: 88 additions & 0 deletions cpp/src/arrow/compute/kernels/count.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <memory>
#include <type_traits>

#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/visibility.h"

namespace arrow {

class Array;
class DataType;

namespace compute {

struct Datum;
class FunctionContext;
class AggregateFunction;

/// \class CountOptions
///
/// The user control the Count kernel behavior with this class. By default, the
/// it will count all non-null values.
struct ARROW_EXPORT CountOptions {
enum mode {
// Count all non-null values.
COUNT_ALL = 0,
// Count all null values.
COUNT_NULL,
};

explicit CountOptions(enum mode count_mode) : count_mode(count_mode) {}

enum mode count_mode = COUNT_ALL;
};

/// \brief Return Count function aggregate
ARROW_EXPORT
std::shared_ptr<AggregateFunction> MakeCount(FunctionContext* context,
const CountOptions& options);

/// \brief Count non-null (or null) values in an array.
///
/// \param[in] context the FunctionContext
/// \param[in] options counting options, see CountOptions for more information
/// \param[in] datum to count
/// \param[out] out resulting datum
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Count(FunctionContext* context, const CountOptions& options, const Datum& datum,
Datum* out);

/// \brief Count non-null (or null) values in an array.
///
/// \param[in] context the FunctionContext
/// \param[in] options counting options, see CountOptions for more information
/// \param[in] array to count
/// \param[out] out resulting datum
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Count(FunctionContext* context, const CountOptions& options, const Array& array,
Datum* out);

} // namespace compute
} // namespace arrow