Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,32 @@ class LargeBinaryLikeMatcher : public TypeMatcher {
std::string ToString() const override { return "large-binary-like"; }
};

class FixedSizeBinaryLikeMatcher : public TypeMatcher {
public:
FixedSizeBinaryLikeMatcher() {}

bool Matches(const DataType& type) const override {
return is_fixed_size_binary(type.id());
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
auto casted = dynamic_cast<const FixedSizeBinaryLikeMatcher*>(&other);
return casted != nullptr;
}
std::string ToString() const override { return "fixed-size-binary-like"; }
};

std::shared_ptr<TypeMatcher> LargeBinaryLike() {
return std::make_shared<LargeBinaryLikeMatcher>();
}

std::shared_ptr<TypeMatcher> FixedSizeBinaryLike() {
return std::make_shared<FixedSizeBinaryLikeMatcher>();
}

} // namespace match

// ----------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ ARROW_EXPORT std::shared_ptr<TypeMatcher> BinaryLike();
// Match types using 64-bit varbinary representation
ARROW_EXPORT std::shared_ptr<TypeMatcher> LargeBinaryLike();

// Match any fixed binary type
ARROW_EXPORT std::shared_ptr<TypeMatcher> FixedSizeBinaryLike();

// \brief Match any primitive type (boolean or any type representable as a C
// Type)
ARROW_EXPORT std::shared_ptr<TypeMatcher> Primitive();
Expand Down
260 changes: 188 additions & 72 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace arrow {
namespace compute {
namespace aggregate {
namespace internal {

// ----------------------------------------------------------------------
// Sum implementation
Expand Down Expand Up @@ -83,6 +83,6 @@ void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func) {
AddMinMaxKernel(MinMaxInitAvx2, Type::INTERVAL_MONTHS, func, SimdLevel::AVX2);
}

} // namespace aggregate
} // namespace internal
} // namespace compute
} // namespace arrow
7 changes: 3 additions & 4 deletions cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace arrow {
namespace compute {
namespace aggregate {
namespace internal {

// ----------------------------------------------------------------------
// Sum implementation
Expand Down Expand Up @@ -72,8 +72,7 @@ void AddSumAvx512AggKernels(ScalarAggregateFunction* func) {
}

void AddMeanAvx512AggKernels(ScalarAggregateFunction* func) {
aggregate::AddBasicAggKernels(MeanInitAvx512, NumericTypes(), float64(), func,
SimdLevel::AVX512);
AddBasicAggKernels(MeanInitAvx512, NumericTypes(), float64(), func, SimdLevel::AVX512);
}

void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) {
Expand All @@ -86,6 +85,6 @@ void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) {
AddMinMaxKernel(MinMaxInitAvx512, Type::INTERVAL_MONTHS, func, SimdLevel::AVX512);
}

} // namespace aggregate
} // namespace internal
} // namespace compute
} // namespace arrow
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

