diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index fd705ff973b..8a091f2355d 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -124,26 +124,15 @@ void RegisterScalarCast(FunctionRegistry* registry) { } // namespace internal -struct CastFunction::CastFunctionImpl { - Type::type out_type; - std::unordered_set in_types; -}; - -CastFunction::CastFunction(std::string name, Type::type out_type) - : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr) { - impl_.reset(new CastFunctionImpl()); - impl_->out_type = out_type; -} - -CastFunction::~CastFunction() = default; - -Type::type CastFunction::out_type_id() const { return impl_->out_type; } +CastFunction::CastFunction(std::string name, Type::type out_type_id) + : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr), + out_type_id_(out_type_id) {} Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) { // We use the same KernelInit for every cast kernel.init = internal::CastState::Init; RETURN_NOT_OK(ScalarFunction::AddKernel(kernel)); - impl_->in_types.insert(static_cast(in_type_id)); + in_type_ids_.push_back(in_type_id); return Status::OK(); } @@ -159,19 +148,10 @@ Status CastFunction::AddKernel(Type::type in_type_id, std::vector in_ return AddKernel(in_type_id, std::move(kernel)); } -bool CastFunction::CanCastTo(const DataType& out_type) const { - return impl_->in_types.find(static_cast(out_type.id())) != impl_->in_types.end(); -} - Result CastFunction::DispatchExact( const std::vector& values) const { - const int passed_num_args = static_cast(values.size()); + RETURN_NOT_OK(CheckArity(values)); - // Validate arity - if (passed_num_args != 1) { - return Status::Invalid("Cast functions accept 1 argument but passed ", - passed_num_args); - } std::vector candidate_kernels; for (const auto& kernel : kernels_) { if (kernel.signature->MatchesInputs(values)) { @@ -181,25 +161,28 @@ Result CastFunction::DispatchExact( if (candidate_kernels.size() == 0) { return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), - " to ", ToTypeName(impl_->out_type), " using function ", + " to ", ToTypeName(out_type_id_), " using function ", this->name()); - } else if (candidate_kernels.size() == 1) { + } + + if (candidate_kernels.size() == 1) { // One match, return it return candidate_kernels[0]; - } else { - // Now we are in a casting scenario where we may have both a EXACT_TYPE and - // a SAME_TYPE_ID. So we will see if there is an exact match among the - // candidate kernels and if not we will just return the first one - for (auto kernel : candidate_kernels) { - const InputType& arg0 = kernel->signature->in_types()[0]; - if (arg0.kind() == InputType::EXACT_TYPE) { - // Bingo. Return it - return kernel; - } + } + + // Now we are in a casting scenario where we may have both a EXACT_TYPE and + // a SAME_TYPE_ID. So we will see if there is an exact match among the + // candidate kernels and if not we will just return the first one + for (auto kernel : candidate_kernels) { + const InputType& arg0 = kernel->signature->in_types()[0]; + if (arg0.kind() == InputType::EXACT_TYPE) { + // Bingo. Return it + return kernel; } - // We didn't find an exact match. So just return some kernel that matches - return candidate_kernels[0]; } + + // We didn't find an exact match. So just return some kernel that matches + return candidate_kernels[0]; } Result Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { @@ -225,13 +208,37 @@ Result> GetCastFunction( } bool CanCast(const DataType& from_type, const DataType& to_type) { - // TODO internal::EnsureInitCastTable(); - auto it = internal::g_cast_table.find(static_cast(from_type.id())); + auto it = internal::g_cast_table.find(static_cast(to_type.id())); if (it == internal::g_cast_table.end()) { return false; } - return it->second->CanCastTo(to_type); + + const CastFunction* function = it->second.get(); + DCHECK_EQ(function->out_type_id(), to_type.id()); + + for (auto from_id : function->in_type_ids()) { + // XXX should probably check the output type as well + if (from_type.id() == from_id) return true; + } + + return false; +} + +Result> Cast(std::vector datums, std::vector descrs, + ExecContext* ctx) { + for (size_t i = 0; i != datums.size(); ++i) { + if (descrs[i] != datums[i].descr()) { + if (descrs[i].shape != datums[i].shape()) { + return Status::NotImplemented("casting between Datum shapes"); + } + + ARROW_ASSIGN_OR_RAISE(datums[i], + Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx)); + } + } + + return datums; } } // namespace compute diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 0b9d9caf882..818f2ef9182 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -82,10 +82,10 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { // the same execution machinery class CastFunction : public ScalarFunction { public: - CastFunction(std::string name, Type::type out_type); - ~CastFunction() override; + CastFunction(std::string name, Type::type out_type_id); - Type::type out_type_id() const; + Type::type out_type_id() const { return out_type_id_; } + const std::vector& in_type_ids() const { return in_type_ids_; } Status AddKernel(Type::type in_type_id, std::vector in_types, OutputType out_type, ArrayKernelExec exec, @@ -96,14 +96,12 @@ class CastFunction : public ScalarFunction { // function to CastInit Status AddKernel(Type::type in_type_id, ScalarKernel kernel); - bool CanCastTo(const DataType& out_type) const; - Result DispatchExact( const std::vector& values) const override; private: - struct CastFunctionImpl; - std::unique_ptr impl_; + std::vector in_type_ids_; + const Type::type out_type_id_; }; ARROW_EXPORT @@ -157,5 +155,17 @@ Result Cast(const Datum& value, std::shared_ptr to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); +/// \brief Cast several values simultaneously. Safe cast options are used. +/// \param[in] values datums to cast +/// \param[in] descrs ValueDescrs to cast to +/// \param[in] ctx the function execution context, optional +/// \return the resulting datums +/// +/// \since 4.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Cast(std::vector values, std::vector descrs, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 7a868853db7..70d7d998e9c 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -21,62 +21,66 @@ #include #include +#include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec_internal.h" +#include "arrow/compute/kernels/common.h" #include "arrow/datum.h" #include "arrow/util/cpu_info.h" namespace arrow { + +using internal::checked_cast; + namespace compute { static const FunctionDoc kEmptyFunctionDoc{}; const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } -Status Function::CheckArity(int passed_num_args) const { - if (arity_.is_varargs && passed_num_args < arity_.num_args) { - return Status::Invalid("VarArgs function needs at least ", arity_.num_args, - " arguments but kernel accepts only ", passed_num_args); - } else if (!arity_.is_varargs && passed_num_args != arity_.num_args) { - return Status::Invalid("Function accepts ", arity_.num_args, - " arguments but kernel accepts ", passed_num_args); +static Status CheckArityImpl(const Function* function, int passed_num_args, + const char* passed_num_args_label) { + if (function->arity().is_varargs && passed_num_args < function->arity().num_args) { + return Status::Invalid("VarArgs function ", function->name(), " needs at least ", + function->arity().num_args, " arguments but ", + passed_num_args_label, " only ", passed_num_args); + } + + if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) { + return Status::Invalid("Function ", function->name(), " accepts ", + function->arity().num_args, " arguments but ", + passed_num_args_label, " ", passed_num_args); } + return Status::OK(); } -template -std::string FormatArgTypes(const std::vector& descrs) { - std::stringstream ss; - ss << "("; - for (size_t i = 0; i < descrs.size(); ++i) { - if (i > 0) { - ss << ", "; - } - ss << descrs[i].ToString(); - } - ss << ")"; - return ss.str(); +Status Function::CheckArity(const std::vector& in_types) const { + return CheckArityImpl(this, static_cast(in_types.size()), "kernel accepts"); +} + +Status Function::CheckArity(const std::vector& descrs) const { + return CheckArityImpl(this, static_cast(descrs.size()), + "attempted to look up kernel(s) with"); +} + +namespace detail { + +Status NoMatchingKernel(const Function* func, const std::vector& descrs) { + return Status::NotImplemented("Function ", func->name(), + " has no kernel matching input types ", + ValueDescr::ToString(descrs)); } -template -Result DispatchExactImpl(const Function& func, - const std::vector& kernels, - const std::vector& values) { - const int passed_num_args = static_cast(values.size()); - const KernelType* kernel_matches[SimdLevel::MAX] = {NULL}; +template +const KernelType* DispatchExactImpl(const std::vector& kernels, + const std::vector& values) { + const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr}; // Validate arity - const Arity arity = func.arity(); - if (arity.is_varargs && passed_num_args < arity.num_args) { - return Status::Invalid("VarArgs function needs at least ", arity.num_args, - " arguments but passed only ", passed_num_args); - } else if (!arity.is_varargs && passed_num_args != arity.num_args) { - return Status::Invalid("Function accepts ", arity.num_args, " arguments but passed ", - passed_num_args); - } for (const auto& kernel : kernels) { - if (kernel.signature->MatchesInputs(values)) { - kernel_matches[kernel.simd_level] = &kernel; + if (kernel->signature->MatchesInputs(values)) { + kernel_matches[kernel->simd_level] = kernel; } } @@ -102,9 +106,47 @@ Result DispatchExactImpl(const Function& func, return kernel_matches[SimdLevel::NONE]; } - return Status::NotImplemented("Function ", func.name(), - " has no kernel matching input types ", - FormatArgTypes(values)); + return nullptr; +} + +const Kernel* DispatchExactImpl(const Function* func, + const std::vector& values) { + if (func->kind() == Function::SCALAR) { + return DispatchExactImpl(checked_cast(func)->kernels(), + values); + } + + if (func->kind() == Function::VECTOR) { + return DispatchExactImpl(checked_cast(func)->kernels(), + values); + } + + if (func->kind() == Function::SCALAR_AGGREGATE) { + return DispatchExactImpl( + checked_cast(func)->kernels(), values); + } + + return nullptr; +} + +} // namespace detail + +Result Function::DispatchExact( + const std::vector& values) const { + if (kind_ == Function::META) { + return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); + } + RETURN_NOT_OK(CheckArity(values)); + + if (auto kernel = detail::DispatchExactImpl(this, values)) { + return kernel; + } + return detail::NoMatchingKernel(this, values); +} + +Result Function::DispatchBest(std::vector* values) const { + // TODO(ARROW-11508) permit generic conversions here + return DispatchExact(*values); } Result Function::Execute(const std::vector& args, @@ -116,6 +158,7 @@ Result Function::Execute(const std::vector& args, ExecContext default_ctx; return Execute(args, options, &default_ctx); } + // type-check Datum arguments here. Really we'd like to avoid this as much as // possible RETURN_NOT_OK(detail::CheckAllValues(args)); @@ -124,7 +167,9 @@ Result Function::Execute(const std::vector& args, inputs[i] = args[i].descr(); } - ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchExact(inputs)); + ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs)); + ARROW_ASSIGN_OR_RAISE(auto implicitly_cast_args, Cast(args, inputs, ctx)); + std::unique_ptr state; KernelContext kernel_ctx{ctx}; @@ -145,8 +190,8 @@ Result Function::Execute(const std::vector& args, RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); auto listener = std::make_shared(); - RETURN_NOT_OK(executor->Execute(args, listener.get())); - return executor->WrapResults(args, listener->values()); + RETURN_NOT_OK(executor->Execute(implicitly_cast_args, listener.get())); + return executor->WrapResults(implicitly_cast_args, listener->values()); } Status Function::Validate() const { @@ -167,7 +212,7 @@ Status Function::Validate() const { Status ScalarFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(static_cast(in_types.size()))); + RETURN_NOT_OK(CheckArity(in_types)); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -179,7 +224,7 @@ Status ScalarFunction::AddKernel(std::vector in_types, OutputType out } Status ScalarFunction::AddKernel(ScalarKernel kernel) { - RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -187,14 +232,9 @@ Status ScalarFunction::AddKernel(ScalarKernel kernel) { return Status::OK(); } -Result ScalarFunction::DispatchExact( - const std::vector& values) const { - return DispatchExactImpl(*this, kernels_, values); -} - Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(static_cast(in_types.size()))); + RETURN_NOT_OK(CheckArity(in_types)); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -206,7 +246,7 @@ Status VectorFunction::AddKernel(std::vector in_types, OutputType out } Status VectorFunction::AddKernel(VectorKernel kernel) { - RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -214,13 +254,8 @@ Status VectorFunction::AddKernel(VectorKernel kernel) { return Status::OK(); } -Result VectorFunction::DispatchExact( - const std::vector& values) const { - return DispatchExactImpl(*this, kernels_, values); -} - Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { - RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -228,15 +263,12 @@ Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { return Status::OK(); } -Result ScalarAggregateFunction::DispatchExact( - const std::vector& values) const { - return DispatchExactImpl(*this, kernels_, values); -} - Result MetaFunction::Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { - RETURN_NOT_OK(CheckArity(static_cast(args.size()))); + RETURN_NOT_OK( + CheckArityImpl(this, static_cast(args.size()), "attempted to Execute with")); + if (options == nullptr) { options = default_options(); } diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index e8e732027c9..af5d81a30ec 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -162,7 +162,15 @@ class ARROW_EXPORT Function { /// /// NB: This function is overridden in CastFunction. virtual Result DispatchExact( - const std::vector& values) const = 0; + const std::vector& values) const; + + /// \brief Return a best-match kernel that can execute the function given the argument + /// types, after implicit casts are applied. + /// + /// \param[in,out] values Argument types. An element may be modified to indicate that + /// the returned kernel only approximately matches the input value descriptors; callers + /// are responsible for casting inputs to the type and shape required by the kernel. + virtual Result DispatchBest(std::vector* values) const; /// \brief Execute the function eagerly with the passed input arguments with /// kernel dispatch, batch iteration, and memory allocation details taken @@ -191,7 +199,8 @@ class ARROW_EXPORT Function { doc_(doc ? doc : &FunctionDoc::Empty()), default_options_(default_options) {} - Status CheckArity(int passed_num_args) const; + Status CheckArity(const std::vector&) const; + Status CheckArity(const std::vector&) const; std::string name_; Function::Kind kind_; @@ -224,6 +233,14 @@ class FunctionImpl : public Function { std::vector kernels_; }; +/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. +ARROW_EXPORT +const Kernel* DispatchExactImpl(const Function* func, const std::vector&); + +/// \brief Return an error message if no Kernel is found. +ARROW_EXPORT +Status NoMatchingKernel(const Function* func, const std::vector&); + } // namespace detail /// \brief A function that executes elementwise operations on arrays or @@ -249,9 +266,6 @@ class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl { /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(ScalarKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; }; /// \brief A function that executes general array operations that may yield @@ -276,9 +290,6 @@ class ARROW_EXPORT VectorFunction : public detail::FunctionImpl { /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(VectorKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; }; class ARROW_EXPORT ScalarAggregateFunction @@ -294,9 +305,6 @@ class ARROW_EXPORT ScalarAggregateFunction /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(ScalarAggregateKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; }; /// \brief A function that dispatches to other functions. Must implement @@ -311,10 +319,6 @@ class ARROW_EXPORT MetaFunction : public Function { Result Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const override; - Result DispatchExact(const std::vector&) const override { - return Status::NotImplemented("DispatchExact for a MetaFunction's Kernels"); - } - protected: virtual Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 67cb5df7908..c8f9cacfb34 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -566,7 +566,7 @@ struct Kernel { /// output array values (as opposed to scalar values in the case of aggregate /// functions). struct ArrayKernel : public Kernel { - ArrayKernel() {} + ArrayKernel() = default; ArrayKernel(std::shared_ptr sig, ArrayKernelExec exec, KernelInit init = NULLPTR) @@ -614,7 +614,7 @@ using VectorFinalize = std::function*)>; /// (which have different defaults from ScalarKernel), and some other /// execution-related options. struct VectorKernel : public ArrayKernel { - VectorKernel() {} + VectorKernel() = default; VectorKernel(std::shared_ptr sig, ArrayKernelExec exec) : ArrayKernel(std::move(sig), exec) {} diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index a5941ea2200..b321ff3fc8b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -179,6 +179,140 @@ Result FirstType(KernelContext*, const std::vector& desc return descrs[0]; } +void EnsureDictionaryDecoded(std::vector* descrs) { + for (ValueDescr& descr : *descrs) { + if (descr.type->id() == Type::DICTIONARY) { + descr.type = checked_cast(*descr.type).value_type(); + } + } +} + +void ReplaceNullWithOtherType(std::vector* descrs) { + DCHECK_EQ(descrs->size(), 2); + + if (descrs->at(0).type->id() == Type::NA) { + descrs->at(0).type = descrs->at(1).type; + return; + } + + if (descrs->at(1).type->id() == Type::NA) { + descrs->at(1).type = descrs->at(0).type; + return; + } +} + +void ReplaceTypes(const std::shared_ptr& type, + std::vector* descrs) { + for (auto& descr : *descrs) { + descr.type = type; + } +} + +std::shared_ptr CommonNumeric(const std::vector& descrs) { + DCHECK(!descrs.empty()) << "tried to find CommonNumeric type of an empty set"; + + for (const auto& descr : descrs) { + auto id = descr.type->id(); + if (!is_floating(id) && !is_integer(id)) { + // a common numeric type is only possible if all types are numeric + return nullptr; + } + if (id == Type::HALF_FLOAT) { + // float16 arithmetic is not currently supported + return nullptr; + } + } + + for (const auto& descr : descrs) { + if (descr.type->id() == Type::DOUBLE) return float64(); + } + + for (const auto& descr : descrs) { + if (descr.type->id() == Type::FLOAT) return float32(); + } + + int max_width_signed = 0, max_width_unsigned = 0; + + for (const auto& descr : descrs) { + auto id = descr.type->id(); + auto max_width = is_signed_integer(id) ? &max_width_signed : &max_width_unsigned; + *max_width = std::max(bit_width(id), *max_width); + } + + if (max_width_signed == 0) { + if (max_width_unsigned >= 64) return uint64(); + if (max_width_unsigned == 32) return uint32(); + if (max_width_unsigned == 16) return uint16(); + DCHECK_EQ(max_width_unsigned, 8); + return int8(); + } + + if (max_width_signed <= max_width_unsigned) { + max_width_signed = static_cast(BitUtil::NextPower2(max_width_unsigned + 1)); + } + + if (max_width_signed >= 64) return int64(); + if (max_width_signed == 32) return int32(); + if (max_width_signed == 16) return int16(); + DCHECK_EQ(max_width_signed, 8); + return int8(); +} + +std::shared_ptr CommonTimestamp(const std::vector& descrs) { + TimeUnit::type finest_unit = TimeUnit::SECOND; + + for (const auto& descr : descrs) { + auto id = descr.type->id(); + // a common timestamp is only possible if all types are timestamp like + switch (id) { + case Type::DATE32: + case Type::DATE64: + continue; + case Type::TIMESTAMP: + finest_unit = + std::max(finest_unit, checked_cast(*descr.type).unit()); + continue; + default: + return nullptr; + } + } + + return timestamp(finest_unit); +} + +std::shared_ptr CommonBinary(const std::vector& descrs) { + bool all_utf8 = true, all_offset32 = true; + + for (const auto& descr : descrs) { + auto id = descr.type->id(); + // a common varbinary type is only possible if all types are binary like + switch (id) { + case Type::STRING: + continue; + case Type::BINARY: + all_utf8 = false; + continue; + case Type::LARGE_STRING: + all_offset32 = false; + continue; + case Type::LARGE_BINARY: + all_offset32 = false; + all_utf8 = false; + continue; + default: + return nullptr; + } + } + + if (all_utf8) { + if (all_offset32) return utf8(); + return large_utf8(); + } + + if (all_offset32) return binary(); + return large_binary(); +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index c3a6b4b9772..f39ffdcca11 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1186,6 +1186,24 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) { // END of kernel generator-dispatchers // ---------------------------------------------------------------------- +ARROW_EXPORT +void EnsureDictionaryDecoded(std::vector* descrs); + +ARROW_EXPORT +void ReplaceNullWithOtherType(std::vector* descrs); + +ARROW_EXPORT +void ReplaceTypes(const std::shared_ptr&, std::vector* descrs); + +ARROW_EXPORT +std::shared_ptr CommonNumeric(const std::vector& descrs); + +ARROW_EXPORT +std::shared_ptr CommonTimestamp(const std::vector& descrs); + +ARROW_EXPORT +std::shared_ptr CommonBinary(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index fc18da7cf13..7abaa1c1a59 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -264,10 +264,31 @@ ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) { } } +struct ArithmeticFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); + ReplaceNullWithOtherType(values); + + if (auto type = CommonNumeric(*values)) { + ReplaceTypes(type, values); + } + + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + template std::shared_ptr MakeArithmeticFunction(std::string name, const FunctionDoc* doc) { - auto func = std::make_shared(name, Arity::Binary(), doc); + auto func = std::make_shared(name, Arity::Binary(), doc); for (const auto& ty : NumericTypes()) { auto exec = NumericEqualTypesBinary(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); @@ -280,7 +301,7 @@ std::shared_ptr MakeArithmeticFunction(std::string name, template std::shared_ptr MakeArithmeticFunctionNotNull(std::string name, const FunctionDoc* doc) { - auto func = std::make_shared(name, Arity::Binary(), doc); + auto func = std::make_shared(name, Arity::Binary(), doc); for (const auto& ty : NumericTypes()) { auto exec = NumericEqualTypesBinary(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index a19abe82873..4d4f14e1154 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -67,7 +67,7 @@ class TestBinaryArithmetic : public TestBase { using BinaryFunction = std::function(const Datum&, const Datum&, ArithmeticOptions, ExecContext*)>; - void SetUp() { options_.check_overflow = false; } + void SetUp() override { options_.check_overflow = false; } std::shared_ptr MakeNullScalar() { return arrow::MakeNullScalar(type_singleton()); @@ -637,5 +637,77 @@ TYPED_TEST(TestBinaryArithmeticFloating, Mul) { this->AssertBinop(Multiply, "[null, 2.0]", this->MakeNullScalar(), "[null, null]"); } +TEST(TestBinaryArithmetic, DispatchBest) { + for (std::string name : {"add", "subtract", "multiply", "divide"}) { + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); + CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); + + CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()}); + + CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()}); + CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()}); + + CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()}); + CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()}); + + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); + + CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, + {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, + {float64(), float64()}); + } + } +} + +TEST(TestBinaryArithmetic, AddWithImplicitCasts) { + CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + ArrayFromJSON(float64(), "[0.25, 0.5, 0.75, 1.0]"), + ArrayFromJSON(float64(), "[0.25, 1.5, 2.75, null]")); + + CheckScalarBinary("add", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(int64(), "[-13, 4, 21, null]")); + + CheckScalarBinary("add", + ArrayFromJSON(dictionary(int32(), int32()), "[8, 6, 3, null, 2]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7, 0]"), + ArrayFromJSON(int64(), "[11, 10, 8, null, 2]")); + + CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + std::make_shared(4), + ArrayFromJSON(int32(), "[null, null, null, null]")); + + CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(int64(), "[3, 5, 7, null]")); +} + +TEST(TestBinaryArithmetic, AddWithImplicitCastsUint64EdgeCase) { + // int64 is as wide as we can promote + CheckDispatchBest("add", {int8(), uint64()}, {int64(), int64()}); + + // this works sometimes + CheckScalarBinary("add", ArrayFromJSON(int8(), "[-1]"), ArrayFromJSON(uint64(), "[0]"), + ArrayFromJSON(int64(), "[-1]")); + + // ... but it can result in impossible implicit casts in the presence of uint64, since + // some uint64 values cannot be cast to int64: + ASSERT_RAISES(Invalid, + CallFunction("add", {ArrayFromJSON(int64(), "[-1]"), + ArrayFromJSON(uint64(), "[18446744073709551615]")})); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc b/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc index 07026db83be..e529d3791aa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc @@ -50,6 +50,7 @@ struct ParseBooleanString { std::vector> GetBooleanCasts() { auto func = std::make_shared("cast_boolean", Type::BOOL); AddCommonCasts(Type::BOOL, boolean(), func.get()); + AddZeroCopyCast(Type::BOOL, boolean(), boolean(), func.get()); for (const auto& ty : NumericTypes()) { ArrayKernelExec exec = diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index f8dde20e3aa..7221722d53a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,17 +149,13 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (out->is_scalar()) { - KERNEL_ASSIGN_OR_RAISE(*out, ctx, - batch[0].scalar_as().GetEncodedValue()); - return; - } + DCHECK(out->is_array()); DictionaryArray dict_arr(batch[0].array()); const CastOptions& options = checked_cast(*ctx->state()).options; const auto& dict_type = *dict_arr.dictionary()->type(); - if (!dict_type.Equals(options.to_type)) { + if (!dict_type.Equals(options.to_type) && !CanCast(dict_type, *options.to_type)) { ctx->SetStatus(Status::Invalid("Cast type ", options.to_type->ToString(), " incompatible with dictionary type ", dict_type.ToString())); @@ -169,6 +165,10 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { KERNEL_ASSIGN_OR_RAISE(*out, ctx, Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), TakeOptions::Defaults(), ctx->exec_context())); + + if (!dict_type.Equals(options.to_type)) { + KERNEL_ASSIGN_OR_RAISE(*out, ctx, Cast(*out, options)); + } } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -224,28 +224,23 @@ Result ResolveOutputFromOptions(KernelContext* ctx, OutputType kOutputTargetType(ResolveOutputFromOptions); void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { - // Make a copy of the buffers into a destination array without carrying - // the type - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); - output->length = input.length; - output->SetNullCount(input.null_count); - output->buffers = input.buffers; - output->offset = input.offset; - output->child_data = input.child_data; - } else { - ctx->SetStatus( - Status::NotImplemented("This cast not yet implemented for " - "scalar input")); - } + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + // Make a copy of the buffers into a destination array without carrying + // the type + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); + output->length = input.length; + output->SetNullCount(input.null_count); + output->buffers = input.buffers; + output->offset = input.offset; + output->child_data = input.child_data; } void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, CastFunction* func) { auto sig = KernelSignature::Make({in_type}, out_type); ScalarKernel kernel; - kernel.exec = ZeroCopyCastExec; + kernel.exec = TrivialScalarUnaryAsArraysExec(ZeroCopyCastExec); kernel.signature = sig; kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; @@ -268,7 +263,8 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun // XXX: Uses Take and does its own memory allocation for the moment. We can // fix this later. DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty, - UnpackDictionary, NullHandling::COMPUTED_NO_PREALLOCATE, + TrivialScalarUnaryAsArraysExec(UnpackDictionary), + NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 6e550fb12c0..4520230f2ae 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -62,7 +62,7 @@ Status CheckFloatTruncation(const Datum& input, const Datum& output) { return is_valid && static_cast(out_val) != in_val; }; auto GetErrorMessage = [&](InT val) { - return Status::Invalid("Float value ", val, " was truncated converting to", + return Status::Invalid("Float value ", val, " was truncated converting to ", *output.type()); }; diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 7d502f046fc..b339018072e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -215,41 +215,41 @@ void AddBinaryToBinaryCast(CastFunction* func) { auto out_ty = TypeTraits::type_singleton(); DCHECK_OK(func->AddKernel( - OutType::type_id, {in_ty}, out_ty, + InType::type_id, {in_ty}, out_ty, TrivialScalarUnaryAsArraysExec(BinaryToBinaryCastFunctor::Exec), NullHandling::COMPUTED_NO_PREALLOCATE)); } +template +void AddBinaryToBinaryCast(CastFunction* func) { + AddBinaryToBinaryCast(func); + AddBinaryToBinaryCast(func); + AddBinaryToBinaryCast(func); + AddBinaryToBinaryCast(func); +} + } // namespace std::vector> GetBinaryLikeCasts() { auto cast_binary = std::make_shared("cast_binary", Type::BINARY); AddCommonCasts(Type::BINARY, binary(), cast_binary.get()); - AddBinaryToBinaryCast(cast_binary.get()); - AddBinaryToBinaryCast(cast_binary.get()); - AddBinaryToBinaryCast(cast_binary.get()); + AddBinaryToBinaryCast(cast_binary.get()); auto cast_large_binary = std::make_shared("cast_large_binary", Type::LARGE_BINARY); AddCommonCasts(Type::LARGE_BINARY, large_binary(), cast_large_binary.get()); - AddBinaryToBinaryCast(cast_large_binary.get()); - AddBinaryToBinaryCast(cast_large_binary.get()); - AddBinaryToBinaryCast(cast_large_binary.get()); + AddBinaryToBinaryCast(cast_large_binary.get()); auto cast_string = std::make_shared("cast_string", Type::STRING); AddCommonCasts(Type::STRING, utf8(), cast_string.get()); AddNumberToStringCasts(cast_string.get()); - AddBinaryToBinaryCast(cast_string.get()); - AddBinaryToBinaryCast(cast_string.get()); - AddBinaryToBinaryCast(cast_string.get()); + AddBinaryToBinaryCast(cast_string.get()); auto cast_large_string = std::make_shared("cast_large_string", Type::LARGE_STRING); AddCommonCasts(Type::LARGE_STRING, large_utf8(), cast_large_string.get()); AddNumberToStringCasts(cast_large_string.get()); - AddBinaryToBinaryCast(cast_large_string.get()); - AddBinaryToBinaryCast(cast_large_string.get()); - AddBinaryToBinaryCast(cast_large_string.get()); + AddBinaryToBinaryCast(cast_large_string.get()); auto cast_fsb = std::make_shared("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index e470f9f90de..d7d1faf7ae5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -322,7 +322,7 @@ struct CastFunctor::value>> template void AddCrossUnitCast(CastFunction* func) { ScalarKernel kernel; - kernel.exec = CastFunctor::Exec; + kernel.exec = TrivialScalarUnaryAsArraysExec(CastFunctor::Exec); kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType); DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 350728793e6..2a0f44187b2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -75,6 +75,9 @@ static std::vector> kNumericTypes = { uint8(), int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}; +static std::vector> kBaseBinaryTypes = { + binary(), utf8(), large_binary(), large_utf8()}; + static void AssertBufferSame(const Array& left, const Array& right, int buffer_index) { ASSERT_EQ(left.data()->buffers[buffer_index].get(), right.data()->buffers[buffer_index].get()); @@ -403,6 +406,75 @@ class TestCast : public TestBase { } }; +TEST_F(TestCast, CanCast) { + auto ExpectCanCast = [](std::shared_ptr from, + std::vector> to_set, + bool expected = true) { + for (auto to : to_set) { + EXPECT_EQ(CanCast(*from, *to), expected) << " from: " << from->ToString() << "\n" + << " to: " << to->ToString(); + } + }; + + auto ExpectCannotCast = [ExpectCanCast](std::shared_ptr from, + std::vector> to_set) { + ExpectCanCast(from, to_set, /*expected=*/false); + }; + + ExpectCanCast(null(), {boolean()}); + ExpectCanCast(null(), kNumericTypes); + ExpectCanCast(null(), kBaseBinaryTypes); + ExpectCanCast( + null(), {date32(), date64(), time32(TimeUnit::MILLI), timestamp(TimeUnit::SECOND)}); + ExpectCanCast(dictionary(uint16(), null()), {null()}); + + ExpectCanCast(boolean(), {boolean()}); + ExpectCanCast(boolean(), kNumericTypes); + ExpectCanCast(boolean(), {utf8(), large_utf8()}); + ExpectCanCast(dictionary(int32(), boolean()), {boolean()}); + + ExpectCannotCast(boolean(), {null()}); + ExpectCannotCast(boolean(), {binary(), large_binary()}); + ExpectCannotCast(boolean(), {date32(), date64(), time32(TimeUnit::MILLI), + timestamp(TimeUnit::SECOND)}); + + for (auto from_numeric : kNumericTypes) { + ExpectCanCast(from_numeric, {boolean()}); + ExpectCanCast(from_numeric, kNumericTypes); + ExpectCanCast(from_numeric, {utf8(), large_utf8()}); + ExpectCanCast(dictionary(int32(), from_numeric), {from_numeric}); + + ExpectCannotCast(from_numeric, {null()}); + } + + for (auto from_base_binary : kBaseBinaryTypes) { + ExpectCanCast(from_base_binary, {boolean()}); + ExpectCanCast(from_base_binary, kNumericTypes); + ExpectCanCast(from_base_binary, kBaseBinaryTypes); + ExpectCanCast(dictionary(int64(), from_base_binary), {from_base_binary}); + + // any cast which is valid for the dictionary is valid for the DictionaryArray + ExpectCanCast(dictionary(uint32(), from_base_binary), kBaseBinaryTypes); + ExpectCanCast(dictionary(int16(), from_base_binary), kNumericTypes); + + ExpectCannotCast(from_base_binary, {null()}); + } + + ExpectCanCast(utf8(), {timestamp(TimeUnit::MILLI)}); + ExpectCanCast(large_utf8(), {timestamp(TimeUnit::NANO)}); + ExpectCannotCast(timestamp(TimeUnit::MICRO), + kBaseBinaryTypes); // no formatting supported + + ExpectCannotCast(fixed_size_binary(3), + {fixed_size_binary(3)}); // FIXME missing identity cast + + ExtensionTypeGuard smallint_guard(smallint()); + ExpectCanCast(smallint(), {int16()}); // cast storage + ExpectCanCast(smallint(), + kNumericTypes); // any cast which is valid for storage is supported + ExpectCannotCast(null(), {smallint()}); // FIXME missing common cast from null +} + TEST_F(TestCast, SameTypeZeroCopy) { std::shared_ptr arr = ArrayFromJSON(int32(), "[0, null, 2, 3, 4]"); ASSERT_OK_AND_ASSIGN(std::shared_ptr result, Cast(*arr, int32())); @@ -1855,7 +1927,7 @@ std::shared_ptr SmallintArrayFromJSON(const std::string& json_data) { TEST_F(TestCast, ExtensionTypeToIntDowncast) { auto smallint = std::make_shared(); - ASSERT_OK(RegisterExtensionType(smallint)); + ExtensionTypeGuard smallint_guard(smallint); CastOptions options; options.allow_int_overflow = false; @@ -1891,8 +1963,6 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { // disallow overflow options.allow_int_overflow = false; ASSERT_RAISES(Invalid, Cast(*v3, uint8(), options)); - - ASSERT_OK(UnregisterExtensionType("smallint")); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index cf32c888e8e..58d3e6fc781 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -72,10 +72,35 @@ void AddGenericCompare(const std::shared_ptr& ty, ScalarFunction* func applicator::ScalarBinaryEqualTypes::Exec)); } +struct CompareFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); + ReplaceNullWithOtherType(values); + + if (auto type = CommonNumeric(*values)) { + ReplaceTypes(type, values); + } else if (auto type = CommonTimestamp(*values)) { + ReplaceTypes(type, values); + } else if (auto type = CommonBinary(*values)) { + ReplaceTypes(type, values); + } + + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + template std::shared_ptr MakeCompareFunction(std::string name, const FunctionDoc* doc) { - auto func = std::make_shared(name, Arity::Binary(), doc); + auto func = std::make_shared(name, Arity::Binary(), doc); DCHECK_OK(func->AddKernel( {boolean(), boolean()}, boolean(), @@ -136,7 +161,7 @@ std::shared_ptr MakeCompareFunction(std::string name, std::shared_ptr MakeFlippedFunction(std::string name, const ScalarFunction& func, const FunctionDoc* doc) { - auto flipped_func = std::make_shared(name, Arity::Binary(), doc); + auto flipped_func = std::make_shared(name, Arity::Binary(), doc); for (const ScalarKernel* kernel : func.kernels()) { ScalarKernel flipped_kernel = *kernel; flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 6f742fe7bfd..7b0906395d7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -451,6 +451,96 @@ TEST(TestCompareTimestamps, Basics) { CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]"); } +TEST(TestCompareKernel, DispatchBest) { + for (std::string name : + {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); + CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); + + CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()}); + + CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()}); + CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()}); + + CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()}); + CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()}); + + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); + + CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, + {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, + {float64(), float64()}); + + CheckDispatchBest(name, {timestamp(TimeUnit::MICRO), date64()}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); + + CheckDispatchBest(name, {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO)}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); + + CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()}); + CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()}); + } +} + +TEST(TestCompareKernel, GreaterWithImplicitCasts) { + CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + ArrayFromJSON(float64(), "[0.5, 1.0, 1.5, 2.0]"), + ArrayFromJSON(boolean(), "[false, false, true, null]")); + + CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(boolean(), "[false, false, true, null]")); + + CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), + ArrayFromJSON(uint8(), "[255, 254, 1, 0]"), + ArrayFromJSON(boolean(), "[false, false, true, null]")); + + CheckScalarBinary("greater", + ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(boolean(), "[false, false, false, null]")); + + CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + std::make_shared(4), + ArrayFromJSON(boolean(), "[null, null, null, null]")); + + CheckScalarBinary("greater", + ArrayFromJSON(timestamp(TimeUnit::SECOND), + R"(["1970-01-01","2000-02-29","1900-02-28"])"), + ArrayFromJSON(date64(), "[86400000, 0, 86400000]"), + ArrayFromJSON(boolean(), "[false, true, false]")); + + CheckScalarBinary("greater", + ArrayFromJSON(dictionary(int32(), int8()), "[3, -3, -28, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(boolean(), "[false, false, false, null]")); +} + +TEST(TestCompareKernel, GreaterWithImplicitCastsUint64EdgeCase) { + // int64 is as wide as we can promote + CheckDispatchBest("greater", {int8(), uint64()}, {int64(), int64()}); + + // this works sometimes + CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-1]"), + ArrayFromJSON(uint64(), "[0]"), ArrayFromJSON(boolean(), "[false]")); + + // ... but it can result in impossible implicit casts in the presence of uint64, since + // some uint64 values cannot be cast to int64: + ASSERT_RAISES( + Invalid, + CallFunction("greater", {ArrayFromJSON(int64(), "[-1]"), + ArrayFromJSON(uint64(), "[18446744073709551615]")})); +} + class TestStringCompareKernel : public ::testing::Test {}; TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { @@ -459,85 +549,74 @@ TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { CompareOptions eq(CompareOperator::EQUAL); ValidateCompare(eq, "[]", one, "[]"); ValidateCompare(eq, "[null]", one, "[null]"); - ValidateCompare(eq, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare( - eq, "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", one, "[0,1,0,0,0,0]"); - ValidateCompare( - eq, "[\"five\",\"four\",\"three\",\"two\",\"one\",\"zero\"]", one, "[0,0,0,0,1,0]"); - ValidateCompare(eq, "[null,\"zero\",\"one\",\"one\"]", one, "[null,0,1,1]"); + ValidateCompare(eq, R"(["zero","zero","one","one","two","two"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(eq, R"(["zero","one","two","three","four","five"])", one, + "[0,1,0,0,0,0]"); + ValidateCompare(eq, R"(["five","four","three","two","one","zero"])", one, + "[0,0,0,0,1,0]"); + ValidateCompare(eq, R"([null,"zero","one","one"])", one, "[null,0,1,1]"); Datum na(std::make_shared()); - ValidateCompare(eq, "[null,\"zero\",\"one\",\"one\"]", na, + ValidateCompare(eq, R"([null,"zero","one","one"])", na, "[null,null,null,null]"); - ValidateCompare(eq, na, "[null,\"zero\",\"one\",\"one\"]", + ValidateCompare(eq, na, R"([null,"zero","one","one"])", "[null,null,null,null]"); CompareOptions neq(CompareOperator::NOT_EQUAL); ValidateCompare(neq, "[]", one, "[]"); ValidateCompare(neq, "[null]", one, "[null]"); - ValidateCompare(neq, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare(neq, - "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", - one, "[1,0,1,1,1,1]"); - ValidateCompare(neq, - "[\"five\",\"four\",\"three\",\"two\",\"one\",\"zero\"]", - one, "[1,1,1,1,0,1]"); - ValidateCompare(neq, "[null,\"zero\",\"one\",\"one\"]", one, - "[null,1,0,0]"); + ValidateCompare(neq, R"(["zero","zero","one","one","two","two"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(neq, R"(["zero","one","two","three","four","five"])", one, + "[1,0,1,1,1,1]"); + ValidateCompare(neq, R"(["five","four","three","two","one","zero"])", one, + "[1,1,1,1,0,1]"); + ValidateCompare(neq, R"([null,"zero","one","one"])", one, "[null,1,0,0]"); CompareOptions gt(CompareOperator::GREATER); ValidateCompare(gt, "[]", one, "[]"); ValidateCompare(gt, "[null]", one, "[null]"); - ValidateCompare(gt, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare( - gt, "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", one, "[1,0,1,1,0,0]"); - ValidateCompare(gt, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare(gt, "[null,\"zero\",\"one\",\"one\"]", one, "[null,1,0,0]"); + ValidateCompare(gt, R"(["zero","zero","one","one","two","two"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(gt, R"(["zero","one","two","three","four","five"])", one, + "[1,0,1,1,0,0]"); + ValidateCompare(gt, R"(["four","five","six","seven","eight","nine"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(gt, R"([null,"zero","one","one"])", one, "[null,1,0,0]"); CompareOptions gte(CompareOperator::GREATER_EQUAL); ValidateCompare(gte, "[]", one, "[]"); ValidateCompare(gte, "[null]", one, "[null]"); - ValidateCompare(gte, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[1,1,1,1,1,1]"); - ValidateCompare(gte, - "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", - one, "[1,1,1,1,0,0]"); - ValidateCompare(gte, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare(gte, "[null,\"zero\",\"one\",\"one\"]", one, - "[null,1,1,1]"); + ValidateCompare(gte, R"(["zero","zero","one","one","two","two"])", one, + "[1,1,1,1,1,1]"); + ValidateCompare(gte, R"(["zero","one","two","three","four","five"])", one, + "[1,1,1,1,0,0]"); + ValidateCompare(gte, R"(["four","five","six","seven","eight","nine"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(gte, R"([null,"zero","one","one"])", one, "[null,1,1,1]"); CompareOptions lt(CompareOperator::LESS); ValidateCompare(lt, "[]", one, "[]"); ValidateCompare(lt, "[null]", one, "[null]"); - ValidateCompare(lt, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[0,0,0,0,0,0]"); - ValidateCompare( - lt, "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", one, "[0,0,0,0,1,1]"); - ValidateCompare(lt, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare(lt, "[null,\"zero\",\"one\",\"one\"]", one, "[null,0,0,0]"); + ValidateCompare(lt, R"(["zero","zero","one","one","two","two"])", one, + "[0,0,0,0,0,0]"); + ValidateCompare(lt, R"(["zero","one","two","three","four","five"])", one, + "[0,0,0,0,1,1]"); + ValidateCompare(lt, R"(["four","five","six","seven","eight","nine"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(lt, R"([null,"zero","one","one"])", one, "[null,0,0,0]"); CompareOptions lte(CompareOperator::LESS_EQUAL); ValidateCompare(lte, "[]", one, "[]"); ValidateCompare(lte, "[null]", one, "[null]"); - ValidateCompare(lte, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare(lte, - "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", - one, "[0,1,0,0,1,1]"); - ValidateCompare(lte, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare(lte, "[null,\"zero\",\"one\",\"one\"]", one, - "[null,0,1,1]"); + ValidateCompare(lte, R"(["zero","zero","one","one","two","two"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(lte, R"(["zero","one","two","three","four","five"])", one, + "[0,1,0,0,1,1]"); + ValidateCompare(lte, R"(["four","five","six","seven","eight","nine"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(lte, R"([null,"zero","one","one"])", one, "[null,0,1,1]"); } TEST_F(TestStringCompareKernel, RandomCompareArrayScalar) { @@ -563,7 +642,7 @@ TEST_F(TestStringCompareKernel, RandomCompareArrayArray) { for (size_t i = 3; i < 5; i++) { for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - const int64_t length = static_cast(1ULL << i); + auto length = static_cast(1ULL << i); auto lhs = Datum(rand.String(length << i, 0, 16, null_probability)); auto rhs = Datum(rand.String(length << i, 0, 16, null_probability)); auto options = CompareOptions(op); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index 93fa34c1694..ffc1e11a7be 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -18,6 +18,7 @@ #include "arrow/array/array_base.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/util/bit_util.h" @@ -36,10 +37,9 @@ namespace { template struct SetLookupState : public KernelState { - explicit SetLookupState(const SetLookupOptions& options, MemoryPool* pool) - : options(options), lookup_table(pool, 0) {} + explicit SetLookupState(MemoryPool* pool) : lookup_table(pool, 0) {} - Status Init() { + Status Init(const SetLookupOptions& options) { if (options.value_set.kind() == Datum::ARRAY) { RETURN_NOT_OK(AddArrayValueSet(*options.value_set.array())); } else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) { @@ -53,7 +53,9 @@ struct SetLookupState : public KernelState { if (lookup_table.size() != options.value_set.length()) { return Status::NotImplemented("duplicate values in value_set"); } - value_set_has_null = (lookup_table.GetNull() >= 0); + if (!options.skip_nulls) { + null_index = lookup_table.GetNull(); + } return Status::OK(); } @@ -72,22 +74,19 @@ struct SetLookupState : public KernelState { } using MemoTable = typename HashTraits::MemoTableType; - const SetLookupOptions& options; MemoTable lookup_table; - bool value_set_has_null; + int32_t null_index = -1; }; template <> struct SetLookupState : public KernelState { - explicit SetLookupState(const SetLookupOptions& options, MemoryPool*) - : options(options) {} + explicit SetLookupState(MemoryPool*) {} - Status Init() { - this->value_set_has_null = (options.value_set.length() > 0); + Status Init(const SetLookupOptions& options) { + value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls; return Status::OK(); } - const SetLookupOptions& options; bool value_set_has_null; }; @@ -118,21 +117,20 @@ struct UnsignedIntType<8> { // Constructing the type requires a type parameter struct InitStateVisitor { KernelContext* ctx; - const SetLookupOptions* options; + SetLookupOptions options; + const std::shared_ptr& arg_type; std::unique_ptr result; - InitStateVisitor(KernelContext* ctx, const SetLookupOptions* options) - : ctx(ctx), options(options) {} + InitStateVisitor(KernelContext* ctx, const KernelInitArgs& args) + : ctx(ctx), + options(*checked_cast(args.options)), + arg_type(args.inputs[0].type) {} template Status Init() { - if (options == nullptr) { - return Status::Invalid( - "Attempted to call a set lookup function without SetLookupOptions"); - } using StateType = SetLookupState; - result.reset(new StateType(*options, ctx->exec_context()->memory_pool())); - return static_cast(result.get())->Init(); + result.reset(new StateType(ctx->exec_context()->memory_pool())); + return static_cast(result.get())->Init(options); } Status Visit(const DataType&) { return Init(); } @@ -157,7 +155,13 @@ struct InitStateVisitor { Status Visit(const FixedSizeBinaryType& type) { return Init(); } Status GetResult(std::unique_ptr* out) { - RETURN_NOT_OK(VisitTypeInline(*options->value_set.type(), this)); + if (!options.value_set.type()->Equals(arg_type)) { + ARROW_ASSIGN_OR_RAISE( + options.value_set, + Cast(options.value_set, CastOptions::Safe(arg_type), ctx->exec_context())); + } + + RETURN_NOT_OK(VisitTypeInline(*arg_type, this)); *out = std::move(result); return Status::OK(); } @@ -165,9 +169,14 @@ struct InitStateVisitor { std::unique_ptr InitSetLookup(KernelContext* ctx, const KernelInitArgs& args) { - InitStateVisitor visitor{ctx, static_cast(args.options)}; + if (args.options == nullptr) { + ctx->SetStatus(Status::Invalid( + "Attempted to call a set lookup function without SetLookupOptions")); + return nullptr; + } + std::unique_ptr result; - ctx->SetStatus(visitor.GetResult(&result)); + ctx->SetStatus(InitStateVisitor{ctx, args}.GetResult(&result)); return result; } @@ -185,7 +194,7 @@ struct IndexInVisitor { const auto& state = checked_cast&>(*ctx->state()); if (data.length != 0) { // skip_nulls is honored for consistency with other types - if (state.value_set_has_null && !state.options.skip_nulls) { + if (state.value_set_has_null) { RETURN_NOT_OK(this->builder.Reserve(data.length)); for (int64_t i = 0; i < data.length; ++i) { this->builder.UnsafeAppend(0); @@ -203,7 +212,6 @@ struct IndexInVisitor { const auto& state = checked_cast&>(*ctx->state()); - int32_t null_index = state.options.skip_nulls ? -1 : state.lookup_table.GetNull(); RETURN_NOT_OK(this->builder.Reserve(data.length)); VisitArrayDataInline( data, @@ -218,9 +226,9 @@ struct IndexInVisitor { } }, [&]() { - if (null_index != -1) { + if (state.null_index != -1) { // value_set included null - this->builder.UnsafeAppend(null_index); + this->builder.UnsafeAppend(state.null_index); } else { // value_set does not include null; output null this->builder.UnsafeAppendNull(); @@ -283,13 +291,8 @@ struct IsInVisitor { const auto& state = checked_cast&>(*ctx->state()); ArrayData* output = out->mutable_array(); // skip_nulls is honored for consistency with other types - if (state.value_set_has_null && !state.options.skip_nulls) { - BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, - output->length, true); - } else { - BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, - output->length, false); - } + BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, output->length, + state.value_set_has_null); return Status::OK(); } @@ -301,6 +304,7 @@ struct IsInVisitor { FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), output->offset, output->length); + VisitArrayDataInline( this->data, [&](T v) { @@ -312,7 +316,7 @@ struct IsInVisitor { writer.Next(); }, [&]() { - if (!state.options.skip_nulls && state.lookup_table.GetNull() != -1) { + if (state.null_index != -1) { writer.Set(); } else { writer.Clear(); @@ -414,6 +418,15 @@ class IndexInMetaBinary : public MetaFunction { } }; +struct SetLookupFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + EnsureDictionaryDecoded(values); + return DispatchExact(*values); + } +}; + const FunctionDoc is_in_doc{ "Find each element in a set of values", ("For each element in `values`, return true if it is found in a given\n" @@ -441,9 +454,10 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel isin_base; isin_base.init = InitSetLookup; - isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); + isin_base.exec = + TrivialScalarUnaryAsArraysExec(ExecIsIn, NullHandling::OUTPUT_NOT_NULL); isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL; - auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); + auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); @@ -458,11 +472,12 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel index_in_base; index_in_base.init = InitSetLookup; - index_in_base.exec = TrivialScalarUnaryAsArraysExec(ExecIndexIn); + index_in_base.exec = TrivialScalarUnaryAsArraysExec( + ExecIndexIn, NullHandling::COMPUTED_NO_PREALLOCATE); index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = - std::make_shared("index_in", Arity::Unary(), &index_in_doc); + std::make_shared("index_in", Arity::Unary(), &index_in_doc); AddBasicSetLookupKernels(index_in_base, /*output_type=*/int32(), index_in.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index 40907da5a62..2285c1cb9ab 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -85,6 +85,21 @@ TEST_F(TestIsInKernel, CallBinary) { AssertArraysEqual(*expected, *out.make_array()); } +TEST_F(TestIsInKernel, ImplicitlyCastValueSet) { + auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]"); + + SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")}; + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("is_in", {input}, &opts)); + + auto expected = ArrayFromJSON(boolean(), ("[false, false, true, true, false," + "true, false, true, false]")); + AssertArraysEqual(*expected, *out.make_array()); + + // fails; value_set cannot be cast to int8 + opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")}; + ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts)); +} + template class TestIsInKernelPrimitive : public ::testing::Test {}; @@ -587,5 +602,19 @@ TEST_F(TestIndexInKernel, ChunkedArrayInvoke) { CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); } +TEST(TestSetLookup, DispatchBest) { + for (std::string name : {"is_in", "index_in"}) { + CheckDispatchBest(name, {int32()}, {int32()}); + CheckDispatchBest(name, {dictionary(int32(), utf8())}, {utf8()}); + } +} + +TEST(TestSetLookup, IsInWithImplicitCasts) { + SetLookupOptions opts{ArrayFromJSON(utf8(), R"(["b", null])")}; + CheckScalarUnary("is_in", + ArrayFromJSON(dictionary(int32(), utf8()), R"(["a", "b", "c", null])"), + ArrayFromJSON(boolean(), "[0, 1, 0, 1]"), &opts); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 5d54a8c1771..73e900351fb 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -24,6 +24,8 @@ #include "arrow/array.h" #include "arrow/chunked_array.h" #include "arrow/compute/exec.h" +#include "arrow/compute/function.h" +#include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/result.h" #include "arrow/testing/gtest_util.h" @@ -173,5 +175,22 @@ void CheckScalarBinary(std::string func_name, std::shared_ptr left_input, CheckScalar(std::move(func_name), {left_input, right_input}, expected, options); } +void CheckDispatchBest(std::string func_name, std::vector original_values, + std::vector expected_equivalent_values) { + ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name)); + + auto values = original_values; + ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values)); + + ASSERT_OK_AND_ASSIGN(auto expected_kernel, + function->DispatchExact(expected_equivalent_values)); + + EXPECT_EQ(actual_kernel, expected_kernel) + << " DispatchBest" << ValueDescr::ToString(original_values) << " => " + << actual_kernel->signature->ToString() << "\n" + << " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => " + << expected_kernel->signature->ToString(); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index c38c0ceb83c..767911888ac 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -143,5 +143,10 @@ void TestRandomPrimitiveCTypes() { DoTestFunctor::Test(duration(TimeUnit::MILLI)); } +// Check that DispatchBest on a given function yields the same Kernel as +// produced by DispatchExact on another set of ValueDescrs. +void CheckDispatchBest(std::string func_name, std::vector descrs, + std::vector exact_descrs); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 3d21f5b1494..93badbd3b25 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -57,13 +57,14 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { return arg; } -ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { - return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec, + NullHandling::type null_handling) { + return [=](KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (out->is_array()) { return exec(ctx, batch, out); } - if (!batch[0].scalar()->is_valid) { + if (null_handling == NullHandling::INTERSECTION && !batch[0].scalar()->is_valid) { out->scalar()->is_valid = false; return; } diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h index aece5a97599..f614439ffb8 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.h +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -59,7 +59,8 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr); // the original exec, then the only element of the resulting array will be extracted as // the output scalar. This could be far more efficient, but instead of optimizing this // it'd be better to support scalar inputs "upstream" in original exec. -ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec); +ArrayKernelExec TrivialScalarUnaryAsArraysExec( + ArrayKernelExec exec, NullHandling::type null_handling = NullHandling::INTERSECTION); } // namespace internal } // namespace compute diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 16f706ed1a4..56339430ee9 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -21,6 +21,7 @@ #include #include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" #include "arrow/compute/exec_internal.h" #include "arrow/dataset/expression_internal.h" #include "arrow/io/memory.h" @@ -306,7 +307,7 @@ size_t Expression::hash() const { } bool Expression::IsBound() const { - if (descr().type == nullptr) return false; + if (type() == nullptr) return false; if (auto call = this->call()) { if (call->kernel == nullptr) return false; @@ -359,7 +360,7 @@ bool Expression::IsNullLiteral() const { } bool Expression::IsSatisfiable() const { - if (descr().type && descr().type->id() == Type::NA) { + if (type() && type()->id() == Type::NA) { return false; } @@ -378,124 +379,62 @@ bool Expression::IsSatisfiable() const { namespace { -Result> InitKernelState( - const Expression::Call& call, compute::ExecContext* exec_context) { - if (!call.kernel->init) return nullptr; - - compute::KernelContext kernel_context(exec_context); - compute::KernelInitArgs kernel_init_args{call.kernel, GetDescriptors(call.arguments), - call.options.get()}; - - auto kernel_state = call.kernel->init(&kernel_context, kernel_init_args); - RETURN_NOT_OK(kernel_context.status()); - return std::move(kernel_state); -} - -Status InsertImplicitCasts(Expression::Call* call); - // Produce a bound Expression from unbound Call and bound arguments. -Result BindNonRecursive(const Expression::Call& call, - std::vector arguments, - bool insert_implicit_casts, +Result BindNonRecursive(Expression::Call call, bool insert_implicit_casts, compute::ExecContext* exec_context) { - DCHECK(std::all_of(arguments.begin(), arguments.end(), + DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - auto bound_call = call; - bound_call.arguments = std::move(arguments); - - if (insert_implicit_casts) { - RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); - } - - ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); - - auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); - - compute::KernelContext kernel_context(exec_context); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, - InitKernelState(bound_call, exec_context)); - kernel_context.SetState(bound_call.kernel_state.get()); - - ARROW_ASSIGN_OR_RAISE( - bound_call.descr, - bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); - - return Expression(std::move(bound_call)); -} - -Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { - if (expr->descr().type->Equals(to_type)) { - return Status::OK(); - } - - if (auto lit = expr->literal()) { - ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, to_type)); - *expr = literal(std::move(new_lit)); - return Status::OK(); - } - - Expression::Call with_cast; - with_cast.function_name = "cast"; - with_cast.options = std::make_shared( - compute::CastOptions::Safe(std::move(to_type))); + auto descrs = GetDescriptors(call.arguments); + ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context)); - compute::ExecContext exec_context; - ARROW_ASSIGN_OR_RAISE(*expr, - BindNonRecursive(with_cast, {std::move(*expr)}, - /*insert_implicit_casts=*/false, &exec_context)); - return Status::OK(); -} + if (!insert_implicit_casts) { + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(descrs)); + } else { + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&descrs)); -Status InsertImplicitCasts(Expression::Call* call) { - DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), - [](const Expression& argument) { return argument.IsBound(); })); + for (size_t i = 0; i < descrs.size(); ++i) { + if (descrs[i] == call.arguments[i].descr()) continue; - if (IsSameTypesBinary(call->function_name)) { - for (auto&& argument : call->arguments) { - if (auto value_type = GetDictionaryValueType(argument.descr().type)) { - RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); + if (descrs[i].shape != call.arguments[i].descr().shape) { + return Status::NotImplemented( + "Automatic broadcasting of scalars arguments to arrays in ", + Expression(std::move(call)).ToString()); } - } - - if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { - // argument 0 is scalar so casting is cheap - return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); - } - // cast argument 1 unconditionally - return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); - } + if (auto lit = call.arguments[i].literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, descrs[i].type)); + call.arguments[i] = literal(std::move(new_lit)); + continue; + } - if (auto options = GetSetLookupOptions(*call)) { - if (auto value_type = GetDictionaryValueType(call->arguments[0].descr().type)) { - // DICTIONARY input is not supported; decode it. - RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &call->arguments[0])); - } + // construct an implicit cast Expression with which to replace this argument + Expression::Call implicit_cast; + implicit_cast.function_name = "cast"; + implicit_cast.arguments = {std::move(call.arguments[i])}; + implicit_cast.options = std::make_shared( + compute::CastOptions::Safe(descrs[i].type)); - if (options->value_set.type()->id() == Type::DICTIONARY) { - // DICTIONARY value_set is not supported; decode it. - auto new_options = std::make_shared(*options); - RETURN_NOT_OK(EnsureNotDictionary(&new_options->value_set)); - options = new_options.get(); - call->options = std::move(new_options); + ARROW_ASSIGN_OR_RAISE( + call.arguments[i], + BindNonRecursive(std::move(implicit_cast), + /*insert_implicit_casts=*/false, exec_context)); } + } - if (!options->value_set.type()->Equals(call->arguments[0].descr().type)) { - // The value_set is assumed smaller than inputs, casting it should be cheaper. - auto new_options = std::make_shared(*options); - ARROW_ASSIGN_OR_RAISE(new_options->value_set, - compute::Cast(std::move(new_options->value_set), - call->arguments[0].descr().type)); - options = new_options.get(); - call->options = std::move(new_options); - } + compute::KernelContext kernel_context(exec_context); + if (call.kernel->init) { + call.kernel_state = + call.kernel->init(&kernel_context, {call.kernel, descrs, call.options.get()}); - return Status::OK(); + RETURN_NOT_OK(kernel_context.status()); + kernel_context.SetState(call.kernel_state.get()); } - return Status::OK(); + ARROW_ASSIGN_OR_RAISE( + call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); + + return Expression(std::move(call)); } struct FieldPathGetDatumImpl { @@ -554,14 +493,11 @@ Result Expression::Bind(ValueDescr in, return Expression{Parameter{*ref, std::move(descr)}}; } - auto call = CallNotNull(*this); - - std::vector bound_arguments(call->arguments.size()); - for (size_t i = 0; i < bound_arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(bound_arguments[i], call->arguments[i].Bind(in, exec_context)); + auto call = *CallNotNull(*this); + for (auto& argument : call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); } - - return BindNonRecursive(*call, std::move(bound_arguments), + return BindNonRecursive(std::move(call), /*insert_implicit_casts=*/true, exec_context); } @@ -595,8 +531,8 @@ Result ExecuteScalarExpression(const Expression& expr, const Datum& input // Refernced field was present but didn't have the expected type. // Should we just error here? For now, pay dispatch cost and just cast. ARROW_ASSIGN_OR_RAISE( - field, compute::Cast(field, expr.descr().type, compute::CastOptions::Safe(), - exec_context)); + field, + compute::Cast(field, expr.type(), compute::CastOptions::Safe(), exec_context)); } return field; @@ -803,8 +739,24 @@ Result ReplaceFieldsWithKnownValues( if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { - ARROW_ASSIGN_OR_RAISE(Datum lit, - compute::Cast(it->second, expr.descr().type)); + Datum lit = it->second; + if (expr.type()->id() == Type::DICTIONARY) { + if (lit.is_scalar()) { + // FIXME the "right" way to support this is adding support for scalars to + // dictionary_encode and support for casting between index types to cast + ARROW_ASSIGN_OR_RAISE( + auto index, + Int32Scalar(0).CastTo( + checked_cast(*expr.type()).index_type())); + + ARROW_ASSIGN_OR_RAISE(auto dictionary, + MakeArrayFromScalar(*lit.scalar(), 1)); + + return literal( + DictionaryScalar::Make(std::move(index), std::move(dictionary))); + } + } + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(it->second, expr.type())); return literal(std::move(lit)); } } @@ -900,7 +852,7 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont flipped_call.function_name = Comparison::GetName(Comparison::GetFlipped(*cmp)); - return BindNonRecursive(flipped_call, std::move(flipped_call.arguments), + return BindNonRecursive(flipped_call, /*insert_implicit_casts=*/false, exec_context); } } @@ -926,7 +878,10 @@ Result DirectComparisonSimplification(Expression expr, if (!cmp) return expr; if (!cmp_guarantee) return expr; - if (call->arguments[0] != guarantee.arguments[0]) return expr; + + const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); + const auto& guarantee_lhs = guarantee.arguments[0]; + if (lhs != guarantee_lhs) return expr; auto rhs = call->arguments[1].literal(); auto guarantee_rhs = guarantee.arguments[1].literal(); diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 984c846210f..13c714b2d72 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -50,8 +50,8 @@ class ARROW_DS_EXPORT Expression { std::shared_ptr> hash; // post-Bind properties: - const compute::Kernel* kernel = NULLPTR; std::shared_ptr function; + const compute::Kernel* kernel = NULLPTR; std::shared_ptr kernel_state; ValueDescr descr; }; @@ -104,6 +104,7 @@ class ARROW_DS_EXPORT Expression { /// The type and shape to which this expression will evaluate ValueDescr descr() const; + std::shared_ptr type() const { return descr().type; } // XXX someday // NullGeneralization::type nullable() const; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 6a0f54fa8d5..24e60377f5a 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -86,16 +86,12 @@ struct Comparison { return nullptr; } - // Execute a simple Comparison between scalars, casting the RHS if types disagree + // Execute a simple Comparison between scalars static Result Execute(Datum l, Datum r) { if (!l.is_scalar() || !r.is_scalar()) { return Status::Invalid("Cannot Execute Comparison on non-scalars"); } - if (!l.type()->Equals(r.type())) { - ARROW_ASSIGN_OR_RAISE(r, compute::Cast(r, l.type())); - } - std::vector arguments{std::move(l), std::move(r)}; ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); @@ -109,6 +105,44 @@ struct Comparison { return less.scalar_as().value ? LESS : GREATER; } + // Given an Expression wrapped in casts which preserve ordering + // (for example, cast(field_ref("i16"), to_type=int32())), unwrap the inner Expression. + // This is used to destructure implicitly cast field_refs during Expression + // simplification. + static const Expression& StripOrderPreservingCasts(const Expression& expr) { + auto call = expr.call(); + if (!call) return expr; + if (call->function_name != "cast") return expr; + + const Expression& from = call->arguments[0]; + + auto from_id = from.type()->id(); + auto to_id = expr.type()->id(); + + if (is_floating(to_id)) { + if (is_integer(from_id) || is_floating(from_id)) { + return StripOrderPreservingCasts(from); + } + return expr; + } + + if (is_unsigned_integer(to_id)) { + if (is_unsigned_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) { + return StripOrderPreservingCasts(from); + } + return expr; + } + + if (is_signed_integer(to_id)) { + if (is_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) { + return StripOrderPreservingCasts(from); + } + return expr; + } + + return expr; + } + static type GetFlipped(type op) { switch (op) { case NA: @@ -182,14 +216,6 @@ inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } -inline bool IsSameTypesBinary(const std::string& function) { - if (Comparison::Get(function)) return true; - - static std::unordered_set set{"add", "subtract", "multiply", "divide"}; - - return set.find(function) != set.end(); -} - inline const compute::SetLookupOptions* GetSetLookupOptions( const Expression::Call& call) { if (!IsSetLookup(call.function_name)) return nullptr; @@ -206,38 +232,6 @@ inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call return checked_cast(call.options.get()); } -inline std::shared_ptr GetDictionaryValueType( - const std::shared_ptr& type) { - if (type && type->id() == Type::DICTIONARY) { - return checked_cast(*type).value_type(); - } - return nullptr; -} - -inline Status EnsureNotDictionary(ValueDescr* descr) { - if (auto value_type = GetDictionaryValueType(descr->type)) { - descr->type = std::move(value_type); - } - return Status::OK(); -} - -inline Status EnsureNotDictionary(Datum* datum) { - if (datum->type()->id() == Type::DICTIONARY) { - const auto& type = checked_cast(*datum->type()).value_type(); - ARROW_ASSIGN_OR_RAISE(*datum, compute::Cast(*datum, type)); - } - return Status::OK(); -} - -inline Status EnsureNotDictionary(Expression::Call* call) { - if (auto options = GetSetLookupOptions(*call)) { - auto new_options = *options; - RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); - call->options.reset(new compute::SetLookupOptions(std::move(new_options))); - } - return Status::OK(); -} - /// A helper for unboxing an Expression composed of associative function calls. /// Such expressions can frequently be rearranged to a semantically equivalent /// expression for more optimal execution or more straightforward manipulation. diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index da5c82425b3..ae62283b1d7 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -57,40 +57,96 @@ void ExpectResultsEqual(Actual&& actual, Expected&& expected) { MaybeExpected maybe_expected(std::forward(expected)); if (maybe_expected.ok()) { - ASSERT_OK_AND_ASSIGN(auto actual, maybe_actual); - EXPECT_EQ(actual, *maybe_expected); + EXPECT_EQ(maybe_actual, maybe_expected); } else { - EXPECT_EQ(maybe_actual.status().code(), expected.status().code()); - EXPECT_NE(maybe_actual.status().message().find(expected.status().message()), - std::string::npos) - << " actual: " << maybe_actual.status() << "\n" - << " expected: " << maybe_expected.status(); + EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT( + expected.status().code(), HasSubstr(expected.status().message()), maybe_actual); } } +const auto no_change = util::nullopt; + TEST(ExpressionUtils, Comparison) { auto Expect = [](Result expected, Datum l, Datum r) { ExpectResultsEqual(Comparison::Execute(l, r).Map(Comparison::GetName), expected); }; - Datum zero(0), one(1), two(2), null(std::make_shared()), str("hello"); + Datum zero(0), one(1), two(2), null(std::make_shared()); + Datum str("hello"), bin(std::make_shared(Buffer::FromString("hello"))); + Datum dict_str(DictionaryScalar::Make(std::make_shared(0), + ArrayFromJSON(utf8(), R"(["a", "b", "c"])"))); - Status parse_failure = Status::Invalid("Failed to parse"); + Status not_impl = Status::NotImplemented("no kernel matching input types"); Expect("equal", one, one); Expect("less", one, two); Expect("greater", one, zero); - // cast RHS to LHS type; "hello" > "1" - Expect("greater", str, one); - // cast RHS to LHS type; "hello" is not convertible to int - Expect(parse_failure, one, str); - Expect("na", one, null); - Expect("na", str, null); Expect("na", null, one); - // cast RHS to LHS type; "hello" is not convertible to int - Expect(parse_failure, null, str); + + // strings and ints are not comparable without explicit casts + Expect(not_impl, str, one); + Expect(not_impl, one, str); + Expect(not_impl, str, null); // not even null ints + + // string -> binary implicit cast allowed + Expect("equal", str, bin); + Expect("equal", bin, str); + + // dict_str -> string, implicit casts allowed + Expect("less", dict_str, str); + Expect("less", dict_str, bin); +} + +TEST(ExpressionUtils, StripOrderPreservingCasts) { + auto Expect = [](Expression expr, util::optional expected_stripped) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + if (!expected_stripped) { + expected_stripped = expr; + } else { + ASSERT_OK_AND_ASSIGN(expected_stripped, expected_stripped->Bind(*kBoringSchema)); + } + EXPECT_EQ(Comparison::StripOrderPreservingCasts(expr), *expected_stripped); + }; + + // Casting int to float preserves ordering. + // For example, let + // a = 3, b = 2, assert(a > b) + // After injecting a cast to float, this ordering still holds + // float(a) == 3.0, float(b) == 2.0, assert(float(a) > float(b)) + Expect(cast(field_ref("i32"), float32()), field_ref("i32")); + + // Casting an integral type to a wider integral type preserves ordering. + Expect(cast(field_ref("i32"), int64()), field_ref("i32")); + Expect(cast(field_ref("i32"), int32()), field_ref("i32")); + Expect(cast(field_ref("i32"), int16()), no_change); + Expect(cast(field_ref("i32"), int8()), no_change); + + Expect(cast(field_ref("u32"), uint64()), field_ref("u32")); + Expect(cast(field_ref("u32"), uint32()), field_ref("u32")); + Expect(cast(field_ref("u32"), uint16()), no_change); + Expect(cast(field_ref("u32"), uint8()), no_change); + + Expect(cast(field_ref("u32"), int64()), field_ref("u32")); + Expect(cast(field_ref("u32"), int32()), field_ref("u32")); + Expect(cast(field_ref("u32"), int16()), no_change); + Expect(cast(field_ref("u32"), int8()), no_change); + + // Casting float to int can affect ordering. + // For example, let + // a = 3.5, b = 3.0, assert(a > b) + // After injecting a cast to integer, this ordering may no longer hold + // int(a) == 3, int(b) == 3, assert(!(int(a) > int(b))) + Expect(cast(field_ref("f32"), int32()), no_change); + + // casting any float type to another preserves ordering + Expect(cast(field_ref("f32"), float64()), field_ref("f32")); + Expect(cast(field_ref("f64"), float32()), field_ref("f64")); + + // casting signed integer to unsigned can alter ordering + Expect(cast(field_ref("i32"), uint32()), no_change); + Expect(cast(field_ref("i32"), uint64()), no_change); } TEST(Expression, ToString) { @@ -240,9 +296,9 @@ TEST(Expression, IsSatisfiable) { // When a top level conjunction contains an Expression which is certain to evaluate to // null, it can only evaluate to null or false. - auto null_or_false = and_(literal(null), field_ref("a")); - // This may appear in satisfiable filters if coalesced - EXPECT_TRUE(call("is_null", {null_or_false}).IsSatisfiable()); + auto never_true = and_(literal(null), field_ref("a")); + // This may appear in satisfiable filters if coalesced (for example, wrapped in fill_na) + EXPECT_TRUE(call("is_null", {never_true}).IsSatisfiable()); // ... but at the top level it is not satisfiable. // This special case arises when (for example) an absent column has made // one member of the conjunction always-null. This is fairly common and @@ -307,8 +363,6 @@ void ExpectBindsTo(Expression expr, util::optional expected, } } -const auto no_change = util::nullopt; - TEST(Expression, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set auto expr = field_ref("alpha"); @@ -342,42 +396,41 @@ TEST(Expression, BindCall) { ExpectBindsTo(expr, no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - // literal(3) may be safely cast to float32, so binding this expr casts that literal: ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), call("add", {field_ref("f32"), literal(3.0F)})); - // literal(3.5) may not be safely cast to int32, so binding this expr fails: - ASSERT_RAISES(Invalid, - call("add", {field_ref("i32"), literal(3.5)}).Bind(*kBoringSchema)); + ExpectBindsTo(call("add", {field_ref("i32"), literal(3.5F)}), + call("add", {cast(field_ref("i32"), float32()), literal(3.5F)})); } TEST(Expression, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { - // cast arguments to same type + // cast arguments to common numeric type + ExpectBindsTo(cmp(field_ref("i64"), field_ref("i32")), + cmp(field_ref("i64"), cast(field_ref("i32"), int64()))); + + ExpectBindsTo(cmp(field_ref("i64"), field_ref("f32")), + cmp(cast(field_ref("i64"), float32()), field_ref("f32"))); + ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), - cmp(field_ref("i32"), cast(field_ref("i64"), int32()))); - // NB: RHS is cast unless LHS is scalar. + cmp(cast(field_ref("i32"), int64()), field_ref("i64"))); + + ExpectBindsTo(cmp(field_ref("i8"), field_ref("u32")), + cmp(cast(field_ref("i8"), int64()), cast(field_ref("u32"), int64()))); // cast dictionary to value type ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); - } - // scalars are directly cast when possible: - auto ts_scalar = MakeScalar("1990-10-23")->CastTo(timestamp(TimeUnit::NANO)); - ExpectBindsTo(equal(field_ref("ts_ns"), literal("1990-10-23")), - equal(field_ref("ts_ns"), literal(*ts_scalar))); + ExpectBindsTo(cmp(field_ref("dict_i32"), literal(int64_t(4))), + cmp(cast(field_ref("dict_i32"), int64()), literal(int64_t(4)))); + } - // cast value_set to argument type - auto Opts = [](std::shared_ptr type) { - return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; - }; - ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(binary())), - call("is_in", {field_ref("str")}, Opts(utf8()))); + compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")}; - // dictionary decode set then cast to argument type - ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))), - call("is_in", {field_ref("str")}, Opts(utf8()))); + // cast dictionary to value type + ExpectBindsTo(call("is_in", {field_ref("dict_str")}, in_a), + call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a)); } TEST(Expression, BindNestedCall) { @@ -519,16 +572,6 @@ TEST(Expression, ExecuteDictionaryTransparent) { {"a": "", "b": ""}, {"a": "hi", "b": "hello"} ])")); - - Datum dict_set = ArrayFromJSON(dictionary(int32(), utf8()), R"(["a"])"); - AssertExecute(call("is_in", {field_ref("a")}, - compute::SetLookupOptions{dict_set, - /*skip_nulls=*/false}), - ArrayFromJSON(struct_({field("a", utf8())}), R"([ - {"a": "a"}, - {"a": "good"}, - {"a": null} - ])")); } void ExpectIdenticalIfUnchanged(Expression modified, Expression original) { @@ -874,6 +917,12 @@ TEST(Expression, SingleComparisonGuarantees) { .WithGuarantee(equal(i32, literal(5))) .Expect(false); + Simplify{ + equal(i32, literal(0.5)), + } + .WithGuarantee(greater_equal(i32, literal(1))) + .Expect(false); + // no simplification possible: Simplify{ not_equal(i32, literal(3)), @@ -949,27 +998,28 @@ TEST(Expression, SimplifyWithGuarantee) { .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), less_equal(field_ref("i32"), literal(1)))) .Expect(equal(field_ref("i32"), literal(0))); - Simplify{ - or_(equal(field_ref("f32"), literal("0")), equal(field_ref("i32"), literal(3)))} + + Simplify{or_(equal(field_ref("f32"), literal(0)), equal(field_ref("i32"), literal(3)))} .WithGuarantee(greater(field_ref("f32"), literal(0.0))) .Expect(equal(field_ref("i32"), literal(3))); // simplification can see through implicit casts - Simplify{or_({equal(field_ref("f32"), literal("0")), - call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ - ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})})} - .WithGuarantee(greater(field_ref("f32"), literal(0.0))) - .Expect(call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); + compute::SetLookupOptions in_123{ArrayFromJSON(int32(), "[1,2,3]"), true}; + Simplify{or_({equal(field_ref("f32"), literal(0)), + call("is_in", {field_ref("i64")}, in_123)})} + .WithGuarantee(greater(field_ref("f32"), literal(0.F))) + .Expect(call("is_in", {field_ref("i64")}, in_123)); + + Simplify{greater(field_ref("dict_i32"), literal(int64_t(1)))} + .WithGuarantee(equal(field_ref("dict_i32"), literal(0))) + .Expect(false); } TEST(Expression, SimplifyThenExecute) { auto filter = - or_({equal(field_ref("f32"), literal("0")), + or_({equal(field_ref("f32"), literal(0)), call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ - ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})}); + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})}); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); auto guarantee = greater(field_ref("f32"), literal(0.0)); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index f0d44cfe3d6..c72283312cb 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -50,16 +50,19 @@ namespace arrow { namespace dataset { const std::shared_ptr kBoringSchema = schema({ + field("bool", boolean()), + field("i8", int8()), field("i32", int32()), field("i32_req", int32(), /*nullable=*/false), + field("u32", uint32()), field("i64", int64()), - field("date64", date64()), field("f32", float32()), field("f32_req", float32(), /*nullable=*/false), field("f64", float64()), - field("bool", boolean()), + field("date64", date64()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), + field("dict_i32", dictionary(int32(), int32())), field("ts_ns", timestamp(TimeUnit::NANO)), }); diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 786110996dc..dd10fce3e4d 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -211,6 +211,19 @@ static std::string FormatValueDescr(const ValueDescr& descr) { std::string ValueDescr::ToString() const { return FormatValueDescr(*this); } +std::string ValueDescr::ToString(const std::vector& descrs) { + std::stringstream ss; + ss << "("; + for (size_t i = 0; i < descrs.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << descrs[i].ToString(); + } + ss << ")"; + return ss.str(); +} + void PrintTo(const ValueDescr& descr, std::ostream* os) { *os << descr.ToString(); } std::string Datum::ToString() const { diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index fb783ea5261..6ba6af7f79e 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -89,6 +89,7 @@ struct ARROW_EXPORT ValueDescr { bool operator!=(const ValueDescr& other) const { return !(*this == other); } std::string ToString() const; + static std::string ToString(const std::vector&); ARROW_EXPORT friend void PrintTo(const ValueDescr&, std::ostream*); }; diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index 09dfd59c8d2..6504d950674 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -317,6 +317,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return ValueUnsafe(); } const T& operator*() const& { return ValueOrDie(); } + const T* operator->() const& { return &ValueOrDie(); } /// Gets a mutable reference to the stored `T` value. /// @@ -331,6 +332,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return ValueUnsafe(); } T& operator*() & { return ValueOrDie(); } + T* operator->() & { return &ValueOrDie(); } /// Moves and returns the internally-stored `T` value. /// @@ -453,9 +455,9 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { } }; -#define ARROW_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ - auto&& result_name = (rexpr); \ - ARROW_RETURN_NOT_OK((result_name).status()); \ +#define ARROW_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + ARROW_RETURN_IF_(!(result_name).ok(), (result_name).status(), ARROW_STRINGIFY(rexpr)); \ lhs = std::move(result_name).ValueUnsafe(); #define ARROW_ASSIGN_OR_RAISE_NAME(x, y) ARROW_CONCAT(x, y) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index eca711d7c4f..06fc6783ff3 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -268,6 +268,13 @@ Result> DictionaryScalar::GetEncodedValue() const { return value.dictionary->GetScalar(index_value); } +std::shared_ptr DictionaryScalar::Make(std::shared_ptr index, + std::shared_ptr dict) { + auto type = dictionary(index->type, dict->type()); + return std::make_shared(ValueType{std::move(index), std::move(dict)}, + std::move(type)); +} + template using scalar_constructor_has_arrow_type = std::is_constructible::ScalarType, std::shared_ptr>; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 2888874d292..e84e3fab900 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -448,6 +448,9 @@ struct ARROW_EXPORT DictionaryScalar : public Scalar { DictionaryScalar(ValueType value, std::shared_ptr type, bool is_valid = true) : Scalar(std::move(type), is_valid), value(std::move(value)) {} + static std::shared_ptr Make(std::shared_ptr index, + std::shared_ptr dict); + Result> GetEncodedValue() const; }; diff --git a/cpp/src/arrow/status.cc b/cpp/src/arrow/status.cc index 480bbd3e468..cfc5eb1e345 100644 --- a/cpp/src/arrow/status.cc +++ b/cpp/src/arrow/status.cc @@ -132,7 +132,7 @@ void Status::Abort(const std::string& message) const { void Status::AddContextLine(const char* filename, int line, const char* expr) { ARROW_CHECK(!ok()) << "Cannot add context line to ok status"; std::stringstream ss; - ss << "\nIn " << filename << ", line " << line << ", code: " << expr; + ss << "\n" << filename << ":" << line << " " << expr; state_->msg += ss.str(); } #endif diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 2e523eac2bb..cdb23a92899 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -79,15 +79,18 @@ EXPECT_THAT(_st.ToString(), (matcher)); \ } while (false) -#define ASSERT_OK(expr) \ - do { \ - auto _res = (expr); \ - ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \ - if (!_st.ok()) { \ - FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString(); \ - } \ +#define EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(code, matcher, expr) \ + do { \ + auto _res = (expr); \ + ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \ + EXPECT_EQ(_st.CodeAsString(), Status::CodeAsString(code)); \ + EXPECT_THAT(_st.ToString(), (matcher)); \ } while (false) +#define ASSERT_OK(expr) \ + for (::arrow::Status _st = ::arrow::internal::GenericToStatus((expr)); !_st.ok();) \ + FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString() + #define ASSERT_OK_NO_THROW(expr) ASSERT_NO_THROW(ASSERT_OK(expr)) #define ARROW_EXPECT_OK(expr) \ @@ -426,13 +429,6 @@ inline void BitmapFromVector(const std::vector& is_valid, ASSERT_OK(GetBitmapFromVector(is_valid, out)); } -template -void AssertSortedEquals(std::vector u, std::vector v) { - std::sort(u.begin(), u.end()); - std::sort(v.begin(), v.end()); - ASSERT_EQ(u, v); -} - ARROW_TESTING_EXPORT void SleepFor(double seconds); @@ -474,6 +470,17 @@ class ARROW_TESTING_EXPORT EnvVarGuard { #define LARGE_MEMORY_TEST(name) name #endif +inline void PrintTo(const Status& st, std::ostream* os) { *os << st.ToString(); } + +template +void PrintTo(const Result& result, std::ostream* os) { + if (result.ok()) { + ::testing::internal::UniversalPrint(result.ValueOrDie(), os); + } else { + *os << result.status(); + } +} + } // namespace arrow namespace nonstd { diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 2dcfc77c437..e872a31f31d 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -929,6 +929,52 @@ static inline bool is_fixed_width(Type::type type_id) { return is_primitive(type_id) || is_dictionary(type_id) || is_fixed_size_binary(type_id); } +static inline int bit_width(Type::type type_id) { + switch (type_id) { + case Type::BOOL: + return 1; + case Type::UINT8: + case Type::INT8: + return 8; + case Type::UINT16: + case Type::INT16: + return 16; + case Type::UINT32: + case Type::INT32: + case Type::DATE32: + case Type::TIME32: + return 32; + case Type::UINT64: + case Type::INT64: + case Type::DATE64: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::DURATION: + return 64; + + case Type::HALF_FLOAT: + return 16; + case Type::FLOAT: + return 32; + case Type::DOUBLE: + return 64; + + case Type::INTERVAL_MONTHS: + return 32; + case Type::INTERVAL_DAY_TIME: + return 64; + + case Type::DECIMAL128: + return 128; + case Type::DECIMAL256: + return 256; + + default: + break; + } + return 0; +} + static inline bool is_nested(Type::type type_id) { switch (type_id) { case Type::LIST: diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index c513ed5b0ab..4101c36ef8f 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -104,6 +104,57 @@ exact semantics of the function:: .. seealso:: :doc:`Compute API reference ` +Implicit casts +============== + +Functions may require conversion of their arguments before execution if a +kernel does not match the argument types precisely. For example comparison +of dictionary encoded arrays is not directly supported by any kernel, but an +implicit cast can be made allowing comparison against the decoded array. + +Each function may define implicit cast behaviour as appropriate. For example +comparison and arithmetic kernels require identically typed arguments, and +support execution against differing numeric types by promoting their arguments +to numeric type which can accommodate any value from either input. + +.. _common-numeric-type: + +Common numeric type +~~~~~~~~~~~~~~~~~~~ + +The common numeric type of a set of input numeric types is the smallest numeric +type which can accommodate any value of any input. If any input is a floating +point type the common numeric type is the widest floating point type among the +inputs. Otherwise the common numeric type is integral and is signed if any input +is signed. For example: + ++-------------------+----------------------+------------------------------------------------+ +| Input types | Common numeric type | Notes | ++===================+======================+================================================+ +| int32, int32 | int32 | | ++-------------------+----------------------+------------------------------------------------+ +| int16, int32 | int32 | Max width is 32, promote LHS to int32 | ++-------------------+----------------------+------------------------------------------------+ +| uint16, int32 | int32 | One input signed, override unsigned | ++-------------------+----------------------+------------------------------------------------+ +| uint32, int32 | int64 | Widen to accommodate range of uint32 | ++-------------------+----------------------+------------------------------------------------+ +| uint16, uint32 | uint32 | All inputs unsigned, maintain unsigned | ++-------------------+----------------------+------------------------------------------------+ +| int16, uint32 | int64 | | ++-------------------+----------------------+------------------------------------------------+ +| uint64, int16 | int64 | int64 cannot accommodate all uint64 values | ++-------------------+----------------------+------------------------------------------------+ +| float32, int32 | float32 | Promote RHS to float32 | ++-------------------+----------------------+------------------------------------------------+ +| float32, float64 | float64 | | ++-------------------+----------------------+------------------------------------------------+ +| float32, int64 | float32 | int64 is wider, still promotes to float32 | ++-------------------+----------------------+------------------------------------------------+ + +In particulary, note that comparing a ``uint64`` column to an ``int16`` column +may emit an error if one of the ``uint64`` values cannot be expressed as the +common type ``int64`` (for example, ``2 ** 63``). .. _compute-function-list: @@ -196,9 +247,11 @@ Binary functions have the following semantics (which is sometimes called Arithmetic functions ~~~~~~~~~~~~~~~~~~~~ -These functions expect two inputs of the same type and apply a given binary +These functions expect two inputs of numeric type and apply a given binary operation to each pair of elements gathered from the inputs. If any of the input elements in a pair is null, the corresponding output element is null. +Inputs will be cast to the :ref:`common numeric type ` +(and dictionary decoded, if applicable) before the operation is applied. The default variant of these functions does not detect overflow (the result then typically wraps around). Each function is also available in an @@ -228,9 +281,12 @@ an ``Invalid`` :class:`Status` when overflow is detected. Comparisons ~~~~~~~~~~~ -Those functions expect two inputs of the same type and apply a given -comparison operator. If any of the input elements in a pair is null, -the corresponding output element is null. +These functions expect two inputs of numeric type (in which case they will be +cast to the :ref:`common numeric type ` before comparison), +or two inputs of Binary- or String-like types, or two inputs of Temporal types. +If any input is dictionary encoded it will be expanded for the purposes of +comparison. If any of the input elements in a pair is null, the corresponding +output element is null. +--------------------------+------------+---------------------------------------------+---------------------+ | Function names | Arity | Input types | Output type | @@ -744,3 +800,4 @@ Structural transforms * \(2) For each value in the list child array, the index at which it is found in the list array is appended to the output. Nulls in the parent list array are discarded. + diff --git a/python/pyarrow/tests/parquet/test_dataset.py b/python/pyarrow/tests/parquet/test_dataset.py index 42ce187f58b..cc49f14030a 100644 --- a/python/pyarrow/tests/parquet/test_dataset.py +++ b/python/pyarrow/tests/parquet/test_dataset.py @@ -206,7 +206,7 @@ def test_filters_equivalency(tempdir, use_legacy_dataset): dataset = pq.ParquetDataset( base_path, filesystem=fs, filters=[('integer', '=', 1), ('string', '!=', 'b'), - ('boolean', '==', True)], + ('boolean', '==', 'True')], use_legacy_dataset=use_legacy_dataset, ) table = dataset.read() diff --git a/r/R/array.R b/r/R/array.R index ec2b545dfae..acb612be5ef 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -62,6 +62,7 @@ #' - `$type_id()`: type id #' - `$Equals(other)` : is this array equal to `other` #' - `$ApproxEquals(other)` : +#' - `$Diff(other)` : return a string expressing the difference between two arrays #' - `$data()`: return the underlying [ArrayData][ArrayData] #' - `$as_vector()`: convert to an R vector #' - `$ToString()`: string representation of the array @@ -95,6 +96,12 @@ Array <- R6Class("Array", ApproxEquals = function(other) { inherits(other, "Array") && Array__ApproxEquals(self, other) }, + Diff = function(other) { + if (!inherits(other, "Array")) { + other <- Array$create(other) + } + Array__Diff(self, other) + }, data = function() Array__data(self), as_vector = function() Array__as_vector(self), ToString = function() { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index d6a5f9356e8..ec0aae94f30 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -48,6 +48,10 @@ Array__ApproxEquals <- function(lhs, rhs){ .Call(`_arrow_Array__ApproxEquals`, lhs, rhs) } +Array__Diff <- function(lhs, rhs){ + .Call(`_arrow_Array__Diff`, lhs, rhs) +} + Array__data <- function(array){ .Call(`_arrow_Array__data`, array) } diff --git a/r/R/expression.R b/r/R/expression.R index 0198e0ebe6a..5475f7a44bc 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -144,16 +144,6 @@ eval_array_expression <- function(x) { a } }) - if (length(x$args) == 2L) { - # Insert implicit casts - if (inherits(x$args[[1]], "Scalar")) { - x$args[[1]] <- x$args[[1]]$cast(x$args[[2]]$type) - } else if (inherits(x$args[[2]], "Scalar")) { - x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) - } else if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Array")) { - x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) - } - } call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } diff --git a/r/man/array.Rd b/r/man/array.Rd index b133c073824..fbc91e4dc35 100644 --- a/r/man/array.Rd +++ b/r/man/array.Rd @@ -60,6 +60,7 @@ a == a \item \verb{$type_id()}: type id \item \verb{$Equals(other)} : is this array equal to \code{other} \item \verb{$ApproxEquals(other)} : +\item \verb{$Diff(other)} : return a string expressing the difference between two arrays \item \verb{$data()}: return the underlying \link{ArrayData} \item \verb{$as_vector()}: convert to an R vector \item \verb{$ToString()}: string representation of the array diff --git a/r/src/array.cpp b/r/src/array.cpp index e96e286a073..9601ee43c03 100644 --- a/r/src/array.cpp +++ b/r/src/array.cpp @@ -141,6 +141,12 @@ bool Array__ApproxEquals(const std::shared_ptr& lhs, return lhs->ApproxEquals(rhs); } +// [[arrow::export]] +std::string Array__Diff(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->Diff(*rhs); +} + // [[arrow::export]] std::shared_ptr Array__data( const std::shared_ptr& array) { diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index ae90abd5adf..2fbfecacfa1 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -108,6 +108,15 @@ BEGIN_CPP11 END_CPP11 } // array.cpp +std::string Array__Diff(const std::shared_ptr& lhs, const std::shared_ptr& rhs); +extern "C" SEXP _arrow_Array__Diff(SEXP lhs_sexp, SEXP rhs_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type lhs(lhs_sexp); + arrow::r::Input&>::type rhs(rhs_sexp); + return cpp11::as_sexp(Array__Diff(lhs, rhs)); +END_CPP11 +} +// array.cpp std::shared_ptr Array__data(const std::shared_ptr& array); extern "C" SEXP _arrow_Array__data(SEXP array_sexp){ BEGIN_CPP11 @@ -3512,6 +3521,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Array__type_id", (DL_FUNC) &_arrow_Array__type_id, 1}, { "_arrow_Array__Equals", (DL_FUNC) &_arrow_Array__Equals, 2}, { "_arrow_Array__ApproxEquals", (DL_FUNC) &_arrow_Array__ApproxEquals, 2}, + { "_arrow_Array__Diff", (DL_FUNC) &_arrow_Array__Diff, 2}, { "_arrow_Array__data", (DL_FUNC) &_arrow_Array__data, 1}, { "_arrow_Array__RangeEquals", (DL_FUNC) &_arrow_Array__RangeEquals, 5}, { "_arrow_Array__View", (DL_FUNC) &_arrow_Array__View, 2}, diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index d37367d47c8..d3cd2eedf6d 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -18,32 +18,43 @@ test_that("Addition", { a <- Array$create(c(1:4, NA_integer_)) expect_type_equal(a, int32()) - expect_type_equal(a + 4, int32()) - expect_equal(a + 4, Array$create(c(5:8, NA_integer_))) - expect_identical(as.vector(a + 4), c(5:8, NA_integer_)) + expect_type_equal(a + 4L, int32()) + expect_type_equal(a + 4, float64()) + expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) + expect_identical(as.vector(a + 4L), c(5:8, NA_integer_)) expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) expect_vector(a + 4L, c(5:8, NA_integer_)) expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5))) - # overflow errors — this is slightly different from R's `NA` coercion when - # overflowing, but better than the alternative of silently restarting - casted <- a$cast(int8()) - expect_error(casted + 127) - expect_error(casted + 200) + a8 <- a$cast(int8()) + expect_type_equal(a8 + Scalar$create(1, int8()), int8()) + + # int8 will be promoted to int32 when added to int32 + expect_type_equal(a8 + 127L, int32()) + expect_equal(a8 + 127L, Array$create(c(128:131, NA_integer_))) + + b <- Array$create(c(4:1, NA_integer_)) + expect_type_equal(a8 + b, int32()) + expect_equal(a8 + b, Array$create(c(5L, 5L, 5L, 5L, NA_integer_))) - skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919") expect_type_equal(a + 4.1, float64()) expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) }) test_that("Subtraction", { a <- Array$create(c(1:4, NA_integer_)) - expect_equal(a - 3, Array$create(c(-2:1, NA_integer_))) + expect_equal(a - 3L, Array$create(c(-2:1, NA_integer_))) + + expect_equal(Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_)) - a, + Array$create(c(4.1, 4.1, 4.1, 4.1, NA_real_))) }) test_that("Multiplication", { a <- Array$create(c(1:4, NA_integer_)) - expect_equal(a * 2, Array$create(c(1:4 * 2L, NA_integer_))) + expect_equal(a * 2L, Array$create(c(1:4 * 2L, NA_integer_))) + + expect_equal((a * 0.5) * 3L, + Array$create(c(1.5, 3, 4.5, 6, NA_real_))) }) test_that("Division", { diff --git a/r/tests/testthat/test-compute-vector.R b/r/tests/testthat/test-compute-vector.R index 4fe7fed4d1c..0b184889bee 100644 --- a/r/tests/testthat/test-compute-vector.R +++ b/r/tests/testthat/test-compute-vector.R @@ -43,6 +43,7 @@ test_that("compare ops with Array", { expect_array_compares(Array$create(c(NA, 1:5)), 4) expect_array_compares(Array$create(as.numeric(c(NA, 1:5))), 4) expect_array_compares(Array$create(c(NA, 1:5)), Array$create(rev(c(NA, 1:5)))) + expect_array_compares(Array$create(c(NA, 1:5)), Array$create(rev(c(NA, 1:5)), type=double())) }) test_that("compare ops with ChunkedArray", { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 1e0f9418eec..990f024212e 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -554,6 +554,7 @@ test_that("filter() on timestamp columns", { ) # Now with bare string date + skip("Implement more aggressive implicit casting for scalars (ARROW-11402)") expect_equivalent( ds %>% filter(ts >= "2015-05-04") %>% @@ -666,8 +667,6 @@ test_that("filter() with expressions", { ) ) - skip("Implicit casts aren't being inserted everywhere they need to be (ARROW-8919)") - # Error: NotImplemented: Function multiply_checked has no kernel matching input types (scalar[double], array[int32]) expect_equivalent( ds %>% select(chr, dbl, int) %>% @@ -680,8 +679,6 @@ test_that("filter() with expressions", { ) ) - skip("Implicit casts are only inserted for scalars (ARROW-8919)") - # Error: NotImplemented: Function add_checked has no kernel matching input types (array[double], array[int32]) expect_equivalent( ds %>% select(chr, dbl, int) %>% @@ -700,7 +697,7 @@ test_that("filter scalar validation doesn't crash (ARROW-7772)", { ds %>% filter(int == "fff", part == 1) %>% collect(), - "Failed to parse string: 'fff' as a scalar of type int32" + "equal has no kernel matching input types .array.int32., scalar.string.." ) })