diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index c5e0e6fd6e9..b545d8bcc10 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -23,7 +23,9 @@ #include "arrow/util/cpu_info.h" #include "arrow/util/hashing.h" -#include +// Include templated definitions for aggregate kernels that must compiled here +// with the SIMD level configured for this compilation unit in the build. +#include "arrow/compute/kernels/aggregate_basic.inc.cc" // NOLINT(build/include) namespace arrow { namespace compute { @@ -276,11 +278,6 @@ struct SumImplDefault : public SumImpl { using SumImpl::SumImpl; }; -template -struct MeanImplDefault : public MeanImpl { - using MeanImpl::MeanImpl; -}; - Result> SumInit(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( @@ -289,6 +286,14 @@ Result> SumInit(KernelContext* ctx, return visitor.Create(); } +// ---------------------------------------------------------------------- +// Mean implementation + +template +struct MeanImplDefault : public MeanImpl { + using MeanImpl::MeanImpl; +}; + Result> MeanInit(KernelContext* ctx, const KernelInitArgs& args) { MeanKernelInit visitor( @@ -482,8 +487,8 @@ void AddFirstOrLastAggKernel(ScalarAggregateFunction* func, // ---------------------------------------------------------------------- // MinMax implementation -Result> MinMaxInit(KernelContext* ctx, - const KernelInitArgs& args) { +Result> MinMaxInitDefault(KernelContext* ctx, + const KernelInitArgs& args) { ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( @@ -1114,14 +1119,14 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Add min max function func = std::make_shared("min_max", Arity::Unary(), min_max_doc, &default_scalar_aggregate_options); - AddMinMaxKernels(MinMaxInit, {null(), boolean()}, func.get()); - AddMinMaxKernels(MinMaxInit, NumericTypes(), func.get()); - AddMinMaxKernels(MinMaxInit, TemporalTypes(), func.get()); - AddMinMaxKernels(MinMaxInit, BaseBinaryTypes(), func.get()); - AddMinMaxKernel(MinMaxInit, Type::FIXED_SIZE_BINARY, func.get()); - AddMinMaxKernel(MinMaxInit, Type::INTERVAL_MONTHS, func.get()); - AddMinMaxKernel(MinMaxInit, Type::DECIMAL128, func.get()); - AddMinMaxKernel(MinMaxInit, Type::DECIMAL256, func.get()); + AddMinMaxKernels(MinMaxInitDefault, {null(), boolean()}, func.get()); + AddMinMaxKernels(MinMaxInitDefault, NumericTypes(), func.get()); + AddMinMaxKernels(MinMaxInitDefault, TemporalTypes(), func.get()); + AddMinMaxKernels(MinMaxInitDefault, BaseBinaryTypes(), func.get()); + AddMinMaxKernel(MinMaxInitDefault, Type::FIXED_SIZE_BINARY, func.get()); + AddMinMaxKernel(MinMaxInitDefault, Type::INTERVAL_MONTHS, func.get()); + AddMinMaxKernel(MinMaxInitDefault, Type::DECIMAL128, func.get()); + AddMinMaxKernel(MinMaxInitDefault, Type::DECIMAL256, func.get()); // Add the SIMD variants for min max #if defined(ARROW_HAVE_RUNTIME_AVX2) if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) { diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc new file mode 100644 index 00000000000..f2151e0a9e0 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc @@ -0,0 +1,1025 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// .inc.cc file to be included in compilation unit where kernels are meant to be +// compiled auto-vectorized by the compiler with different SIMD levels passed +// as compiler flags. +// +// It contains no includes to avoid double inclusion in the compilation unit +// that includes this .inc.cc file. + +#include +#include +#include +#include +#include + +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/kernels/aggregate_internal.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/align_util.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/decimal.h" + +namespace arrow::compute::internal { +namespace { + +// ---------------------------------------------------------------------- +// Sum implementation + +template ::Type> +struct SumImpl : public ScalarAggregator { + using ThisType = SumImpl; + using CType = typename TypeTraits::CType; + using SumType = ResultType; + using SumCType = typename TypeTraits::CType; + using OutputType = typename TypeTraits::ScalarType; + + SumImpl(std::shared_ptr out_type, ScalarAggregateOptions options_) + : out_type(std::move(out_type)), options(std::move(options_)) {} + + Status Consume(KernelContext*, const ExecSpan& batch) override { + if (batch[0].is_array()) { + const ArraySpan& data = batch[0].array; + this->count += data.length - data.GetNullCount(); + this->nulls_observed = this->nulls_observed || data.GetNullCount(); + + if (!options.skip_nulls && this->nulls_observed) { + // Short-circuit + return Status::OK(); + } + + if (is_boolean_type::value) { + this->sum += GetTrueCount(data); + } else { + this->sum += SumArray(data); + } + } else { + const Scalar& data = *batch[0].scalar; + this->count += data.is_valid * batch.length; + this->nulls_observed = this->nulls_observed || !data.is_valid; + if (data.is_valid) { + this->sum += internal::UnboxScalar::Unbox(data) * batch.length; + } + } + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + const auto& other = checked_cast(src); + this->count += other.count; + this->sum += other.sum; + this->nulls_observed = this->nulls_observed || other.nulls_observed; + return Status::OK(); + } + + Status Finalize(KernelContext*, Datum* out) override { + if ((!options.skip_nulls && this->nulls_observed) || + (this->count < options.min_count)) { + out->value = std::make_shared(out_type); + } else { + out->value = std::make_shared(this->sum, out_type); + } + return Status::OK(); + } + + size_t count = 0; + bool nulls_observed = false; + SumCType sum = 0; + std::shared_ptr out_type; + ScalarAggregateOptions options; +}; + +template +struct NullImpl : public ScalarAggregator { + using ScalarType = typename TypeTraits::ScalarType; + + explicit NullImpl(const ScalarAggregateOptions& options_) : options(options_) {} + + Status Consume(KernelContext*, const ExecSpan& batch) override { + if (batch[0].is_scalar() || batch[0].array.GetNullCount() > 0) { + // If the batch is a scalar or an array with elements, set is_empty to false + is_empty = false; + } + return Status::OK(); + } + + Status MergeFrom(KernelContext*, KernelState&& src) override { + const auto& other = checked_cast(src); + this->is_empty &= other.is_empty; + return Status::OK(); + } + + Status Finalize(KernelContext*, Datum* out) override { + if ((options.skip_nulls || this->is_empty) && options.min_count == 0) { + // Return 0 if the remaining data is empty + out->value = output_empty(); + } else { + out->value = MakeNullScalar(TypeTraits::type_singleton()); + } + return Status::OK(); + } + + virtual std::shared_ptr output_empty() = 0; + + bool is_empty = true; + ScalarAggregateOptions options; +}; + +template +struct NullSumImpl : public NullImpl { + using ScalarType = typename TypeTraits::ScalarType; + + explicit NullSumImpl(const ScalarAggregateOptions& options_) + : NullImpl(options_) {} + + std::shared_ptr output_empty() override { + return std::make_shared(0); + } +}; + +template