namespace arrow {
namespace compute {
namespace aggregate {
namespace internal {

void AddBasicAggKernels(KernelInit init,
const std::vector<std::shared_ptr<DataType>>& types,
Expand Down Expand Up @@ -83,7 +83,7 @@ struct SumImpl : public ScalarAggregator {
if (is_boolean_type<ArrowType>::value) {
this->sum += static_cast<SumCType>(BooleanArray(data).true_count());
} else {
this->sum += arrow::compute::detail::SumArray<CType, SumCType, SimdLevel>(*data);
this->sum += SumArray<CType, SumCType, SimdLevel>(*data);
}
} else {
const auto& data = *batch[0].scalar();
Expand Down Expand Up @@ -621,6 +621,6 @@ struct MinMaxInitState {
}
};

} // namespace aggregate
} // namespace internal
} // namespace compute
} // namespace arrow
12 changes: 6 additions & 6 deletions cpp/src/arrow/compute/kernels/aggregate_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

namespace arrow {
namespace compute {
namespace internal {

// Find the largest compatible primitive type for a primitive type.
template <typename I, typename Enable = void>
Expand Down Expand Up @@ -110,10 +111,6 @@ void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFinalize finalize, ScalarAggregateFunction* func,
SimdLevel::type simd_level = SimdLevel::NONE);

namespace detail {

using arrow::internal::VisitSetBitRunsVoid;

// SumArray must be parameterized with the SIMD level since it's called both from
// translation units with and without vectorization. Normally it gets inlined but
// if not, without the parameter, we'll have multiple definitions of the same
Expand All @@ -125,6 +122,8 @@ template <typename ValueType, typename SumType, SimdLevel::type SimdLevel,
typename ValueFunc>
enable_if_t<std::is_floating_point<SumType>::value, SumType> SumArray(
const ArrayData& data, ValueFunc&& func) {
using arrow::internal::VisitSetBitRunsVoid;

const int64_t data_size = data.length - data.GetNullCount();
if (data_size == 0) {
return 0;
Expand Down Expand Up @@ -200,6 +199,8 @@ template <typename ValueType, typename SumType, SimdLevel::type SimdLevel,
typename ValueFunc>
enable_if_t<!std::is_floating_point<SumType>::value, SumType> SumArray(
const ArrayData& data, ValueFunc&& func) {
using arrow::internal::VisitSetBitRunsVoid;

SumType sum = 0;
const ValueType* values = data.GetValues<ValueType>(1);
VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
Expand All @@ -217,7 +218,6 @@ SumType SumArray(const ArrayData& data) {
data, [](ValueType v) { return static_cast<SumType>(v); });
}

} // namespace detail

} // namespace internal
} // namespace compute
} // namespace arrow
145 changes: 145 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <memory>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <gtest/gtest.h>
Expand Down Expand Up @@ -53,6 +54,8 @@ using internal::checked_pointer_cast;

namespace compute {

using internal::FindAccumulatorType;

//
// Sum
//
Expand Down Expand Up @@ -873,6 +876,148 @@ TYPED_TEST(TestRandomNumericCountKernel, RandomArrayCount) {
}
}

//
// Count Distinct
//

class TestCountDistinctKernel : public ::testing::Test {
protected:
Datum Expected(int64_t value) { return MakeScalar(static_cast<int64_t>(value)); }

void Check(Datum input, int64_t expected_all, bool has_nulls = true) {
int64_t expected_valid = has_nulls ? expected_all - 1 : expected_all;
int64_t expected_null = has_nulls ? 1 : 0;
CheckScalar("count_distinct", {input}, Expected(expected_valid), &only_valid);
CheckScalar("count_distinct", {input}, Expected(expected_null), &only_null);
CheckScalar("count_distinct", {input}, Expected(expected_all), &all);
}

void Check(const std::shared_ptr<DataType>& type, util::string_view json,
int64_t expected_all, bool has_nulls = true) {
Check(ArrayFromJSON(type, json), expected_all, has_nulls);
}

void Check(const std::shared_ptr<DataType>& type, util::string_view json) {
auto input = ScalarFromJSON(type, json);
auto zero = ResultWith(Expected(0));
auto one = ResultWith(Expected(1));
// non null scalar
EXPECT_THAT(CallFunction("count_distinct", {input}, &only_valid), one);
EXPECT_THAT(CallFunction("count_distinct", {input}, &only_null), zero);
EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
// null scalar
input = MakeNullScalar(input->type);
EXPECT_THAT(CallFunction("count_distinct", {input}, &only_valid), zero);
EXPECT_THAT(CallFunction("count_distinct", {input}, &only_null), one);
EXPECT_THAT(CallFunction("count_distinct", {input}, &all), one);
}

CountOptions only_valid{CountOptions::ONLY_VALID};
CountOptions only_null{CountOptions::ONLY_NULL};
CountOptions all{CountOptions::ALL};
};

TEST_F(TestCountDistinctKernel, AllArrayTypesWithNulls) {
// Boolean
Check(boolean(), "[]", 0, /*has_nulls=*/false);
Check(boolean(), "[true, null, false, null, false, true]", 3);
// Number
for (auto ty : NumericTypes()) {
Check(ty, "[1, 1, null, 2, 5, 8, 9, 9, null, 10, 6, 6]", 8);
Check(ty, "[1, 1, 8, 2, 5, 8, 9, 9, 10, 10, 6, 6]", 7, /*has_nulls=*/false);
}
// Date
Check(date32(), "[0, 11016, 0, null, 14241, 14241, null]", 4);
Check(date64(), "[0, null, 0, null, 0, 0, 1262217600000]", 3);
// Time
Check(time32(TimeUnit::SECOND), "[0, 11, 0, null, 14, 14, null]", 4);
Check(time32(TimeUnit::MILLI), "[0, 11000, 0, null, 11000, 11000]", 3);
Check(time64(TimeUnit::MICRO), "[84203999999, 0, null, 84203999999, 0]", 3);
Check(time64(TimeUnit::NANO), "[11715003000000, 0, null, 0, 0]", 3);
// Timestamp & Duration
for (auto u : TimeUnit::values()) {
Check(duration(u), "[123456789, null, 987654321, 123456789, null]", 3);
Check(duration(u), "[123456789, 987654321, 123456789, 123456789]", 2,
/*has_nulls=*/false);
auto ts = R"(["2009-12-31T04:20:20", "2020-01-01", null, "2009-12-31T04:20:20"])";
Check(timestamp(u), ts, 3);
Check(timestamp(u, "Pacific/Marquesas"), ts, 3);
}
// Interval
Check(month_interval(), "[9012, 5678, null, 9012, 5678, null, 9012]", 3);
Check(day_time_interval(), "[[0, 1], [0, 1], null, [0, 1], [1234, 5678]]", 3);
Check(month_day_nano_interval(), "[[0, 1, 2], [0, 1, 2], null, [0, 1, 2]]", 2);
// Binary & String & Fixed binary
auto samples = R"([null, "abc", null, "abc", "abc", "cba", "bca", "cba", null])";
Check(binary(), samples, 4);
Check(large_binary(), samples, 4);
Check(utf8(), samples, 4);
Check(large_utf8(), samples, 4);
Check(fixed_size_binary(3), samples, 4);
// Decimal
samples = R"(["12345.679", "98765.421", null, "12345.679", "98765.421"])";
Check(decimal128(21, 3), samples, 3);
Check(decimal256(13, 3), samples, 3);
}

TEST_F(TestCountDistinctKernel, AllScalarTypesWithNulls) {
// Boolean
Check(boolean(), "true");
// Number
for (auto ty : NumericTypes()) {
Check(ty, "91");
}
// Date
Check(date32(), "11016");
Check(date64(), "1262217600000");
// Time
Check(time32(TimeUnit::SECOND), "14");
Check(time32(TimeUnit::MILLI), "11000");
Check(time64(TimeUnit::MICRO), "84203999999");
Check(time64(TimeUnit::NANO), "11715003000000");
// Timestamp & Duration
for (auto u : TimeUnit::values()) {
Check(duration(u), "987654321");
Check(duration(u), "123456789");
auto ts = R"("2009-12-31T04:20:20")";
Check(timestamp(u), ts);
Check(timestamp(u, "Pacific/Marquesas"), ts);
}
// Interval
Check(month_interval(), "5678");
Check(day_time_interval(), "[1234, 5678]");
Check(month_day_nano_interval(), "[0, 1, 2]");
// Binary & String & Fixed binary
auto sample = R"("cba")";
Check(binary(), sample);
Check(large_binary(), sample);
Check(utf8(), sample);
Check(large_utf8(), sample);
Check(fixed_size_binary(3), sample);
// Decimal
sample = R"("98765.421")";
Check(decimal128(21, 3), sample);
Check(decimal256(13, 3), sample);
}

TEST_F(TestCountDistinctKernel, Random) {
UInt32Builder builder;
std::unordered_set<uint32_t> memo;
auto visit_null = []() { return Status::OK(); };
auto visit_value = [&](uint32_t arg) {
const bool inserted = memo.insert(arg).second;
if (inserted) {
return builder.Append(arg);
}
return Status::OK();
};
auto rand = random::RandomArrayGenerator(0x1205643);
auto arr = rand.Numeric<UInt32Type>(1024, 0, 100, 0.0)->data();
auto r = VisitArrayDataInline<UInt32Type>(*arr, visit_value, visit_null);
auto input = builder.Finish().ValueOrDie();
Check(input, memo.size(), false);
}

//
// Mean
//
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/arrow/compute/kernels/aggregate_var_std.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ struct VarStdState {

using SumType =
typename std::conditional<is_floating_type<T>::value, double, int128_t>::type;
SumType sum =
arrow::compute::detail::SumArray<CType, SumType, SimdLevel::NONE>(*array.data());
SumType sum = SumArray<CType, SumType, SimdLevel::NONE>(*array.data());

const double mean = static_cast<double>(sum) / count;
const double m2 = arrow::compute::detail::SumArray<CType, double, SimdLevel::NONE>(
*array.data(), [mean](CType value) {
const double m2 =
SumArray<CType, double, SimdLevel::NONE>(*array.data(), [mean](CType value) {
const double v = static_cast<double>(value);
return (v - mean) * (v - mean);
});
Expand Down
2 changes: 2 additions & 0 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ the input to a single output value.
+--------------------+-------+------------------+------------------------+----------------------------------+-------+
| count | Unary | Any | Scalar Int64 | :struct:`CountOptions` | \(2) |
+--------------------+-------+------------------+------------------------+----------------------------------+-------+
| count_distinct | Unary | Non-nested types | Scalar Int64 | :struct:`CountOptions` | \(2) |
+--------------------+-------+------------------+------------------------+----------------------------------+-------+
| index | Unary | Any | Scalar Int64 | :struct:`IndexOptions` | |
+--------------------+-------+------------------+------------------------+----------------------------------+-------+
| max | Unary | Non-nested types | Scalar Input type | :struct:`ScalarAggregateOptions` | |
Expand Down
1 change: 1 addition & 0 deletions docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Aggregations
any
approximate_median
count
count_distinct
index
max
mean
Expand Down
17 changes: 17 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2220,3 +2220,20 @@ def test_list_element():
result = pa.compute.list_element(lists, index)
expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type)
assert result.equals(expected)


def test_count_distinct():
seed = datetime.now()
samples = [seed.replace(year=y) for y in range(1992, 2092)]
arr = pa.array(samples, pa.timestamp("ns"))
result = pa.compute.count_distinct(arr)
expected = pa.scalar(len(samples), type=pa.int64())
assert result.equals(expected)


def test_count_distinct_options():
arr = pa.array([1, 2, 3, None, None])
assert pc.count_distinct(arr).as_py() == 3
assert pc.count_distinct(arr, mode='only_valid').as_py() == 3
assert pc.count_distinct(arr, mode='only_null').as_py() == 1
assert pc.count_distinct(arr, mode='all').as_py() == 4
2 changes: 1 addition & 1 deletion r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return out;
}

if (func_name == "hash_count_distinct") {
if (func_name == "count_distinct" || func_name == "hash_count_distinct") {
using Options = arrow::compute::CountOptions;
auto out = std::make_shared<Options>(Options::Defaults());
out->mode =
Expand Down
Loading