From 1fb628b26cd563fc747efa6fc47bf175d9c62bf9 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 12 Jan 2021 08:37:17 -0500 Subject: [PATCH 01/22] ARROW-8919: [C++][Compute] Add Function::DispatchBest --- cpp/src/arrow/compute/exec_internal.h | 5 + cpp/src/arrow/compute/function.cc | 118 ++++++++++++------ cpp/src/arrow/compute/function.h | 23 ++-- cpp/src/arrow/compute/kernel.h | 4 +- .../arrow/compute/kernels/codegen_internal.cc | 31 +++++ .../arrow/compute/kernels/codegen_internal.h | 10 ++ .../compute/kernels/scalar_arithmetic.cc | 20 ++- .../compute/kernels/scalar_arithmetic_test.cc | 21 +++- .../arrow/compute/kernels/scalar_compare.cc | 20 ++- .../compute/kernels/scalar_compare_test.cc | 118 +++++++++--------- cpp/src/arrow/compute/kernels/test_util.cc | 32 +++++ cpp/src/arrow/compute/kernels/test_util.h | 5 + cpp/src/arrow/dataset/expression.cc | 29 +++-- cpp/src/arrow/dataset/expression.h | 1 + cpp/src/arrow/dataset/expression_internal.h | 16 --- 15 files changed, 308 insertions(+), 145 deletions(-) diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index a74e5c8d8fa..cce8386d93a 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -132,6 +132,11 @@ class ARROW_EXPORT KernelExecutor { ARROW_EXPORT Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out); +/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. +ARROW_EXPORT +const Kernel* DispatchExactImpl(const Function* func, + const std::vector& values); + } // namespace detail } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 7a868853db7..979d121fc30 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -27,6 +27,9 @@ #include "arrow/util/cpu_info.h" namespace arrow { + +using internal::checked_cast; + namespace compute { static const FunctionDoc kEmptyFunctionDoc{}; @@ -44,8 +47,30 @@ Status Function::CheckArity(int passed_num_args) const { return Status::OK(); } -template -std::string FormatArgTypes(const std::vector& descrs) { +namespace { + +Status ValidateDispatch(const Function* func, const std::vector& values) { + if (func->kind() == Function::META) { + return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); + } + + const int passed_num_args = static_cast(values.size()); + const Arity arity = func->arity(); + + if (arity.is_varargs && passed_num_args < arity.num_args) { + return Status::Invalid("VarArgs function needs at least ", arity.num_args, + " arguments but passed only ", passed_num_args); + } + + if (!arity.is_varargs && passed_num_args != arity.num_args) { + return Status::Invalid("Function accepts ", arity.num_args, " arguments but passed ", + passed_num_args); + } + + return Status::OK(); +} + +Status NoMatchingKernel(const Function* func, const std::vector& descrs) { std::stringstream ss; ss << "("; for (size_t i = 0; i < descrs.size(); ++i) { @@ -55,28 +80,20 @@ std::string FormatArgTypes(const std::vector& descrs) { ss << descrs[i].ToString(); } ss << ")"; - return ss.str(); + + return Status::NotImplemented("Function ", func->name(), + " has no kernel matching input types ", ss.str()); } -template -Result DispatchExactImpl(const Function& func, - const std::vector& kernels, - const std::vector& values) { - const int passed_num_args = static_cast(values.size()); - const KernelType* kernel_matches[SimdLevel::MAX] = {NULL}; +template +const KernelType* DispatchExactImpl(const std::vector& kernels, + const std::vector& values) { + const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr}; // Validate arity - const Arity arity = func.arity(); - if (arity.is_varargs && passed_num_args < arity.num_args) { - return Status::Invalid("VarArgs function needs at least ", arity.num_args, - " arguments but passed only ", passed_num_args); - } else if (!arity.is_varargs && passed_num_args != arity.num_args) { - return Status::Invalid("Function accepts ", arity.num_args, " arguments but passed ", - passed_num_args); - } for (const auto& kernel : kernels) { - if (kernel.signature->MatchesInputs(values)) { - kernel_matches[kernel.simd_level] = &kernel; + if (kernel->signature->MatchesInputs(values)) { + kernel_matches[kernel->simd_level] = kernel; } } @@ -102,9 +119,51 @@ Result DispatchExactImpl(const Function& func, return kernel_matches[SimdLevel::NONE]; } - return Status::NotImplemented("Function ", func.name(), - " has no kernel matching input types ", - FormatArgTypes(values)); + return nullptr; +} + +const Kernel* DispatchExactImpl(const Function* func, + const std::vector& values) { + if (func->kind() == Function::SCALAR) { + return DispatchExactImpl(checked_cast(func)->kernels(), + values); + } + + if (func->kind() == Function::VECTOR) { + return DispatchExactImpl(checked_cast(func)->kernels(), + values); + } + + if (func->kind() == Function::SCALAR_AGGREGATE) { + return DispatchExactImpl( + checked_cast(func)->kernels(), values); + } + + return nullptr; +} + +} // namespace + +Result Function::DispatchExact( + const std::vector& values) const { + RETURN_NOT_OK(ValidateDispatch(this, values)); + + if (auto kernel = DispatchExactImpl(this, values)) { + return kernel; + } + return NoMatchingKernel(this, values); +} + +Result Function::DispatchBest(std::vector* values) const { + RETURN_NOT_OK(ValidateDispatch(this, *values)); + + // first try for an exact match + if (auto kernel = DispatchExactImpl(this, *values)) { + return kernel; + } + + // XXX permit generic conversions here, for example dict -> decoded, null -> any? + return DispatchExact(*values); } Result Function::Execute(const std::vector& args, @@ -187,11 +246,6 @@ Status ScalarFunction::AddKernel(ScalarKernel kernel) { return Status::OK(); } -Result ScalarFunction::DispatchExact( - const std::vector& values) const { - return DispatchExactImpl(*this, kernels_, values); -} - Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { RETURN_NOT_OK(CheckArity(static_cast(in_types.size()))); @@ -214,11 +268,6 @@ Status VectorFunction::AddKernel(VectorKernel kernel) { return Status::OK(); } -Result VectorFunction::DispatchExact( - const std::vector& values) const { - return DispatchExactImpl(*this, kernels_, values); -} - Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); if (arity_.is_varargs && !kernel.signature->is_varargs()) { @@ -228,11 +277,6 @@ Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { return Status::OK(); } -Result ScalarAggregateFunction::DispatchExact( - const std::vector& values) const { - return DispatchExactImpl(*this, kernels_, values); -} - Result MetaFunction::Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index e8e732027c9..23ecd6f160e 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -162,7 +162,15 @@ class ARROW_EXPORT Function { /// /// NB: This function is overridden in CastFunction. virtual Result DispatchExact( - const std::vector& values) const = 0; + const std::vector& values) const; + + /// \brief Return a best-match kernel that can execute the function given the argument + /// types, after implicit casts are applied. + /// + /// \param[in,out] values Argument types. An element may be modified to indicate that + /// the returned kernel only approximately matches the input value descriptors; callers + /// are responsible for casting inputs to the type and shape required by the kernel. + virtual Result DispatchBest(std::vector* values) const; /// \brief Execute the function eagerly with the passed input arguments with /// kernel dispatch, batch iteration, and memory allocation details taken @@ -249,9 +257,6 @@ class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl { /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(ScalarKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; }; /// \brief A function that executes general array operations that may yield @@ -276,9 +281,6 @@ class ARROW_EXPORT VectorFunction : public detail::FunctionImpl { /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(VectorKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; }; class ARROW_EXPORT ScalarAggregateFunction @@ -294,9 +296,6 @@ class ARROW_EXPORT ScalarAggregateFunction /// \brief Add a kernel (function implementation). Returns error if the /// kernel's signature does not match the function's arity. Status AddKernel(ScalarAggregateKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; }; /// \brief A function that dispatches to other functions. Must implement @@ -311,10 +310,6 @@ class ARROW_EXPORT MetaFunction : public Function { Result Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const override; - Result DispatchExact(const std::vector&) const override { - return Status::NotImplemented("DispatchExact for a MetaFunction's Kernels"); - } - protected: virtual Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 67cb5df7908..c8f9cacfb34 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -566,7 +566,7 @@ struct Kernel { /// output array values (as opposed to scalar values in the case of aggregate /// functions). struct ArrayKernel : public Kernel { - ArrayKernel() {} + ArrayKernel() = default; ArrayKernel(std::shared_ptr sig, ArrayKernelExec exec, KernelInit init = NULLPTR) @@ -614,7 +614,7 @@ using VectorFinalize = std::function*)>; /// (which have different defaults from ScalarKernel), and some other /// execution-related options. struct VectorKernel : public ArrayKernel { - VectorKernel() {} + VectorKernel() = default; VectorKernel(std::shared_ptr sig, ArrayKernelExec exec) : ArrayKernel(std::move(sig), exec) {} diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index a5941ea2200..cd9f5bfc876 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -179,6 +179,37 @@ Result FirstType(KernelContext*, const std::vector& desc return descrs[0]; } +std::shared_ptr CommonNumeric(const std::vector& descrs) { + for (const auto& descr : descrs) { + auto id = descr.type->id(); + if (!is_floating(id) && !is_integer(id)) { + // a common numeric type is only possible if all types are numeric + return nullptr; + } + } + for (const auto& descr : descrs) { + if (descr.type->id() == Type::DOUBLE) return float64(); + } + for (const auto& descr : descrs) { + if (descr.type->id() == Type::FLOAT) return float32(); + } + + bool at_least_one_signed = false; + int max_width = 0; + + for (const auto& descr : descrs) { + at_least_one_signed |= is_signed_integer(descr.type->id()); + max_width = + std::max(max_width, checked_cast(*descr.type).bit_width()); + } + + if (max_width == 64) return at_least_one_signed ? int64() : uint64(); + if (max_width == 32) return at_least_one_signed ? int32() : uint32(); + if (max_width == 16) return at_least_one_signed ? int16() : uint16(); + DCHECK_EQ(max_width, 8); + return at_least_one_signed ? int8() : uint8(); +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index c3a6b4b9772..3d03b11585b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1186,6 +1186,16 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) { // END of kernel generator-dispatchers // ---------------------------------------------------------------------- +inline void EnsureDictionaryDecoded(std::vector* descrs) { + for (ValueDescr& descr : *descrs) { + if (descr.type->id() == Type::DICTIONARY) { + descr.type = checked_cast(*descr.type).value_type(); + } + } +} + +std::shared_ptr CommonNumeric(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index fc18da7cf13..38995b16e96 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -264,10 +264,26 @@ ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) { } } +struct ArithmeticFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + EnsureDictionaryDecoded(values); + + if (auto type = CommonNumeric(*values)) { + for (auto& descr : *values) { + descr.type = type; + } + } + + return DispatchExact(*values); + } +}; + template std::shared_ptr MakeArithmeticFunction(std::string name, const FunctionDoc* doc) { - auto func = std::make_shared(name, Arity::Binary(), doc); + auto func = std::make_shared(name, Arity::Binary(), doc); for (const auto& ty : NumericTypes()) { auto exec = NumericEqualTypesBinary(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); @@ -280,7 +296,7 @@ std::shared_ptr MakeArithmeticFunction(std::string name, template std::shared_ptr MakeArithmeticFunctionNotNull(std::string name, const FunctionDoc* doc) { - auto func = std::make_shared(name, Arity::Binary(), doc); + auto func = std::make_shared(name, Arity::Binary(), doc); for (const auto& ty : NumericTypes()) { auto exec = NumericEqualTypesBinary(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index a19abe82873..2e852bc54ab 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -67,7 +67,7 @@ class TestBinaryArithmetic : public TestBase { using BinaryFunction = std::function(const Datum&, const Datum&, ArithmeticOptions, ExecContext*)>; - void SetUp() { options_.check_overflow = false; } + void SetUp() override { options_.check_overflow = false; } std::shared_ptr MakeNullScalar() { return arrow::MakeNullScalar(type_singleton()); @@ -637,5 +637,24 @@ TYPED_TEST(TestBinaryArithmeticFloating, Mul) { this->AssertBinop(Multiply, "[null, 2.0]", this->MakeNullScalar(), "[null, null]"); } +TEST(TestBinaryArithmetic, DispatchBest) { + for (std::string name : {"add", "subtract", "multiply", "divide"}) { + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); + + CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, + {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, + {float64(), float64()}); + } + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index cf32c888e8e..45e3406996c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -72,10 +72,26 @@ void AddGenericCompare(const std::shared_ptr& ty, ScalarFunction* func applicator::ScalarBinaryEqualTypes::Exec)); } +struct CompareFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + EnsureDictionaryDecoded(values); + + if (auto type = CommonNumeric(*values)) { + for (auto& descr : *values) { + descr.type = type; + } + } + + return DispatchExact(*values); + } +}; + template std::shared_ptr MakeCompareFunction(std::string name, const FunctionDoc* doc) { - auto func = std::make_shared(name, Arity::Binary(), doc); + auto func = std::make_shared(name, Arity::Binary(), doc); DCHECK_OK(func->AddKernel( {boolean(), boolean()}, boolean(), @@ -136,7 +152,7 @@ std::shared_ptr MakeCompareFunction(std::string name, std::shared_ptr MakeFlippedFunction(std::string name, const ScalarFunction& func, const FunctionDoc* doc) { - auto flipped_func = std::make_shared(name, Arity::Binary(), doc); + auto flipped_func = std::make_shared(name, Arity::Binary(), doc); for (const ScalarKernel* kernel : func.kernels()) { ScalarKernel flipped_kernel = *kernel; flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 6f742fe7bfd..1b91bf54813 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -451,6 +451,23 @@ TEST(TestCompareTimestamps, Basics) { CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]"); } +TEST(TestCompareKernel, DispatchBest) { + for (std::string name : + {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); + + CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, + {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, + {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), utf8()), utf8()}, {utf8(), utf8()}); + } +} + class TestStringCompareKernel : public ::testing::Test {}; TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { @@ -459,85 +476,74 @@ TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { CompareOptions eq(CompareOperator::EQUAL); ValidateCompare(eq, "[]", one, "[]"); ValidateCompare(eq, "[null]", one, "[null]"); - ValidateCompare(eq, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare( - eq, "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", one, "[0,1,0,0,0,0]"); - ValidateCompare( - eq, "[\"five\",\"four\",\"three\",\"two\",\"one\",\"zero\"]", one, "[0,0,0,0,1,0]"); - ValidateCompare(eq, "[null,\"zero\",\"one\",\"one\"]", one, "[null,0,1,1]"); + ValidateCompare(eq, R"(["zero","zero","one","one","two","two"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(eq, R"(["zero","one","two","three","four","five"])", one, + "[0,1,0,0,0,0]"); + ValidateCompare(eq, R"(["five","four","three","two","one","zero"])", one, + "[0,0,0,0,1,0]"); + ValidateCompare(eq, R"([null,"zero","one","one"])", one, "[null,0,1,1]"); Datum na(std::make_shared()); - ValidateCompare(eq, "[null,\"zero\",\"one\",\"one\"]", na, + ValidateCompare(eq, R"([null,"zero","one","one"])", na, "[null,null,null,null]"); - ValidateCompare(eq, na, "[null,\"zero\",\"one\",\"one\"]", + ValidateCompare(eq, na, R"([null,"zero","one","one"])", "[null,null,null,null]"); CompareOptions neq(CompareOperator::NOT_EQUAL); ValidateCompare(neq, "[]", one, "[]"); ValidateCompare(neq, "[null]", one, "[null]"); - ValidateCompare(neq, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare(neq, - "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", - one, "[1,0,1,1,1,1]"); - ValidateCompare(neq, - "[\"five\",\"four\",\"three\",\"two\",\"one\",\"zero\"]", - one, "[1,1,1,1,0,1]"); - ValidateCompare(neq, "[null,\"zero\",\"one\",\"one\"]", one, - "[null,1,0,0]"); + ValidateCompare(neq, R"(["zero","zero","one","one","two","two"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(neq, R"(["zero","one","two","three","four","five"])", one, + "[1,0,1,1,1,1]"); + ValidateCompare(neq, R"(["five","four","three","two","one","zero"])", one, + "[1,1,1,1,0,1]"); + ValidateCompare(neq, R"([null,"zero","one","one"])", one, "[null,1,0,0]"); CompareOptions gt(CompareOperator::GREATER); ValidateCompare(gt, "[]", one, "[]"); ValidateCompare(gt, "[null]", one, "[null]"); - ValidateCompare(gt, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare( - gt, "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", one, "[1,0,1,1,0,0]"); - ValidateCompare(gt, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare(gt, "[null,\"zero\",\"one\",\"one\"]", one, "[null,1,0,0]"); + ValidateCompare(gt, R"(["zero","zero","one","one","two","two"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(gt, R"(["zero","one","two","three","four","five"])", one, + "[1,0,1,1,0,0]"); + ValidateCompare(gt, R"(["four","five","six","seven","eight","nine"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(gt, R"([null,"zero","one","one"])", one, "[null,1,0,0]"); CompareOptions gte(CompareOperator::GREATER_EQUAL); ValidateCompare(gte, "[]", one, "[]"); ValidateCompare(gte, "[null]", one, "[null]"); - ValidateCompare(gte, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[1,1,1,1,1,1]"); - ValidateCompare(gte, - "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", - one, "[1,1,1,1,0,0]"); - ValidateCompare(gte, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare(gte, "[null,\"zero\",\"one\",\"one\"]", one, - "[null,1,1,1]"); + ValidateCompare(gte, R"(["zero","zero","one","one","two","two"])", one, + "[1,1,1,1,1,1]"); + ValidateCompare(gte, R"(["zero","one","two","three","four","five"])", one, + "[1,1,1,1,0,0]"); + ValidateCompare(gte, R"(["four","five","six","seven","eight","nine"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(gte, R"([null,"zero","one","one"])", one, "[null,1,1,1]"); CompareOptions lt(CompareOperator::LESS); ValidateCompare(lt, "[]", one, "[]"); ValidateCompare(lt, "[null]", one, "[null]"); - ValidateCompare(lt, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[0,0,0,0,0,0]"); - ValidateCompare( - lt, "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", one, "[0,0,0,0,1,1]"); - ValidateCompare(lt, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare(lt, "[null,\"zero\",\"one\",\"one\"]", one, "[null,0,0,0]"); + ValidateCompare(lt, R"(["zero","zero","one","one","two","two"])", one, + "[0,0,0,0,0,0]"); + ValidateCompare(lt, R"(["zero","one","two","three","four","five"])", one, + "[0,0,0,0,1,1]"); + ValidateCompare(lt, R"(["four","five","six","seven","eight","nine"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(lt, R"([null,"zero","one","one"])", one, "[null,0,0,0]"); CompareOptions lte(CompareOperator::LESS_EQUAL); ValidateCompare(lte, "[]", one, "[]"); ValidateCompare(lte, "[null]", one, "[null]"); - ValidateCompare(lte, "[\"zero\",\"zero\",\"one\",\"one\",\"two\",\"two\"]", - one, "[0,0,1,1,0,0]"); - ValidateCompare(lte, - "[\"zero\",\"one\",\"two\",\"three\",\"four\",\"five\"]", - one, "[0,1,0,0,1,1]"); - ValidateCompare(lte, - "[\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\"]", - one, "[1,1,0,0,1,1]"); - ValidateCompare(lte, "[null,\"zero\",\"one\",\"one\"]", one, - "[null,0,1,1]"); + ValidateCompare(lte, R"(["zero","zero","one","one","two","two"])", one, + "[0,0,1,1,0,0]"); + ValidateCompare(lte, R"(["zero","one","two","three","four","five"])", one, + "[0,1,0,0,1,1]"); + ValidateCompare(lte, R"(["four","five","six","seven","eight","nine"])", one, + "[1,1,0,0,1,1]"); + ValidateCompare(lte, R"([null,"zero","one","one"])", one, "[null,0,1,1]"); } TEST_F(TestStringCompareKernel, RandomCompareArrayScalar) { @@ -563,7 +569,7 @@ TEST_F(TestStringCompareKernel, RandomCompareArrayArray) { for (size_t i = 3; i < 5; i++) { for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - const int64_t length = static_cast(1ULL << i); + auto length = static_cast(1ULL << i); auto lhs = Datum(rand.String(length << i, 0, 16, null_probability)); auto rhs = Datum(rand.String(length << i, 0, 16, null_probability)); auto options = CompareOptions(op); diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 5d54a8c1771..c8dd00250dc 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -24,6 +24,8 @@ #include "arrow/array.h" #include "arrow/chunked_array.h" #include "arrow/compute/exec.h" +#include "arrow/compute/function.h" +#include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/result.h" #include "arrow/testing/gtest_util.h" @@ -173,5 +175,35 @@ void CheckScalarBinary(std::string func_name, std::shared_ptr left_input, CheckScalar(std::move(func_name), {left_input, right_input}, expected, options); } +void CheckDispatchBest(std::string func_name, std::vector original_values, + std::vector expected_equivalent_values) { + ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(func_name)); + + auto values = original_values; + ASSERT_OK_AND_ASSIGN(auto actual_kernel, function->DispatchBest(&values)); + + ASSERT_OK_AND_ASSIGN(auto expected_kernel, + function->DispatchExact(expected_equivalent_values)); + + auto Format = [](const std::vector& descrs) { + std::stringstream ss; + ss << "("; + for (size_t i = 0; i < descrs.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << descrs[i].ToString(); + } + ss << ")"; + return ss.str(); + }; + + EXPECT_EQ(actual_kernel, expected_kernel) + << "DispatchBest" << Format(original_values) << " => " + << actual_kernel->signature->ToString() << "\n" + << "DispatchExact" << Format(expected_equivalent_values) << " => " + << expected_kernel->signature->ToString(); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index c38c0ceb83c..767911888ac 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -143,5 +143,10 @@ void TestRandomPrimitiveCTypes() { DoTestFunctor::Test(duration(TimeUnit::MILLI)); } +// Check that DispatchBest on a given function yields the same Kernel as +// produced by DispatchExact on another set of ValueDescrs. +void CheckDispatchBest(std::string func_name, std::vector descrs, + std::vector exact_descrs); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 16f706ed1a4..ddec3acf2f9 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -306,7 +306,7 @@ size_t Expression::hash() const { } bool Expression::IsBound() const { - if (descr().type == nullptr) return false; + if (type() == nullptr) return false; if (auto call = this->call()) { if (call->kernel == nullptr) return false; @@ -359,7 +359,7 @@ bool Expression::IsNullLiteral() const { } bool Expression::IsSatisfiable() const { - if (descr().type && descr().type->id() == Type::NA) { + if (type() && type()->id() == Type::NA) { return false; } @@ -426,7 +426,7 @@ Result BindNonRecursive(const Expression::Call& call, } Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { - if (expr->descr().type->Equals(to_type)) { + if (expr->type()->Equals(to_type)) { return Status::OK(); } @@ -454,22 +454,22 @@ Status InsertImplicitCasts(Expression::Call* call) { if (IsSameTypesBinary(call->function_name)) { for (auto&& argument : call->arguments) { - if (auto value_type = GetDictionaryValueType(argument.descr().type)) { + if (auto value_type = GetDictionaryValueType(argument.type())) { RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); } } if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { // argument 0 is scalar so casting is cheap - return MaybeInsertCast(call->arguments[1].descr().type, &call->arguments[0]); + return MaybeInsertCast(call->arguments[1].type(), &call->arguments[0]); } // cast argument 1 unconditionally - return MaybeInsertCast(call->arguments[0].descr().type, &call->arguments[1]); + return MaybeInsertCast(call->arguments[0].type(), &call->arguments[1]); } if (auto options = GetSetLookupOptions(*call)) { - if (auto value_type = GetDictionaryValueType(call->arguments[0].descr().type)) { + if (auto value_type = GetDictionaryValueType(call->arguments[0].type())) { // DICTIONARY input is not supported; decode it. RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &call->arguments[0])); } @@ -482,12 +482,12 @@ Status InsertImplicitCasts(Expression::Call* call) { call->options = std::move(new_options); } - if (!options->value_set.type()->Equals(call->arguments[0].descr().type)) { + if (!options->value_set.type()->Equals(call->arguments[0].type())) { // The value_set is assumed smaller than inputs, casting it should be cheaper. auto new_options = std::make_shared(*options); - ARROW_ASSIGN_OR_RAISE(new_options->value_set, - compute::Cast(std::move(new_options->value_set), - call->arguments[0].descr().type)); + ARROW_ASSIGN_OR_RAISE( + new_options->value_set, + compute::Cast(std::move(new_options->value_set), call->arguments[0].type())); options = new_options.get(); call->options = std::move(new_options); } @@ -595,8 +595,8 @@ Result ExecuteScalarExpression(const Expression& expr, const Datum& input // Refernced field was present but didn't have the expected type. // Should we just error here? For now, pay dispatch cost and just cast. ARROW_ASSIGN_OR_RAISE( - field, compute::Cast(field, expr.descr().type, compute::CastOptions::Safe(), - exec_context)); + field, + compute::Cast(field, expr.type(), compute::CastOptions::Safe(), exec_context)); } return field; @@ -803,8 +803,7 @@ Result ReplaceFieldsWithKnownValues( if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { - ARROW_ASSIGN_OR_RAISE(Datum lit, - compute::Cast(it->second, expr.descr().type)); + ARROW_ASSIGN_OR_RAISE(Datum lit, compute::Cast(it->second, expr.type())); return literal(std::move(lit)); } } diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 984c846210f..0d28d3d3f3d 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -104,6 +104,7 @@ class ARROW_DS_EXPORT Expression { /// The type and shape to which this expression will evaluate ValueDescr descr() const; + std::shared_ptr type() const { return descr().type; } // XXX someday // NullGeneralization::type nullable() const; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index 6a0f54fa8d5..a4cb63c5b50 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -214,13 +214,6 @@ inline std::shared_ptr GetDictionaryValueType( return nullptr; } -inline Status EnsureNotDictionary(ValueDescr* descr) { - if (auto value_type = GetDictionaryValueType(descr->type)) { - descr->type = std::move(value_type); - } - return Status::OK(); -} - inline Status EnsureNotDictionary(Datum* datum) { if (datum->type()->id() == Type::DICTIONARY) { const auto& type = checked_cast(*datum->type()).value_type(); @@ -229,15 +222,6 @@ inline Status EnsureNotDictionary(Datum* datum) { return Status::OK(); } -inline Status EnsureNotDictionary(Expression::Call* call) { - if (auto options = GetSetLookupOptions(*call)) { - auto new_options = *options; - RETURN_NOT_OK(EnsureNotDictionary(&new_options.value_set)); - call->options.reset(new compute::SetLookupOptions(std::move(new_options))); - } - return Status::OK(); -} - /// A helper for unboxing an Expression composed of associative function calls. /// Such expressions can frequently be rearranged to a semantically equivalent /// expression for more optimal execution or more straightforward manipulation. From e6f08407a86ed6a1af832f3b93b1d706f89f6acf Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 14 Jan 2021 12:25:40 -0500 Subject: [PATCH 02/22] support implicit casts in Function::Execute, CallFunction --- cpp/src/arrow/compute/function.cc | 21 +++++++++++++++--- .../compute/kernels/scalar_arithmetic_test.cc | 21 ++++++++++++++++++ .../compute/kernels/scalar_compare_test.cc | 22 +++++++++++++++++++ 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 979d121fc30..0019c430298 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -21,6 +21,7 @@ #include #include +#include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec_internal.h" #include "arrow/datum.h" @@ -166,15 +167,18 @@ Result Function::DispatchBest(std::vector* values) co return DispatchExact(*values); } -Result Function::Execute(const std::vector& args, +Result Function::Execute(const std::vector& original_args, const FunctionOptions* options, ExecContext* ctx) const { if (options == nullptr) { options = default_options(); } if (ctx == nullptr) { ExecContext default_ctx; - return Execute(args, options, &default_ctx); + return Execute(original_args, options, &default_ctx); } + // make a local copy to accommodate implicit casts + auto args = original_args; + // type-check Datum arguments here. Really we'd like to avoid this as much as // possible RETURN_NOT_OK(detail::CheckAllValues(args)); @@ -183,7 +187,18 @@ Result Function::Execute(const std::vector& args, inputs[i] = args[i].descr(); } - ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchExact(inputs)); + ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs)); + for (size_t i = 0; i != args.size(); ++i) { + if (inputs[i] != args[i].descr()) { + if (inputs[i].shape != args[i].shape()) { + return Status::NotImplemented( + "Automatic broadcasting of scalars to arrays for function ", name()); + } + + ARROW_ASSIGN_OR_RAISE(args[i], Cast(args[i], inputs[i].type)); + } + } + std::unique_ptr state; KernelContext kernel_ctx{ctx}; diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 2e852bc54ab..edaccc4e8f8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -656,5 +656,26 @@ TEST(TestBinaryArithmetic, DispatchBest) { } } +TEST(TestBinaryArithmetic, AddWithImplicitCasts) { + CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + ArrayFromJSON(float64(), "[0.25, 0.5, 0.75, 1.0]"), + ArrayFromJSON(float64(), "[0.25, 1.5, 2.75, null]")); + + CheckScalarBinary("add", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(int32(), "[-13, 4, 21, null]")); + + CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(int32(), "[3, 5, 7, null]")); + + // Not currently implemented since it would invoke a double implicit cast: + // dictionary(int32, int8) -> int8 -> int32 + // CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, + // null]"), + // ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + // ArrayFromJSON(int32(), "[3, 5, 7, null]")); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 1b91bf54813..c36db631055 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -468,6 +468,28 @@ TEST(TestCompareKernel, DispatchBest) { } } +TEST(TestCompareKernel, GreaterWithImplicitCasts) { + CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + ArrayFromJSON(float64(), "[0.5, 1.0, 1.5, 2.0]"), + ArrayFromJSON(boolean(), "[false, false, true, null]")); + + CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(boolean(), "[false, false, true, null]")); + + CheckScalarBinary("greater", + ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(boolean(), "[false, false, false, null]")); + + // Not currently implemented since it would invoke a double implicit cast: + // dictionary(int32, int8) -> int8 -> int32 + // CheckScalarBinary("greater", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, + // null]"), + // ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + // ArrayFromJSON(boolean(), "[false, false, false, null]")); +} + class TestStringCompareKernel : public ::testing::Test {}; TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { From 5cdd710a5d61169f2d55a21090c138bd02c636df Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 14 Jan 2021 20:34:17 -0500 Subject: [PATCH 03/22] first pass at integrating DispatchBest into Expressions --- .../arrow/compute/kernels/codegen_internal.cc | 6 +- cpp/src/arrow/dataset/expression.cc | 157 +++++++----------- cpp/src/arrow/dataset/expression.h | 2 +- cpp/src/arrow/dataset/expression_internal.h | 58 ++++--- cpp/src/arrow/dataset/expression_test.cc | 118 ++++++++----- cpp/src/arrow/dataset/test_util.h | 6 +- cpp/src/arrow/type_traits.h | 40 +++++ 7 files changed, 224 insertions(+), 163 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index cd9f5bfc876..764fc584eea 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -198,9 +198,9 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { int max_width = 0; for (const auto& descr : descrs) { - at_least_one_signed |= is_signed_integer(descr.type->id()); - max_width = - std::max(max_width, checked_cast(*descr.type).bit_width()); + auto id = descr.type->id(); + at_least_one_signed |= is_signed_integer(id); + max_width = std::max(bit_width(id), max_width); } if (max_width == 64) return at_least_one_signed ? int64() : uint64(); diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index ddec3acf2f9..1a5e17a993f 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -391,111 +391,78 @@ Result> InitKernelState( return std::move(kernel_state); } -Status InsertImplicitCasts(Expression::Call* call); - -// Produce a bound Expression from unbound Call and bound arguments. -Result BindNonRecursive(const Expression::Call& call, - std::vector arguments, - bool insert_implicit_casts, - compute::ExecContext* exec_context) { - DCHECK(std::all_of(arguments.begin(), arguments.end(), - [](const Expression& argument) { return argument.IsBound(); })); - - auto bound_call = call; - bound_call.arguments = std::move(arguments); - - if (insert_implicit_casts) { - RETURN_NOT_OK(InsertImplicitCasts(&bound_call)); - } - - ARROW_ASSIGN_OR_RAISE(bound_call.function, GetFunction(bound_call, exec_context)); - - auto descrs = GetDescriptors(bound_call.arguments); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel, bound_call.function->DispatchExact(descrs)); - - compute::KernelContext kernel_context(exec_context); - ARROW_ASSIGN_OR_RAISE(bound_call.kernel_state, - InitKernelState(bound_call, exec_context)); - kernel_context.SetState(bound_call.kernel_state.get()); - - ARROW_ASSIGN_OR_RAISE( - bound_call.descr, - bound_call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); - - return Expression(std::move(bound_call)); -} - -Status MaybeInsertCast(std::shared_ptr to_type, Expression* expr) { - if (expr->type()->Equals(to_type)) { - return Status::OK(); - } +Status ImplicitCastFunctionOptions(Expression::Call* call) { + if (auto options = GetSetLookupOptions(*call)) { + if (options->value_set.type()->Equals(call->arguments[0].type())) { + return Status::OK(); + } - if (auto lit = expr->literal()) { - ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, to_type)); - *expr = literal(std::move(new_lit)); + // The value_set is assumed smaller than inputs, casting it should be cheaper. + auto new_options = std::make_shared(*options); + ARROW_ASSIGN_OR_RAISE( + new_options->value_set, + compute::Cast(std::move(new_options->value_set), call->arguments[0].type())); + options = new_options.get(); + call->options = std::move(new_options); return Status::OK(); } - Expression::Call with_cast; - with_cast.function_name = "cast"; - with_cast.options = std::make_shared( - compute::CastOptions::Safe(std::move(to_type))); - - compute::ExecContext exec_context; - ARROW_ASSIGN_OR_RAISE(*expr, - BindNonRecursive(with_cast, {std::move(*expr)}, - /*insert_implicit_casts=*/false, &exec_context)); return Status::OK(); } -Status InsertImplicitCasts(Expression::Call* call) { - DCHECK(std::all_of(call->arguments.begin(), call->arguments.end(), +// Produce a bound Expression from unbound Call and bound arguments. +Result BindNonRecursive(Expression::Call call, bool insert_implicit_casts, + compute::ExecContext* exec_context) { + DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - if (IsSameTypesBinary(call->function_name)) { - for (auto&& argument : call->arguments) { - if (auto value_type = GetDictionaryValueType(argument.type())) { - RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &argument)); - } - } + auto descrs = GetDescriptors(call.arguments); + ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context)); - if (call->arguments[0].descr().shape == ValueDescr::SCALAR) { - // argument 0 is scalar so casting is cheap - return MaybeInsertCast(call->arguments[1].type(), &call->arguments[0]); - } + if (!insert_implicit_casts) { + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(descrs)); + } else { + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&descrs)); - // cast argument 1 unconditionally - return MaybeInsertCast(call->arguments[0].type(), &call->arguments[1]); - } + for (size_t i = 0; i < descrs.size(); ++i) { + if (descrs[i] == call.arguments[i].descr()) continue; - if (auto options = GetSetLookupOptions(*call)) { - if (auto value_type = GetDictionaryValueType(call->arguments[0].type())) { - // DICTIONARY input is not supported; decode it. - RETURN_NOT_OK(MaybeInsertCast(std::move(value_type), &call->arguments[0])); - } + if (descrs[i].shape != call.arguments[i].descr().shape) { + return Status::NotImplemented( + "Automatic broadcasting of scalars arguments to arrays in ", + Expression(std::move(call)).ToString()); + } - if (options->value_set.type()->id() == Type::DICTIONARY) { - // DICTIONARY value_set is not supported; decode it. - auto new_options = std::make_shared(*options); - RETURN_NOT_OK(EnsureNotDictionary(&new_options->value_set)); - options = new_options.get(); - call->options = std::move(new_options); - } + if (auto lit = call.arguments[i].literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, descrs[i].type)); + call.arguments[i] = literal(std::move(new_lit)); + continue; + } + + // construct an implicit cast Expression with which to replace this argument + Expression::Call implicit_cast; + implicit_cast.function_name = "cast"; + implicit_cast.arguments = {std::move(call.arguments[i])}; + implicit_cast.options = std::make_shared( + compute::CastOptions::Safe(descrs[i].type)); - if (!options->value_set.type()->Equals(call->arguments[0].type())) { - // The value_set is assumed smaller than inputs, casting it should be cheaper. - auto new_options = std::make_shared(*options); ARROW_ASSIGN_OR_RAISE( - new_options->value_set, - compute::Cast(std::move(new_options->value_set), call->arguments[0].type())); - options = new_options.get(); - call->options = std::move(new_options); + call.arguments[i], + BindNonRecursive(std::move(implicit_cast), + /*insert_implicit_casts=*/false, exec_context)); } - return Status::OK(); + RETURN_NOT_OK(ImplicitCastFunctionOptions(&call)); } - return Status::OK(); + compute::KernelContext kernel_context(exec_context); + ARROW_ASSIGN_OR_RAISE(call.kernel_state, InitKernelState(call, exec_context)); + kernel_context.SetState(call.kernel_state.get()); + + ARROW_ASSIGN_OR_RAISE( + call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); + + return Expression(std::move(call)); } struct FieldPathGetDatumImpl { @@ -554,14 +521,11 @@ Result Expression::Bind(ValueDescr in, return Expression{Parameter{*ref, std::move(descr)}}; } - auto call = CallNotNull(*this); - - std::vector bound_arguments(call->arguments.size()); - for (size_t i = 0; i < bound_arguments.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(bound_arguments[i], call->arguments[i].Bind(in, exec_context)); + auto call = *CallNotNull(*this); + for (auto& argument : call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); } - - return BindNonRecursive(*call, std::move(bound_arguments), + return BindNonRecursive(std::move(call), /*insert_implicit_casts=*/true, exec_context); } @@ -899,7 +863,7 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont flipped_call.function_name = Comparison::GetName(Comparison::GetFlipped(*cmp)); - return BindNonRecursive(flipped_call, std::move(flipped_call.arguments), + return BindNonRecursive(flipped_call, /*insert_implicit_casts=*/false, exec_context); } } @@ -925,7 +889,10 @@ Result DirectComparisonSimplification(Expression expr, if (!cmp) return expr; if (!cmp_guarantee) return expr; - if (call->arguments[0] != guarantee.arguments[0]) return expr; + + const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); + const auto& guarantee_lhs = guarantee.arguments[0]; + if (lhs != guarantee_lhs) return expr; auto rhs = call->arguments[1].literal(); auto guarantee_rhs = guarantee.arguments[1].literal(); diff --git a/cpp/src/arrow/dataset/expression.h b/cpp/src/arrow/dataset/expression.h index 0d28d3d3f3d..13c714b2d72 100644 --- a/cpp/src/arrow/dataset/expression.h +++ b/cpp/src/arrow/dataset/expression.h @@ -50,8 +50,8 @@ class ARROW_DS_EXPORT Expression { std::shared_ptr> hash; // post-Bind properties: - const compute::Kernel* kernel = NULLPTR; std::shared_ptr function; + const compute::Kernel* kernel = NULLPTR; std::shared_ptr kernel_state; ValueDescr descr; }; diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index a4cb63c5b50..d8c6b7445df 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -109,6 +109,40 @@ struct Comparison { return less.scalar_as().value ? LESS : GREATER; } + static const Expression& StripOrderPreservingCasts(const Expression& expr) { + auto call = expr.call(); + if (!call) return expr; + if (call->function_name != "cast") return expr; + + const Expression& from = call->arguments[0]; + + auto from_id = from.type()->id(); + auto to_id = expr.type()->id(); + + if (is_floating(to_id)) { + if (is_integer(from_id) || is_floating(from_id)) { + return StripOrderPreservingCasts(from); + } + return expr; + } + + if (is_unsigned_integer(to_id)) { + if (is_unsigned_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) { + return StripOrderPreservingCasts(from); + } + return expr; + } + + if (is_signed_integer(to_id)) { + if (is_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) { + return StripOrderPreservingCasts(from); + } + return expr; + } + + return expr; + } + static type GetFlipped(type op) { switch (op) { case NA: @@ -182,14 +216,6 @@ inline bool IsSetLookup(const std::string& function) { return function == "is_in" || function == "index_in"; } -inline bool IsSameTypesBinary(const std::string& function) { - if (Comparison::Get(function)) return true; - - static std::unordered_set set{"add", "subtract", "multiply", "divide"}; - - return set.find(function) != set.end(); -} - inline const compute::SetLookupOptions* GetSetLookupOptions( const Expression::Call& call) { if (!IsSetLookup(call.function_name)) return nullptr; @@ -206,22 +232,6 @@ inline const compute::StrptimeOptions* GetStrptimeOptions(const Expression::Call return checked_cast(call.options.get()); } -inline std::shared_ptr GetDictionaryValueType( - const std::shared_ptr& type) { - if (type && type->id() == Type::DICTIONARY) { - return checked_cast(*type).value_type(); - } - return nullptr; -} - -inline Status EnsureNotDictionary(Datum* datum) { - if (datum->type()->id() == Type::DICTIONARY) { - const auto& type = checked_cast(*datum->type()).value_type(); - ARROW_ASSIGN_OR_RAISE(*datum, compute::Cast(*datum, type)); - } - return Status::OK(); -} - /// A helper for unboxing an Expression composed of associative function calls. /// Such expressions can frequently be rearranged to a semantically equivalent /// expression for more optimal execution or more straightforward manipulation. diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index da5c82425b3..35f488aaaea 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -68,6 +68,8 @@ void ExpectResultsEqual(Actual&& actual, Expected&& expected) { } } +const auto no_change = util::nullopt; + TEST(ExpressionUtils, Comparison) { auto Expect = [](Result expected, Datum l, Datum r) { ExpectResultsEqual(Comparison::Execute(l, r).Map(Comparison::GetName), expected); @@ -93,6 +95,51 @@ TEST(ExpressionUtils, Comparison) { Expect(parse_failure, null, str); } +TEST(ExpressionUtils, StripOrderPreservingCasts) { + auto Expect = [](Expression expr, util::optional expected_stripped) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); + if (!expected_stripped) { + expected_stripped = expr; + } else { + ASSERT_OK_AND_ASSIGN(expected_stripped, expected_stripped->Bind(*kBoringSchema)); + } + EXPECT_EQ(Comparison::StripOrderPreservingCasts(expr), *expected_stripped); + }; + + // Casting int to float preserves ordering. + // For example, let + // a = 3, b = 2, assert(a > b) + // After injecting a cast to float, this ordering still holds + // float(a) == 3.0, float(b) == 2.0, assert(float(a) > float(b)) + Expect(cast(field_ref("i32"), float32()), field_ref("i32")); + + // Casting an integral type to a wider integral type preserves ordering. + Expect(cast(field_ref("i32"), int64()), field_ref("i32")); + Expect(cast(field_ref("i32"), int32()), field_ref("i32")); + Expect(cast(field_ref("i32"), int16()), no_change); + Expect(cast(field_ref("i32"), int8()), no_change); + + Expect(cast(field_ref("u32"), uint64()), field_ref("u32")); + Expect(cast(field_ref("u32"), uint32()), field_ref("u32")); + Expect(cast(field_ref("u32"), uint16()), no_change); + Expect(cast(field_ref("u32"), uint8()), no_change); + + // Casting float to int can affect ordering. + // For example, let + // a = 3.5, b = 3.0, assert(a > b) + // After injecting a cast to integer, this ordering may no longer hold + // int(a) == 3, int(b) == 3, assert(!(int(a) > int(b))) + Expect(cast(field_ref("f32"), int32()), no_change); + + // casting any float type to another preserves ordering + Expect(cast(field_ref("f32"), float64()), field_ref("f32")); + Expect(cast(field_ref("f64"), float32()), field_ref("f64")); + + // casting signed integer to unsigned can alter ordering + Expect(cast(field_ref("i32"), uint32()), no_change); + Expect(cast(field_ref("i32"), uint64()), no_change); +} + TEST(Expression, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); @@ -307,8 +354,6 @@ void ExpectBindsTo(Expression expr, util::optional expected, } } -const auto no_change = util::nullopt; - TEST(Expression, BindFieldRef) { // an unbound field_ref does not have the output ValueDescr set auto expr = field_ref("alpha"); @@ -342,42 +387,44 @@ TEST(Expression, BindCall) { ExpectBindsTo(expr, no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - // literal(3) may be safely cast to float32, so binding this expr casts that literal: ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), call("add", {field_ref("f32"), literal(3.0F)})); - // literal(3.5) may not be safely cast to int32, so binding this expr fails: - ASSERT_RAISES(Invalid, - call("add", {field_ref("i32"), literal(3.5)}).Bind(*kBoringSchema)); + ExpectBindsTo(call("add", {field_ref("i32"), literal(3.5F)}), + call("add", {cast(field_ref("i32"), float32()), literal(3.5F)})); } TEST(Expression, BindWithImplicitCasts) { for (auto cmp : {equal, not_equal, less, less_equal, greater, greater_equal}) { - // cast arguments to same type + // cast arguments to common numeric type + ExpectBindsTo(cmp(field_ref("i64"), field_ref("i32")), + cmp(field_ref("i64"), cast(field_ref("i32"), int64()))); + + ExpectBindsTo(cmp(field_ref("i64"), field_ref("f32")), + cmp(cast(field_ref("i64"), float32()), field_ref("f32"))); + ExpectBindsTo(cmp(field_ref("i32"), field_ref("i64")), - cmp(field_ref("i32"), cast(field_ref("i64"), int32()))); - // NB: RHS is cast unless LHS is scalar. + cmp(cast(field_ref("i32"), int64()), field_ref("i64"))); + + ExpectBindsTo(cmp(field_ref("i8"), field_ref("u32")), + cmp(cast(field_ref("i8"), int32()), cast(field_ref("u32"), int32()))); // cast dictionary to value type ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); } - // scalars are directly cast when possible: - auto ts_scalar = MakeScalar("1990-10-23")->CastTo(timestamp(TimeUnit::NANO)); - ExpectBindsTo(equal(field_ref("ts_ns"), literal("1990-10-23")), - equal(field_ref("ts_ns"), literal(*ts_scalar))); - - // cast value_set to argument type auto Opts = [](std::shared_ptr type) { return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; }; + + // cast value_set to argument type ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(binary())), call("is_in", {field_ref("str")}, Opts(utf8()))); - // dictionary decode set then cast to argument type - ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(dictionary(int32(), binary()))), - call("is_in", {field_ref("str")}, Opts(utf8()))); + // cast dictionary to value type + ExpectBindsTo(call("is_in", {field_ref("dict_str")}, Opts(utf8())), + call("is_in", {cast(field_ref("dict_str"), utf8())}, Opts(utf8()))); } TEST(Expression, BindNestedCall) { @@ -519,16 +566,6 @@ TEST(Expression, ExecuteDictionaryTransparent) { {"a": "", "b": ""}, {"a": "hi", "b": "hello"} ])")); - - Datum dict_set = ArrayFromJSON(dictionary(int32(), utf8()), R"(["a"])"); - AssertExecute(call("is_in", {field_ref("a")}, - compute::SetLookupOptions{dict_set, - /*skip_nulls=*/false}), - ArrayFromJSON(struct_({field("a", utf8())}), R"([ - {"a": "a"}, - {"a": "good"}, - {"a": null} - ])")); } void ExpectIdenticalIfUnchanged(Expression modified, Expression original) { @@ -874,6 +911,12 @@ TEST(Expression, SingleComparisonGuarantees) { .WithGuarantee(equal(i32, literal(5))) .Expect(false); + Simplify{ + equal(i32, literal(0.5)), + } + .WithGuarantee(greater_equal(i32, literal(1))) + .Expect(false); + // no simplification possible: Simplify{ not_equal(i32, literal(3)), @@ -949,27 +992,26 @@ TEST(Expression, SimplifyWithGuarantee) { .WithGuarantee(and_(greater_equal(field_ref("i32"), literal(0)), less_equal(field_ref("i32"), literal(1)))) .Expect(equal(field_ref("i32"), literal(0))); - Simplify{ - or_(equal(field_ref("f32"), literal("0")), equal(field_ref("i32"), literal(3)))} + + Simplify{or_(equal(field_ref("f32"), literal(0)), equal(field_ref("i32"), literal(3)))} .WithGuarantee(greater(field_ref("f32"), literal(0.0))) .Expect(equal(field_ref("i32"), literal(3))); // simplification can see through implicit casts - Simplify{or_({equal(field_ref("f32"), literal("0")), - call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ - ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})})} - .WithGuarantee(greater(field_ref("f32"), literal(0.0))) + Simplify{ + or_({equal(field_ref("f32"), literal(0)), + call("is_in", {field_ref("i64")}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})})} + .WithGuarantee(greater(field_ref("f32"), literal(0.F))) .Expect(call("is_in", {field_ref("i64")}, compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); } TEST(Expression, SimplifyThenExecute) { auto filter = - or_({equal(field_ref("f32"), literal("0")), + or_({equal(field_ref("f32"), literal(0)), call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ - ArrayFromJSON(dictionary(int32(), int32()), "[1,2,3]"), true})}); + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})}); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); auto guarantee = greater(field_ref("f32"), literal(0.0)); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index f0d44cfe3d6..b3e9ae424da 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -50,14 +50,16 @@ namespace arrow { namespace dataset { const std::shared_ptr kBoringSchema = schema({ + field("bool", boolean()), + field("i8", int8()), field("i32", int32()), field("i32_req", int32(), /*nullable=*/false), + field("u32", uint32()), field("i64", int64()), - field("date64", date64()), field("f32", float32()), field("f32_req", float32(), /*nullable=*/false), field("f64", float64()), - field("bool", boolean()), + field("date64", date64()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), field("ts_ns", timestamp(TimeUnit::NANO)), diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 2dcfc77c437..48dce38e87a 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -929,6 +929,46 @@ static inline bool is_fixed_width(Type::type type_id) { return is_primitive(type_id) || is_dictionary(type_id) || is_fixed_size_binary(type_id); } +static inline int bit_width(Type::type type_id) { + switch (type_id) { + case Type::BOOL: + return 1; + case Type::UINT8: + case Type::INT8: + return 8; + case Type::UINT16: + case Type::INT16: + return 16; + case Type::UINT32: + case Type::INT32: + case Type::DATE32: + case Type::TIME32: + return 32; + case Type::UINT64: + case Type::INT64: + case Type::DATE64: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::DURATION: + return 64; + + case Type::HALF_FLOAT: + return 16; + case Type::FLOAT: + return 32; + case Type::DOUBLE: + return 64; + + case Type::INTERVAL_MONTHS: + return 32; + case Type::INTERVAL_DAY_TIME: + return 64; + default: + break; + } + return 0; +} + static inline bool is_nested(Type::type type_id) { switch (type_id) { case Type::LIST: From 6385032d3eddcf7643094c90301340471650de52 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 15 Jan 2021 16:55:33 -0500 Subject: [PATCH 04/22] add DispatchBest to SetLookup kernels --- cpp/src/arrow/compute/kernels/scalar_set_lookup.cc | 13 +++++++++++-- .../compute/kernels/scalar_set_lookup_test.cc | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index 93fa34c1694..41e2face6ec 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -414,6 +414,15 @@ class IndexInMetaBinary : public MetaFunction { } }; +struct SetLookupFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + EnsureDictionaryDecoded(values); + return DispatchExact(*values); + } +}; + const FunctionDoc is_in_doc{ "Find each element in a set of values", ("For each element in `values`, return true if it is found in a given\n" @@ -443,7 +452,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { isin_base.init = InitSetLookup; isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL; - auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); + auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); AddBasicSetLookupKernels(isin_base, /*output_type=*/boolean(), is_in.get()); @@ -462,7 +471,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = - std::make_shared("index_in", Arity::Unary(), &index_in_doc); + std::make_shared("index_in", Arity::Unary(), &index_in_doc); AddBasicSetLookupKernels(index_in_base, /*output_type=*/int32(), index_in.get()); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index 40907da5a62..5c6c5838b68 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -587,5 +587,19 @@ TEST_F(TestIndexInKernel, ChunkedArrayInvoke) { CheckIndexInChunked(input, value_set, expected, /*skip_nulls=*/true); } +TEST(TestSetLookup, DispatchBest) { + for (std::string name : {"is_in", "index_in"}) { + CheckDispatchBest(name, {int32()}, {int32()}); + CheckDispatchBest(name, {dictionary(int32(), utf8())}, {utf8()}); + } +} + +TEST(TestSetLookup, IsInWithImplicitCasts) { + SetLookupOptions opts{ArrayFromJSON(utf8(), R"(["b", "d"])")}; + CheckScalarUnary("is_in", + ArrayFromJSON(dictionary(int32(), utf8()), R"(["a", "b", "c", null])"), + ArrayFromJSON(boolean(), "[0, 1, 0, null]"), &opts); +} + } // namespace compute } // namespace arrow From 058e15a163b25a417e3991b5e6ddb51c7b0e9052 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Jan 2021 11:39:40 -0500 Subject: [PATCH 05/22] repair implicit cast is_in execution test --- cpp/src/arrow/compute/kernels/scalar_set_lookup.cc | 6 ++++-- cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc | 4 ++-- cpp/src/arrow/compute/kernels/util_internal.cc | 7 ++++--- cpp/src/arrow/compute/kernels/util_internal.h | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index 41e2face6ec..fcfedd572f4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -450,7 +450,8 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel isin_base; isin_base.init = InitSetLookup; - isin_base.exec = TrivialScalarUnaryAsArraysExec(ExecIsIn); + isin_base.exec = + TrivialScalarUnaryAsArraysExec(ExecIsIn, NullHandling::OUTPUT_NOT_NULL); isin_base.null_handling = NullHandling::OUTPUT_NOT_NULL; auto is_in = std::make_shared("is_in", Arity::Unary(), &is_in_doc); @@ -467,7 +468,8 @@ void RegisterScalarSetLookup(FunctionRegistry* registry) { { ScalarKernel index_in_base; index_in_base.init = InitSetLookup; - index_in_base.exec = TrivialScalarUnaryAsArraysExec(ExecIndexIn); + index_in_base.exec = TrivialScalarUnaryAsArraysExec( + ExecIndexIn, NullHandling::COMPUTED_NO_PREALLOCATE); index_in_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; index_in_base.mem_allocation = MemAllocation::NO_PREALLOCATE; auto index_in = diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index 5c6c5838b68..5b87d09dbbc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -595,10 +595,10 @@ TEST(TestSetLookup, DispatchBest) { } TEST(TestSetLookup, IsInWithImplicitCasts) { - SetLookupOptions opts{ArrayFromJSON(utf8(), R"(["b", "d"])")}; + SetLookupOptions opts{ArrayFromJSON(utf8(), R"(["b", null])")}; CheckScalarUnary("is_in", ArrayFromJSON(dictionary(int32(), utf8()), R"(["a", "b", "c", null])"), - ArrayFromJSON(boolean(), "[0, 1, 0, null]"), &opts); + ArrayFromJSON(boolean(), "[0, 1, 0, 1]"), &opts); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/util_internal.cc b/cpp/src/arrow/compute/kernels/util_internal.cc index 3d21f5b1494..93badbd3b25 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.cc +++ b/cpp/src/arrow/compute/kernels/util_internal.cc @@ -57,13 +57,14 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr) { return arg; } -ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec) { - return [exec](KernelContext* ctx, const ExecBatch& batch, Datum* out) { +ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec, + NullHandling::type null_handling) { + return [=](KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (out->is_array()) { return exec(ctx, batch, out); } - if (!batch[0].scalar()->is_valid) { + if (null_handling == NullHandling::INTERSECTION && !batch[0].scalar()->is_valid) { out->scalar()->is_valid = false; return; } diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h index aece5a97599..f614439ffb8 100644 --- a/cpp/src/arrow/compute/kernels/util_internal.h +++ b/cpp/src/arrow/compute/kernels/util_internal.h @@ -59,7 +59,8 @@ PrimitiveArg GetPrimitiveArg(const ArrayData& arr); // the original exec, then the only element of the resulting array will be extracted as // the output scalar. This could be far more efficient, but instead of optimizing this // it'd be better to support scalar inputs "upstream" in original exec. -ArrayKernelExec TrivialScalarUnaryAsArraysExec(ArrayKernelExec exec); +ArrayKernelExec TrivialScalarUnaryAsArraysExec( + ArrayKernelExec exec, NullHandling::type null_handling = NullHandling::INTERSECTION); } // namespace internal } // namespace compute From b8525a47dbb49a65991622c73b3914696a6adfc4 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Jan 2021 12:53:23 -0500 Subject: [PATCH 06/22] add support for null -> * cast to arithmetic and compare --- cpp/src/arrow/compute/exec_internal.h | 5 -- cpp/src/arrow/compute/function.cc | 59 ++++++------------- .../arrow/compute/kernels/codegen_internal.cc | 22 +++++++ .../arrow/compute/kernels/codegen_internal.h | 13 ++-- cpp/src/arrow/compute/kernels/common.h | 12 ++++ .../compute/kernels/scalar_arithmetic.cc | 8 ++- .../compute/kernels/scalar_arithmetic_test.cc | 13 ++++ .../arrow/compute/kernels/scalar_compare.cc | 8 ++- .../compute/kernels/scalar_compare_test.cc | 14 +++++ cpp/src/arrow/datum.cc | 13 ++++ cpp/src/arrow/datum.h | 1 + 11 files changed, 114 insertions(+), 54 deletions(-) diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index cce8386d93a..a74e5c8d8fa 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -132,11 +132,6 @@ class ARROW_EXPORT KernelExecutor { ARROW_EXPORT Status PropagateNulls(KernelContext* ctx, const ExecBatch& batch, ArrayData* out); -/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. -ARROW_EXPORT -const Kernel* DispatchExactImpl(const Function* func, - const std::vector& values); - } // namespace detail } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 0019c430298..42903cf3627 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -41,49 +41,22 @@ Status Function::CheckArity(int passed_num_args) const { if (arity_.is_varargs && passed_num_args < arity_.num_args) { return Status::Invalid("VarArgs function needs at least ", arity_.num_args, " arguments but kernel accepts only ", passed_num_args); - } else if (!arity_.is_varargs && passed_num_args != arity_.num_args) { - return Status::Invalid("Function accepts ", arity_.num_args, - " arguments but kernel accepts ", passed_num_args); - } - return Status::OK(); -} - -namespace { - -Status ValidateDispatch(const Function* func, const std::vector& values) { - if (func->kind() == Function::META) { - return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); - } - - const int passed_num_args = static_cast(values.size()); - const Arity arity = func->arity(); - - if (arity.is_varargs && passed_num_args < arity.num_args) { - return Status::Invalid("VarArgs function needs at least ", arity.num_args, - " arguments but passed only ", passed_num_args); } - if (!arity.is_varargs && passed_num_args != arity.num_args) { - return Status::Invalid("Function accepts ", arity.num_args, " arguments but passed ", - passed_num_args); + if (!arity_.is_varargs && passed_num_args != arity_.num_args) { + return Status::Invalid("Function accepts ", arity_.num_args, + " arguments but kernel accepts ", passed_num_args); } return Status::OK(); } -Status NoMatchingKernel(const Function* func, const std::vector& descrs) { - std::stringstream ss; - ss << "("; - for (size_t i = 0; i < descrs.size(); ++i) { - if (i > 0) { - ss << ", "; - } - ss << descrs[i].ToString(); - } - ss << ")"; +namespace detail { +Status NoMatchingKernel(const Function* func, const std::vector& descrs) { return Status::NotImplemented("Function ", func->name(), - " has no kernel matching input types ", ss.str()); + " has no kernel matching input types ", + ValueDescr::ToString(descrs)); } template @@ -143,23 +116,29 @@ const Kernel* DispatchExactImpl(const Function* func, return nullptr; } -} // namespace +} // namespace detail Result Function::DispatchExact( const std::vector& values) const { - RETURN_NOT_OK(ValidateDispatch(this, values)); + if (kind_ == Function::META) { + return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); + } + RETURN_NOT_OK(CheckArity(static_cast(values.size()))); - if (auto kernel = DispatchExactImpl(this, values)) { + if (auto kernel = detail::DispatchExactImpl(this, values)) { return kernel; } - return NoMatchingKernel(this, values); + return detail::NoMatchingKernel(this, values); } Result Function::DispatchBest(std::vector* values) const { - RETURN_NOT_OK(ValidateDispatch(this, *values)); + if (kind_ == Function::META) { + return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); + } + RETURN_NOT_OK(CheckArity(static_cast(values->size()))); // first try for an exact match - if (auto kernel = DispatchExactImpl(this, *values)) { + if (auto kernel = detail::DispatchExactImpl(this, *values)) { return kernel; } diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 764fc584eea..0936f6e752d 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -179,6 +179,28 @@ Result FirstType(KernelContext*, const std::vector& desc return descrs[0]; } +void EnsureDictionaryDecoded(std::vector* descrs) { + for (ValueDescr& descr : *descrs) { + if (descr.type->id() == Type::DICTIONARY) { + descr.type = checked_cast(*descr.type).value_type(); + } + } +} + +void ReplaceNullWithOtherType(std::vector* descrs) { + DCHECK_EQ(descrs->size(), 2); + + if (descrs->at(0).type->id() == Type::NA) { + descrs->at(0).type = descrs->at(1).type; + return; + } + + if (descrs->at(1).type->id() == Type::NA) { + descrs->at(1).type = descrs->at(0).type; + return; + } +} + std::shared_ptr CommonNumeric(const std::vector& descrs) { for (const auto& descr : descrs) { auto id = descr.type->id(); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 3d03b11585b..c50cb2669df 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1186,14 +1186,13 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) { // END of kernel generator-dispatchers // ---------------------------------------------------------------------- -inline void EnsureDictionaryDecoded(std::vector* descrs) { - for (ValueDescr& descr : *descrs) { - if (descr.type->id() == Type::DICTIONARY) { - descr.type = checked_cast(*descr.type).value_type(); - } - } -} +ARROW_EXPORT +void EnsureDictionaryDecoded(std::vector* descrs); +ARROW_EXPORT +void ReplaceNullWithOtherType(std::vector* descrs); + +ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/common.h b/cpp/src/arrow/compute/kernels/common.h index 21244320f38..6566555e7f1 100644 --- a/cpp/src/arrow/compute/kernels/common.h +++ b/cpp/src/arrow/compute/kernels/common.h @@ -51,4 +51,16 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; +namespace compute { +namespace detail { + +/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. +ARROW_EXPORT +const Kernel* DispatchExactImpl(const Function* func, const std::vector&); + +ARROW_EXPORT +Status NoMatchingKernel(const Function* func, const std::vector&); + +} // namespace detail +} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 38995b16e96..059bf0aa6b7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -268,7 +268,10 @@ struct ArithmeticFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + EnsureDictionaryDecoded(values); + ReplaceNullWithOtherType(values); if (auto type = CommonNumeric(*values)) { for (auto& descr : *values) { @@ -276,7 +279,10 @@ struct ArithmeticFunction : ScalarFunction { } } - return DispatchExact(*values); + if (auto kernel = arrow::compute::detail::DispatchExactImpl(this, *values)) { + return kernel; + } + return arrow::compute::detail::NoMatchingKernel(this, *values); } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index edaccc4e8f8..f14a28aaa6c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -643,13 +643,22 @@ TEST(TestBinaryArithmetic, DispatchBest) { name += suffix; CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + + CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); + + CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, {float64(), float64()}); } @@ -669,6 +678,10 @@ TEST(TestBinaryArithmetic, AddWithImplicitCasts) { ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), ArrayFromJSON(int32(), "[3, 5, 7, null]")); + CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + std::make_shared(4), + ArrayFromJSON(int32(), "[null, null, null, null]")); + // Not currently implemented since it would invoke a double implicit cast: // dictionary(int32, int8) -> int8 -> int32 // CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 45e3406996c..75eebdf4bc3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -76,7 +76,10 @@ struct CompareFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + EnsureDictionaryDecoded(values); + ReplaceNullWithOtherType(values); if (auto type = CommonNumeric(*values)) { for (auto& descr : *values) { @@ -84,7 +87,10 @@ struct CompareFunction : ScalarFunction { } } - return DispatchExact(*values); + if (auto kernel = arrow::compute::detail::DispatchExactImpl(this, *values)) { + return kernel; + } + return arrow::compute::detail::NoMatchingKernel(this, *values); } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index c36db631055..5fc05330cdd 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -455,15 +455,25 @@ TEST(TestCompareKernel, DispatchBest) { for (std::string name : {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + + CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); + + CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, {float64(), float64()}); + CheckDispatchBest(name, {dictionary(int8(), utf8()), utf8()}, {utf8(), utf8()}); } } @@ -482,6 +492,10 @@ TEST(TestCompareKernel, GreaterWithImplicitCasts) { ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), ArrayFromJSON(boolean(), "[false, false, false, null]")); + CheckScalarBinary("greater", ArrayFromJSON(int32(), "[0, 1, 2, null]"), + std::make_shared(4), + ArrayFromJSON(boolean(), "[null, null, null, null]")); + // Not currently implemented since it would invoke a double implicit cast: // dictionary(int32, int8) -> int8 -> int32 // CheckScalarBinary("greater", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index 786110996dc..dd10fce3e4d 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -211,6 +211,19 @@ static std::string FormatValueDescr(const ValueDescr& descr) { std::string ValueDescr::ToString() const { return FormatValueDescr(*this); } +std::string ValueDescr::ToString(const std::vector& descrs) { + std::stringstream ss; + ss << "("; + for (size_t i = 0; i < descrs.size(); ++i) { + if (i > 0) { + ss << ", "; + } + ss << descrs[i].ToString(); + } + ss << ")"; + return ss.str(); +} + void PrintTo(const ValueDescr& descr, std::ostream* os) { *os << descr.ToString(); } std::string Datum::ToString() const { diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index fb783ea5261..6ba6af7f79e 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -89,6 +89,7 @@ struct ARROW_EXPORT ValueDescr { bool operator!=(const ValueDescr& other) const { return !(*this == other); } std::string ToString() const; + static std::string ToString(const std::vector&); ARROW_EXPORT friend void PrintTo(const ValueDescr&, std::ostream*); }; From e60e55568465bbaa6e6f6c748358c925d2bf9ecb Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Jan 2021 13:56:40 -0500 Subject: [PATCH 07/22] use explicit schema to avoid inferring bool as str --- python/pyarrow/tests/parquet/test_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/tests/parquet/test_dataset.py b/python/pyarrow/tests/parquet/test_dataset.py index 42ce187f58b..48c62c7e458 100644 --- a/python/pyarrow/tests/parquet/test_dataset.py +++ b/python/pyarrow/tests/parquet/test_dataset.py @@ -191,6 +191,11 @@ def test_filters_equivalency(tempdir, use_legacy_dataset): ['string', string_keys], ['boolean', boolean_keys] ] + schema = pa.schema({ + 'integer': pa.int32(), + 'string': pa.string(), + 'boolean', pa.boolean() + }) df = pd.DataFrame({ 'integer': np.array(integer_keys, dtype='i4').repeat(15), @@ -204,7 +209,7 @@ def test_filters_equivalency(tempdir, use_legacy_dataset): # Old filters syntax: # integer == 1 AND string != b AND boolean == True dataset = pq.ParquetDataset( - base_path, filesystem=fs, + base_path, filesystem=fs, schema=schema, filters=[('integer', '=', 1), ('string', '!=', 'b'), ('boolean', '==', True)], use_legacy_dataset=use_legacy_dataset, From 2528d95dee27ed26d762f25d5c954e6dd93f0782 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 19 Jan 2021 14:42:15 -0500 Subject: [PATCH 08/22] apply implicit casts to R binding --- .../compute/kernels/scalar_cast_numeric.cc | 2 +- r/R/expression.R | 10 --------- r/tests/testthat/test-compute-arith.R | 21 +++++++++---------- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 6e550fb12c0..4520230f2ae 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -62,7 +62,7 @@ Status CheckFloatTruncation(const Datum& input, const Datum& output) { return is_valid && static_cast(out_val) != in_val; }; auto GetErrorMessage = [&](InT val) { - return Status::Invalid("Float value ", val, " was truncated converting to", + return Status::Invalid("Float value ", val, " was truncated converting to ", *output.type()); }; diff --git a/r/R/expression.R b/r/R/expression.R index 0198e0ebe6a..5475f7a44bc 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -144,16 +144,6 @@ eval_array_expression <- function(x) { a } }) - if (length(x$args) == 2L) { - # Insert implicit casts - if (inherits(x$args[[1]], "Scalar")) { - x$args[[1]] <- x$args[[1]]$cast(x$args[[2]]$type) - } else if (inherits(x$args[[2]], "Scalar")) { - x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) - } else if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Array")) { - x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type) - } - } call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index d37367d47c8..3bfa8c2b41e 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -18,32 +18,31 @@ test_that("Addition", { a <- Array$create(c(1:4, NA_integer_)) expect_type_equal(a, int32()) - expect_type_equal(a + 4, int32()) - expect_equal(a + 4, Array$create(c(5:8, NA_integer_))) - expect_identical(as.vector(a + 4), c(5:8, NA_integer_)) + expect_type_equal(a + 4L, int32()) + expect_type_equal(a + 4, float64()) + expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) + expect_identical(as.vector(a + 4L), c(5:8, NA_integer_)) expect_equal(a + 4L, Array$create(c(5:8, NA_integer_))) expect_vector(a + 4L, c(5:8, NA_integer_)) expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5))) - # overflow errors — this is slightly different from R's `NA` coercion when - # overflowing, but better than the alternative of silently restarting - casted <- a$cast(int8()) - expect_error(casted + 127) - expect_error(casted + 200) + a8 <- a$cast(int8()) + expect_type_equal(a8 + Scalar$create(1, int8()), int8()) + expect_type_equal(a8 + 127L, int32()) + expect_type_equal(a8 + 200L, int32()) - skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919") expect_type_equal(a + 4.1, float64()) expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) }) test_that("Subtraction", { a <- Array$create(c(1:4, NA_integer_)) - expect_equal(a - 3, Array$create(c(-2:1, NA_integer_))) + expect_equal(a - 3L, Array$create(c(-2:1, NA_integer_))) }) test_that("Multiplication", { a <- Array$create(c(1:4, NA_integer_)) - expect_equal(a * 2, Array$create(c(1:4 * 2L, NA_integer_))) + expect_equal(a * 2L, Array$create(c(1:4 * 2L, NA_integer_))) }) test_that("Division", { From ecd778cebf728d98a7be1d06494b8f60cf9e95a5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 20 Jan 2021 20:48:02 -0500 Subject: [PATCH 09/22] ensure value_set is cast to the input type --- .../compute/kernels/scalar_set_lookup.cc | 74 ++++++++++--------- .../compute/kernels/scalar_set_lookup_test.cc | 15 ++++ cpp/src/arrow/dataset/expression.cc | 21 ------ cpp/src/arrow/dataset/expression_test.cc | 22 ++---- 4 files changed, 61 insertions(+), 71 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc index fcfedd572f4..ffc1e11a7be 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup.cc @@ -18,6 +18,7 @@ #include "arrow/array/array_base.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/util/bit_util.h" @@ -36,10 +37,9 @@ namespace { template struct SetLookupState : public KernelState { - explicit SetLookupState(const SetLookupOptions& options, MemoryPool* pool) - : options(options), lookup_table(pool, 0) {} + explicit SetLookupState(MemoryPool* pool) : lookup_table(pool, 0) {} - Status Init() { + Status Init(const SetLookupOptions& options) { if (options.value_set.kind() == Datum::ARRAY) { RETURN_NOT_OK(AddArrayValueSet(*options.value_set.array())); } else if (options.value_set.kind() == Datum::CHUNKED_ARRAY) { @@ -53,7 +53,9 @@ struct SetLookupState : public KernelState { if (lookup_table.size() != options.value_set.length()) { return Status::NotImplemented("duplicate values in value_set"); } - value_set_has_null = (lookup_table.GetNull() >= 0); + if (!options.skip_nulls) { + null_index = lookup_table.GetNull(); + } return Status::OK(); } @@ -72,22 +74,19 @@ struct SetLookupState : public KernelState { } using MemoTable = typename HashTraits::MemoTableType; - const SetLookupOptions& options; MemoTable lookup_table; - bool value_set_has_null; + int32_t null_index = -1; }; template <> struct SetLookupState : public KernelState { - explicit SetLookupState(const SetLookupOptions& options, MemoryPool*) - : options(options) {} + explicit SetLookupState(MemoryPool*) {} - Status Init() { - this->value_set_has_null = (options.value_set.length() > 0); + Status Init(const SetLookupOptions& options) { + value_set_has_null = (options.value_set.length() > 0) && !options.skip_nulls; return Status::OK(); } - const SetLookupOptions& options; bool value_set_has_null; }; @@ -118,21 +117,20 @@ struct UnsignedIntType<8> { // Constructing the type requires a type parameter struct InitStateVisitor { KernelContext* ctx; - const SetLookupOptions* options; + SetLookupOptions options; + const std::shared_ptr& arg_type; std::unique_ptr result; - InitStateVisitor(KernelContext* ctx, const SetLookupOptions* options) - : ctx(ctx), options(options) {} + InitStateVisitor(KernelContext* ctx, const KernelInitArgs& args) + : ctx(ctx), + options(*checked_cast(args.options)), + arg_type(args.inputs[0].type) {} template Status Init() { - if (options == nullptr) { - return Status::Invalid( - "Attempted to call a set lookup function without SetLookupOptions"); - } using StateType = SetLookupState; - result.reset(new StateType(*options, ctx->exec_context()->memory_pool())); - return static_cast(result.get())->Init(); + result.reset(new StateType(ctx->exec_context()->memory_pool())); + return static_cast(result.get())->Init(options); } Status Visit(const DataType&) { return Init(); } @@ -157,7 +155,13 @@ struct InitStateVisitor { Status Visit(const FixedSizeBinaryType& type) { return Init(); } Status GetResult(std::unique_ptr* out) { - RETURN_NOT_OK(VisitTypeInline(*options->value_set.type(), this)); + if (!options.value_set.type()->Equals(arg_type)) { + ARROW_ASSIGN_OR_RAISE( + options.value_set, + Cast(options.value_set, CastOptions::Safe(arg_type), ctx->exec_context())); + } + + RETURN_NOT_OK(VisitTypeInline(*arg_type, this)); *out = std::move(result); return Status::OK(); } @@ -165,9 +169,14 @@ struct InitStateVisitor { std::unique_ptr InitSetLookup(KernelContext* ctx, const KernelInitArgs& args) { - InitStateVisitor visitor{ctx, static_cast(args.options)}; + if (args.options == nullptr) { + ctx->SetStatus(Status::Invalid( + "Attempted to call a set lookup function without SetLookupOptions")); + return nullptr; + } + std::unique_ptr result; - ctx->SetStatus(visitor.GetResult(&result)); + ctx->SetStatus(InitStateVisitor{ctx, args}.GetResult(&result)); return result; } @@ -185,7 +194,7 @@ struct IndexInVisitor { const auto& state = checked_cast&>(*ctx->state()); if (data.length != 0) { // skip_nulls is honored for consistency with other types - if (state.value_set_has_null && !state.options.skip_nulls) { + if (state.value_set_has_null) { RETURN_NOT_OK(this->builder.Reserve(data.length)); for (int64_t i = 0; i < data.length; ++i) { this->builder.UnsafeAppend(0); @@ -203,7 +212,6 @@ struct IndexInVisitor { const auto& state = checked_cast&>(*ctx->state()); - int32_t null_index = state.options.skip_nulls ? -1 : state.lookup_table.GetNull(); RETURN_NOT_OK(this->builder.Reserve(data.length)); VisitArrayDataInline( data, @@ -218,9 +226,9 @@ struct IndexInVisitor { } }, [&]() { - if (null_index != -1) { + if (state.null_index != -1) { // value_set included null - this->builder.UnsafeAppend(null_index); + this->builder.UnsafeAppend(state.null_index); } else { // value_set does not include null; output null this->builder.UnsafeAppendNull(); @@ -283,13 +291,8 @@ struct IsInVisitor { const auto& state = checked_cast&>(*ctx->state()); ArrayData* output = out->mutable_array(); // skip_nulls is honored for consistency with other types - if (state.value_set_has_null && !state.options.skip_nulls) { - BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, - output->length, true); - } else { - BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, - output->length, false); - } + BitUtil::SetBitsTo(output->buffers[1]->mutable_data(), output->offset, output->length, + state.value_set_has_null); return Status::OK(); } @@ -301,6 +304,7 @@ struct IsInVisitor { FirstTimeBitmapWriter writer(output->buffers[1]->mutable_data(), output->offset, output->length); + VisitArrayDataInline( this->data, [&](T v) { @@ -312,7 +316,7 @@ struct IsInVisitor { writer.Next(); }, [&]() { - if (!state.options.skip_nulls && state.lookup_table.GetNull() != -1) { + if (state.null_index != -1) { writer.Set(); } else { writer.Clear(); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index 5b87d09dbbc..2285c1cb9ab 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -85,6 +85,21 @@ TEST_F(TestIsInKernel, CallBinary) { AssertArraysEqual(*expected, *out.make_array()); } +TEST_F(TestIsInKernel, ImplicitlyCastValueSet) { + auto input = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6, 7, 8]"); + + SetLookupOptions opts{ArrayFromJSON(int32(), "[2, 3, 5, 7]")}; + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction("is_in", {input}, &opts)); + + auto expected = ArrayFromJSON(boolean(), ("[false, false, true, true, false," + "true, false, true, false]")); + AssertArraysEqual(*expected, *out.make_array()); + + // fails; value_set cannot be cast to int8 + opts = SetLookupOptions{ArrayFromJSON(float32(), "[2.5, 3.1, 5.0]")}; + ASSERT_RAISES(Invalid, CallFunction("is_in", {input}, &opts)); +} + template class TestIsInKernelPrimitive : public ::testing::Test {}; diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 1a5e17a993f..261464f7ba6 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -391,25 +391,6 @@ Result> InitKernelState( return std::move(kernel_state); } -Status ImplicitCastFunctionOptions(Expression::Call* call) { - if (auto options = GetSetLookupOptions(*call)) { - if (options->value_set.type()->Equals(call->arguments[0].type())) { - return Status::OK(); - } - - // The value_set is assumed smaller than inputs, casting it should be cheaper. - auto new_options = std::make_shared(*options); - ARROW_ASSIGN_OR_RAISE( - new_options->value_set, - compute::Cast(std::move(new_options->value_set), call->arguments[0].type())); - options = new_options.get(); - call->options = std::move(new_options); - return Status::OK(); - } - - return Status::OK(); -} - // Produce a bound Expression from unbound Call and bound arguments. Result BindNonRecursive(Expression::Call call, bool insert_implicit_casts, compute::ExecContext* exec_context) { @@ -451,8 +432,6 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ BindNonRecursive(std::move(implicit_cast), /*insert_implicit_casts=*/false, exec_context)); } - - RETURN_NOT_OK(ImplicitCastFunctionOptions(&call)); } compute::KernelContext kernel_context(exec_context); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 35f488aaaea..4b05d7251e2 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -414,17 +414,11 @@ TEST(Expression, BindWithImplicitCasts) { cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); } - auto Opts = [](std::shared_ptr type) { - return compute::SetLookupOptions{ArrayFromJSON(type, R"(["a"])")}; - }; - - // cast value_set to argument type - ExpectBindsTo(call("is_in", {field_ref("str")}, Opts(binary())), - call("is_in", {field_ref("str")}, Opts(utf8()))); + compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")}; // cast dictionary to value type - ExpectBindsTo(call("is_in", {field_ref("dict_str")}, Opts(utf8())), - call("is_in", {cast(field_ref("dict_str"), utf8())}, Opts(utf8()))); + ExpectBindsTo(call("is_in", {field_ref("dict_str")}, in_a), + call("is_in", {cast(field_ref("dict_str"), utf8())}, in_a)); } TEST(Expression, BindNestedCall) { @@ -998,13 +992,11 @@ TEST(Expression, SimplifyWithGuarantee) { .Expect(equal(field_ref("i32"), literal(3))); // simplification can see through implicit casts - Simplify{ - or_({equal(field_ref("f32"), literal(0)), - call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true})})} + compute::SetLookupOptions in_123{ArrayFromJSON(int32(), "[1,2,3]"), true}; + Simplify{or_({equal(field_ref("f32"), literal(0)), + call("is_in", {field_ref("i64")}, in_123)})} .WithGuarantee(greater(field_ref("f32"), literal(0.F))) - .Expect(call("is_in", {field_ref("i64")}, - compute::SetLookupOptions{ArrayFromJSON(int64(), "[1,2,3]"), true})); + .Expect(call("is_in", {field_ref("i64")}, in_123)); } TEST(Expression, SimplifyThenExecute) { From 003ef40f5bcc051981ba1c4c10e038296ff55f1c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 20 Jan 2021 20:54:17 -0500 Subject: [PATCH 10/22] always check for an exact match first --- cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 7 ++++--- cpp/src/arrow/compute/kernels/scalar_compare.cc | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 059bf0aa6b7..17db5b06502 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -270,6 +270,9 @@ struct ArithmeticFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + EnsureDictionaryDecoded(values); ReplaceNullWithOtherType(values); @@ -279,9 +282,7 @@ struct ArithmeticFunction : ScalarFunction { } } - if (auto kernel = arrow::compute::detail::DispatchExactImpl(this, *values)) { - return kernel; - } + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 75eebdf4bc3..b658d3ea024 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -78,6 +78,9 @@ struct CompareFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + EnsureDictionaryDecoded(values); ReplaceNullWithOtherType(values); @@ -87,9 +90,7 @@ struct CompareFunction : ScalarFunction { } } - if (auto kernel = arrow::compute::detail::DispatchExactImpl(this, *values)) { - return kernel; - } + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); } }; From 7ebb067c1df50f9be00e0fe32e75b3afa94409ca Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 20 Jan 2021 21:23:33 -0500 Subject: [PATCH 11/22] add implicit cast between timestamp-like types to comparison --- .../arrow/compute/kernels/codegen_internal.cc | 22 +++++++++++++++++++ .../arrow/compute/kernels/codegen_internal.h | 3 +++ .../arrow/compute/kernels/scalar_compare.cc | 4 ++++ .../compute/kernels/scalar_compare_test.cc | 16 ++++++++++++-- 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 0936f6e752d..baf7ade2584 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -232,6 +232,28 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { return at_least_one_signed ? int8() : uint8(); } +std::shared_ptr CommonTimestamp(const std::vector& descrs) { + TimeUnit::type finest_unit = TimeUnit::SECOND; + + for (const auto& descr : descrs) { + auto id = descr.type->id(); + // a common timestamp is only possible if all types are timestamp like + switch (id) { + case Type::DATE32: + case Type::DATE64: + continue; + case Type::TIMESTAMP: + finest_unit = + std::max(finest_unit, checked_cast(*descr.type).unit()); + continue; + default: + return nullptr; + } + } + + return timestamp(finest_unit); +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index c50cb2669df..a547b000c80 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1195,6 +1195,9 @@ void ReplaceNullWithOtherType(std::vector* descrs); ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); +ARROW_EXPORT +std::shared_ptr CommonTimestamp(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index b658d3ea024..c58643d58b2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -88,6 +88,10 @@ struct CompareFunction : ScalarFunction { for (auto& descr : *values) { descr.type = type; } + } else if (auto type = CommonTimestamp(*values)) { + for (auto& descr : *values) { + descr.type = type; + } } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 5fc05330cdd..9276532a704 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -475,6 +475,12 @@ TEST(TestCompareKernel, DispatchBest) { {float64(), float64()}); CheckDispatchBest(name, {dictionary(int8(), utf8()), utf8()}, {utf8(), utf8()}); + + CheckDispatchBest(name, {timestamp(TimeUnit::MICRO), date64()}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); + + CheckDispatchBest(name, {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO)}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); } } @@ -496,10 +502,16 @@ TEST(TestCompareKernel, GreaterWithImplicitCasts) { std::make_shared(4), ArrayFromJSON(boolean(), "[null, null, null, null]")); + CheckScalarBinary("greater", + ArrayFromJSON(timestamp(TimeUnit::SECOND), + R"(["1970-01-01","2000-02-29","1900-02-28"])"), + ArrayFromJSON(date64(), "[86400000, 0, 86400000]"), + ArrayFromJSON(boolean(), "[false, true, false]")); + // Not currently implemented since it would invoke a double implicit cast: // dictionary(int32, int8) -> int8 -> int32 - // CheckScalarBinary("greater", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, - // null]"), + // CheckScalarBinary("greater", + // ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), // ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), // ArrayFromJSON(boolean(), "[false, false, false, null]")); } From 8100d212042cbb407266bf6e2fb20cad38f51468 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 22 Jan 2021 11:54:32 -0500 Subject: [PATCH 12/22] support dictionary(X) -> Y casts if X -> Y --- cpp/src/arrow/compute/cast.cc | 75 ++++++++----------- cpp/src/arrow/compute/cast.h | 12 ++- cpp/src/arrow/compute/function.cc | 44 +++++++---- cpp/src/arrow/compute/function.h | 3 +- .../arrow/compute/kernels/codegen_internal.cc | 40 ++++++++++ .../arrow/compute/kernels/codegen_internal.h | 6 ++ .../compute/kernels/scalar_arithmetic.cc | 6 +- .../compute/kernels/scalar_cast_boolean.cc | 1 + .../compute/kernels/scalar_cast_internal.cc | 42 +++++------ .../compute/kernels/scalar_cast_string.cc | 26 +++---- .../compute/kernels/scalar_cast_temporal.cc | 2 +- .../arrow/compute/kernels/scalar_cast_test.cc | 74 ++++++++++++++++++ .../arrow/compute/kernels/scalar_compare.cc | 12 ++- .../compute/kernels/scalar_compare_test.cc | 3 + cpp/src/arrow/dataset/expression.cc | 20 ++++- cpp/src/arrow/dataset/expression_internal.h | 6 +- cpp/src/arrow/dataset/expression_test.cc | 51 ++++++++----- cpp/src/arrow/dataset/test_util.h | 1 + cpp/src/arrow/result.h | 8 +- cpp/src/arrow/scalar.cc | 7 ++ cpp/src/arrow/scalar.h | 3 + cpp/src/arrow/status.cc | 2 +- cpp/src/arrow/testing/gtest_util.h | 35 +++++---- python/pyarrow/tests/parquet/test_dataset.py | 9 +-- r/tests/testthat/test-dataset.R | 7 +- 25 files changed, 326 insertions(+), 169 deletions(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index fd705ff973b..4462256164f 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -124,26 +124,15 @@ void RegisterScalarCast(FunctionRegistry* registry) { } // namespace internal -struct CastFunction::CastFunctionImpl { - Type::type out_type; - std::unordered_set in_types; -}; - -CastFunction::CastFunction(std::string name, Type::type out_type) - : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr) { - impl_.reset(new CastFunctionImpl()); - impl_->out_type = out_type; -} - -CastFunction::~CastFunction() = default; - -Type::type CastFunction::out_type_id() const { return impl_->out_type; } +CastFunction::CastFunction(std::string name, Type::type out_type_id) + : ScalarFunction(std::move(name), Arity::Unary(), /*doc=*/nullptr), + out_type_id_(out_type_id) {} Status CastFunction::AddKernel(Type::type in_type_id, ScalarKernel kernel) { // We use the same KernelInit for every cast kernel.init = internal::CastState::Init; RETURN_NOT_OK(ScalarFunction::AddKernel(kernel)); - impl_->in_types.insert(static_cast(in_type_id)); + in_type_ids_.push_back(in_type_id); return Status::OK(); } @@ -159,19 +148,10 @@ Status CastFunction::AddKernel(Type::type in_type_id, std::vector in_ return AddKernel(in_type_id, std::move(kernel)); } -bool CastFunction::CanCastTo(const DataType& out_type) const { - return impl_->in_types.find(static_cast(out_type.id())) != impl_->in_types.end(); -} - Result CastFunction::DispatchExact( const std::vector& values) const { - const int passed_num_args = static_cast(values.size()); + RETURN_NOT_OK(CheckArity(values)); - // Validate arity - if (passed_num_args != 1) { - return Status::Invalid("Cast functions accept 1 argument but passed ", - passed_num_args); - } std::vector candidate_kernels; for (const auto& kernel : kernels_) { if (kernel.signature->MatchesInputs(values)) { @@ -181,25 +161,28 @@ Result CastFunction::DispatchExact( if (candidate_kernels.size() == 0) { return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), - " to ", ToTypeName(impl_->out_type), " using function ", + " to ", ToTypeName(out_type_id_), " using function ", this->name()); - } else if (candidate_kernels.size() == 1) { + } + + if (candidate_kernels.size() == 1) { // One match, return it return candidate_kernels[0]; - } else { - // Now we are in a casting scenario where we may have both a EXACT_TYPE and - // a SAME_TYPE_ID. So we will see if there is an exact match among the - // candidate kernels and if not we will just return the first one - for (auto kernel : candidate_kernels) { - const InputType& arg0 = kernel->signature->in_types()[0]; - if (arg0.kind() == InputType::EXACT_TYPE) { - // Bingo. Return it - return kernel; - } + } + + // Now we are in a casting scenario where we may have both a EXACT_TYPE and + // a SAME_TYPE_ID. So we will see if there is an exact match among the + // candidate kernels and if not we will just return the first one + for (auto kernel : candidate_kernels) { + const InputType& arg0 = kernel->signature->in_types()[0]; + if (arg0.kind() == InputType::EXACT_TYPE) { + // Bingo. Return it + return kernel; } - // We didn't find an exact match. So just return some kernel that matches - return candidate_kernels[0]; } + + // We didn't find an exact match. So just return some kernel that matches + return candidate_kernels[0]; } Result Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { @@ -225,13 +208,21 @@ Result> GetCastFunction( } bool CanCast(const DataType& from_type, const DataType& to_type) { - // TODO internal::EnsureInitCastTable(); - auto it = internal::g_cast_table.find(static_cast(from_type.id())); + auto it = internal::g_cast_table.find(static_cast(to_type.id())); if (it == internal::g_cast_table.end()) { return false; } - return it->second->CanCastTo(to_type); + + const CastFunction* function = it->second.get(); + DCHECK_EQ(function->out_type_id(), to_type.id()); + + for (auto from_id : function->in_type_ids()) { + // XXX should probably check the output type as well + if (from_type.id() == from_id) return true; + } + + return false; } } // namespace compute diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 0b9d9caf882..60111df7bf2 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -82,10 +82,10 @@ struct ARROW_EXPORT CastOptions : public FunctionOptions { // the same execution machinery class CastFunction : public ScalarFunction { public: - CastFunction(std::string name, Type::type out_type); - ~CastFunction() override; + CastFunction(std::string name, Type::type out_type_id); - Type::type out_type_id() const; + Type::type out_type_id() const { return out_type_id_; } + const std::vector& in_type_ids() const { return in_type_ids_; } Status AddKernel(Type::type in_type_id, std::vector in_types, OutputType out_type, ArrayKernelExec exec, @@ -96,14 +96,12 @@ class CastFunction : public ScalarFunction { // function to CastInit Status AddKernel(Type::type in_type_id, ScalarKernel kernel); - bool CanCastTo(const DataType& out_type) const; - Result DispatchExact( const std::vector& values) const override; private: - struct CastFunctionImpl; - std::unique_ptr impl_; + std::vector in_type_ids_; + const Type::type out_type_id_; }; ARROW_EXPORT diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 42903cf3627..ccfbeedfe33 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -37,20 +37,32 @@ static const FunctionDoc kEmptyFunctionDoc{}; const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } -Status Function::CheckArity(int passed_num_args) const { - if (arity_.is_varargs && passed_num_args < arity_.num_args) { - return Status::Invalid("VarArgs function needs at least ", arity_.num_args, - " arguments but kernel accepts only ", passed_num_args); +Status CheckArityImpl(const Function* function, int passed_num_args, + const char* passed_num_args_label) { + if (function->arity().is_varargs && passed_num_args < function->arity().num_args) { + return Status::Invalid("VarArgs function needs at least ", function->arity().num_args, + " arguments but ", passed_num_args_label, " only ", + passed_num_args); } - if (!arity_.is_varargs && passed_num_args != arity_.num_args) { - return Status::Invalid("Function accepts ", arity_.num_args, - " arguments but kernel accepts ", passed_num_args); + if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) { + return Status::Invalid("Function ", function->name(), " accepts ", + function->arity().num_args, " arguments but ", + passed_num_args_label, " ", passed_num_args); } return Status::OK(); } +Status Function::CheckArity(const std::vector& in_types) const { + return CheckArityImpl(this, static_cast(in_types.size()), "kernel accepts"); +} + +Status Function::CheckArity(const std::vector& descrs) const { + return CheckArityImpl(this, static_cast(descrs.size()), + "attempted to look up kernel(s) with"); +} + namespace detail { Status NoMatchingKernel(const Function* func, const std::vector& descrs) { @@ -123,7 +135,7 @@ Result Function::DispatchExact( if (kind_ == Function::META) { return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); } - RETURN_NOT_OK(CheckArity(static_cast(values.size()))); + RETURN_NOT_OK(CheckArity(values)); if (auto kernel = detail::DispatchExactImpl(this, values)) { return kernel; @@ -135,7 +147,7 @@ Result Function::DispatchBest(std::vector* values) co if (kind_ == Function::META) { return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); } - RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + RETURN_NOT_OK(CheckArity(*values)); // first try for an exact match if (auto kernel = detail::DispatchExactImpl(this, *values)) { @@ -220,7 +232,7 @@ Status Function::Validate() const { Status ScalarFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(static_cast(in_types.size()))); + RETURN_NOT_OK(CheckArity(in_types)); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -232,7 +244,7 @@ Status ScalarFunction::AddKernel(std::vector in_types, OutputType out } Status ScalarFunction::AddKernel(ScalarKernel kernel) { - RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -242,7 +254,7 @@ Status ScalarFunction::AddKernel(ScalarKernel kernel) { Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(static_cast(in_types.size()))); + RETURN_NOT_OK(CheckArity(in_types)); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -254,7 +266,7 @@ Status VectorFunction::AddKernel(std::vector in_types, OutputType out } Status VectorFunction::AddKernel(VectorKernel kernel) { - RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -263,7 +275,7 @@ Status VectorFunction::AddKernel(VectorKernel kernel) { } Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { - RETURN_NOT_OK(CheckArity(static_cast(kernel.signature->in_types().size()))); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -274,7 +286,9 @@ Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { Result MetaFunction::Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { - RETURN_NOT_OK(CheckArity(static_cast(args.size()))); + RETURN_NOT_OK( + CheckArityImpl(this, static_cast(args.size()), "attempted to Execute with")); + if (options == nullptr) { options = default_options(); } diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 23ecd6f160e..0ae4d22e200 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -199,7 +199,8 @@ class ARROW_EXPORT Function { doc_(doc ? doc : &FunctionDoc::Empty()), default_options_(default_options) {} - Status CheckArity(int passed_num_args) const; + Status CheckArity(const std::vector&) const; + Status CheckArity(const std::vector&) const; std::string name_; Function::Kind kind_; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index baf7ade2584..2df44ee8f9a 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -201,6 +201,13 @@ void ReplaceNullWithOtherType(std::vector* descrs) { } } +void ReplaceTypes(const std::shared_ptr& type, + std::vector* descrs) { + for (auto& descr : *descrs) { + descr.type = type; + } +} + std::shared_ptr CommonNumeric(const std::vector& descrs) { for (const auto& descr : descrs) { auto id = descr.type->id(); @@ -254,6 +261,39 @@ std::shared_ptr CommonTimestamp(const std::vector& descrs) return timestamp(finest_unit); } +std::shared_ptr CommonBinary(const std::vector& descrs) { + bool all_utf8 = true, all_offset32 = true; + + for (const auto& descr : descrs) { + auto id = descr.type->id(); + // a common varbinary type is only possible if all types are binary like + switch (id) { + case Type::STRING: + continue; + case Type::BINARY: + all_utf8 = false; + continue; + case Type::LARGE_STRING: + all_offset32 = false; + continue; + case Type::LARGE_BINARY: + all_offset32 = false; + all_utf8 = false; + continue; + default: + return nullptr; + } + } + + if (all_utf8) { + if (all_offset32) return utf8(); + return large_utf8(); + } + + if (all_offset32) return binary(); + return large_binary(); +} + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index a547b000c80..f39ffdcca11 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1192,12 +1192,18 @@ void EnsureDictionaryDecoded(std::vector* descrs); ARROW_EXPORT void ReplaceNullWithOtherType(std::vector* descrs); +ARROW_EXPORT +void ReplaceTypes(const std::shared_ptr&, std::vector* descrs); + ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); ARROW_EXPORT std::shared_ptr CommonTimestamp(const std::vector& descrs); +ARROW_EXPORT +std::shared_ptr CommonBinary(const std::vector& descrs); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 17db5b06502..7abaa1c1a59 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -268,7 +268,7 @@ struct ArithmeticFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { - RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + RETURN_NOT_OK(CheckArity(*values)); using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -277,9 +277,7 @@ struct ArithmeticFunction : ScalarFunction { ReplaceNullWithOtherType(values); if (auto type = CommonNumeric(*values)) { - for (auto& descr : *values) { - descr.type = type; - } + ReplaceTypes(type, values); } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc b/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc index 07026db83be..e529d3791aa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_boolean.cc @@ -50,6 +50,7 @@ struct ParseBooleanString { std::vector> GetBooleanCasts() { auto func = std::make_shared("cast_boolean", Type::BOOL); AddCommonCasts(Type::BOOL, boolean(), func.get()); + AddZeroCopyCast(Type::BOOL, boolean(), boolean(), func.get()); for (const auto& ty : NumericTypes()) { ArrayKernelExec exec = diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index f8dde20e3aa..7221722d53a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -149,17 +149,13 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type, const Dat // ---------------------------------------------------------------------- void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (out->is_scalar()) { - KERNEL_ASSIGN_OR_RAISE(*out, ctx, - batch[0].scalar_as().GetEncodedValue()); - return; - } + DCHECK(out->is_array()); DictionaryArray dict_arr(batch[0].array()); const CastOptions& options = checked_cast(*ctx->state()).options; const auto& dict_type = *dict_arr.dictionary()->type(); - if (!dict_type.Equals(options.to_type)) { + if (!dict_type.Equals(options.to_type) && !CanCast(dict_type, *options.to_type)) { ctx->SetStatus(Status::Invalid("Cast type ", options.to_type->ToString(), " incompatible with dictionary type ", dict_type.ToString())); @@ -169,6 +165,10 @@ void UnpackDictionary(KernelContext* ctx, const ExecBatch& batch, Datum* out) { KERNEL_ASSIGN_OR_RAISE(*out, ctx, Take(Datum(dict_arr.dictionary()), Datum(dict_arr.indices()), TakeOptions::Defaults(), ctx->exec_context())); + + if (!dict_type.Equals(options.to_type)) { + KERNEL_ASSIGN_OR_RAISE(*out, ctx, Cast(*out, options)); + } } void OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -224,28 +224,23 @@ Result ResolveOutputFromOptions(KernelContext* ctx, OutputType kOutputTargetType(ResolveOutputFromOptions); void ZeroCopyCastExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { - // Make a copy of the buffers into a destination array without carrying - // the type - const ArrayData& input = *batch[0].array(); - ArrayData* output = out->mutable_array(); - output->length = input.length; - output->SetNullCount(input.null_count); - output->buffers = input.buffers; - output->offset = input.offset; - output->child_data = input.child_data; - } else { - ctx->SetStatus( - Status::NotImplemented("This cast not yet implemented for " - "scalar input")); - } + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + // Make a copy of the buffers into a destination array without carrying + // the type + const ArrayData& input = *batch[0].array(); + ArrayData* output = out->mutable_array(); + output->length = input.length; + output->SetNullCount(input.null_count); + output->buffers = input.buffers; + output->offset = input.offset; + output->child_data = input.child_data; } void AddZeroCopyCast(Type::type in_type_id, InputType in_type, OutputType out_type, CastFunction* func) { auto sig = KernelSignature::Make({in_type}, out_type); ScalarKernel kernel; - kernel.exec = ZeroCopyCastExec; + kernel.exec = TrivialScalarUnaryAsArraysExec(ZeroCopyCastExec); kernel.signature = sig; kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; @@ -268,7 +263,8 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun // XXX: Uses Take and does its own memory allocation for the moment. We can // fix this later. DCHECK_OK(func->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, out_ty, - UnpackDictionary, NullHandling::COMPUTED_NO_PREALLOCATE, + TrivialScalarUnaryAsArraysExec(UnpackDictionary), + NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 7d502f046fc..b339018072e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -215,41 +215,41 @@ void AddBinaryToBinaryCast(CastFunction* func) { auto out_ty = TypeTraits::type_singleton(); DCHECK_OK(func->AddKernel( - OutType::type_id, {in_ty}, out_ty, + InType::type_id, {in_ty}, out_ty, TrivialScalarUnaryAsArraysExec(BinaryToBinaryCastFunctor::Exec), NullHandling::COMPUTED_NO_PREALLOCATE)); } +template +void AddBinaryToBinaryCast(CastFunction* func) { + AddBinaryToBinaryCast(func); + AddBinaryToBinaryCast(func); + AddBinaryToBinaryCast(func); + AddBinaryToBinaryCast(func); +} + } // namespace std::vector> GetBinaryLikeCasts() { auto cast_binary = std::make_shared("cast_binary", Type::BINARY); AddCommonCasts(Type::BINARY, binary(), cast_binary.get()); - AddBinaryToBinaryCast(cast_binary.get()); - AddBinaryToBinaryCast(cast_binary.get()); - AddBinaryToBinaryCast(cast_binary.get()); + AddBinaryToBinaryCast(cast_binary.get()); auto cast_large_binary = std::make_shared("cast_large_binary", Type::LARGE_BINARY); AddCommonCasts(Type::LARGE_BINARY, large_binary(), cast_large_binary.get()); - AddBinaryToBinaryCast(cast_large_binary.get()); - AddBinaryToBinaryCast(cast_large_binary.get()); - AddBinaryToBinaryCast(cast_large_binary.get()); + AddBinaryToBinaryCast(cast_large_binary.get()); auto cast_string = std::make_shared("cast_string", Type::STRING); AddCommonCasts(Type::STRING, utf8(), cast_string.get()); AddNumberToStringCasts(cast_string.get()); - AddBinaryToBinaryCast(cast_string.get()); - AddBinaryToBinaryCast(cast_string.get()); - AddBinaryToBinaryCast(cast_string.get()); + AddBinaryToBinaryCast(cast_string.get()); auto cast_large_string = std::make_shared("cast_large_string", Type::LARGE_STRING); AddCommonCasts(Type::LARGE_STRING, large_utf8(), cast_large_string.get()); AddNumberToStringCasts(cast_large_string.get()); - AddBinaryToBinaryCast(cast_large_string.get()); - AddBinaryToBinaryCast(cast_large_string.get()); - AddBinaryToBinaryCast(cast_large_string.get()); + AddBinaryToBinaryCast(cast_large_string.get()); auto cast_fsb = std::make_shared("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index e470f9f90de..d7d1faf7ae5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -322,7 +322,7 @@ struct CastFunctor::value>> template void AddCrossUnitCast(CastFunction* func) { ScalarKernel kernel; - kernel.exec = CastFunctor::Exec; + kernel.exec = TrivialScalarUnaryAsArraysExec(CastFunctor::Exec); kernel.signature = KernelSignature::Make({InputType(Type::type_id)}, kOutputTargetType); DCHECK_OK(func->AddKernel(Type::type_id, std::move(kernel))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 350728793e6..bc7110d1c10 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -75,6 +75,9 @@ static std::vector> kNumericTypes = { uint8(), int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}; +static std::vector> kBaseBinaryTypes = { + binary(), utf8(), large_binary(), large_utf8()}; + static void AssertBufferSame(const Array& left, const Array& right, int buffer_index) { ASSERT_EQ(left.data()->buffers[buffer_index].get(), right.data()->buffers[buffer_index].get()); @@ -403,6 +406,77 @@ class TestCast : public TestBase { } }; +TEST_F(TestCast, CanCast) { + auto ExpectCanCast = [](std::shared_ptr from, + std::vector> to_set, + bool expected = true) { + for (auto to : to_set) { + EXPECT_EQ(CanCast(*from, *to), expected) << " from: " << from->ToString() << "\n" + << " to: " << to->ToString(); + } + }; + + auto ExpectCannotCast = [ExpectCanCast](std::shared_ptr from, + std::vector> to_set) { + ExpectCanCast(from, to_set, /*expected=*/false); + }; + + ExpectCanCast(null(), {boolean()}); + ExpectCanCast(null(), kNumericTypes); + ExpectCanCast(null(), kBaseBinaryTypes); + ExpectCanCast( + null(), {date32(), date64(), time32(TimeUnit::MILLI), timestamp(TimeUnit::SECOND)}); + ExpectCanCast(dictionary(uint16(), null()), {null()}); + + ExpectCanCast(boolean(), {boolean()}); + ExpectCanCast(boolean(), kNumericTypes); + ExpectCanCast(boolean(), {utf8(), large_utf8()}); + ExpectCanCast(dictionary(int32(), boolean()), {boolean()}); + + ExpectCannotCast(boolean(), {null()}); + ExpectCannotCast(boolean(), {binary(), large_binary()}); + ExpectCannotCast(boolean(), {date32(), date64(), time32(TimeUnit::MILLI), + timestamp(TimeUnit::SECOND)}); + + for (auto from_numeric : kNumericTypes) { + ExpectCanCast(from_numeric, {boolean()}); + ExpectCanCast(from_numeric, kNumericTypes); + ExpectCanCast(from_numeric, {utf8(), large_utf8()}); + ExpectCanCast(dictionary(int32(), from_numeric), {from_numeric}); + + ExpectCannotCast(from_numeric, {null()}); + } + + for (auto from_base_binary : kBaseBinaryTypes) { + ExpectCanCast(from_base_binary, {boolean()}); + ExpectCanCast(from_base_binary, kNumericTypes); + ExpectCanCast(from_base_binary, kBaseBinaryTypes); + ExpectCanCast(dictionary(int64(), from_base_binary), {from_base_binary}); + + // any cast which is valid for the dictionary is valid for the DictionaryArray + ExpectCanCast(dictionary(uint32(), from_base_binary), kBaseBinaryTypes); + ExpectCanCast(dictionary(int16(), from_base_binary), kNumericTypes); + + ExpectCannotCast(from_base_binary, {null()}); + } + + ExpectCanCast(utf8(), {timestamp(TimeUnit::MILLI)}); + ExpectCanCast(large_utf8(), {timestamp(TimeUnit::NANO)}); + ExpectCannotCast(timestamp(TimeUnit::MICRO), + kBaseBinaryTypes); // no formatting supported + + ExpectCannotCast(fixed_size_binary(3), + {fixed_size_binary(3)}); // FIXME missing identity cast + + auto smallint = std::make_shared(); + ASSERT_OK(RegisterExtensionType(smallint)); + ExpectCanCast(smallint, {int16()}); // cast storage + ExpectCanCast(smallint, + kNumericTypes); // any cast which is valid for storage is supported + ExpectCannotCast(null(), {smallint}); // FIXME missing common cast from null + ASSERT_OK(UnregisterExtensionType("smallint")); +} + TEST_F(TestCast, SameTypeZeroCopy) { std::shared_ptr arr = ArrayFromJSON(int32(), "[0, null, 2, 3, 4]"); ASSERT_OK_AND_ASSIGN(std::shared_ptr result, Cast(*arr, int32())); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index c58643d58b2..58d3e6fc781 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -76,7 +76,7 @@ struct CompareFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { - RETURN_NOT_OK(CheckArity(static_cast(values->size()))); + RETURN_NOT_OK(CheckArity(*values)); using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -85,13 +85,11 @@ struct CompareFunction : ScalarFunction { ReplaceNullWithOtherType(values); if (auto type = CommonNumeric(*values)) { - for (auto& descr : *values) { - descr.type = type; - } + ReplaceTypes(type, values); } else if (auto type = CommonTimestamp(*values)) { - for (auto& descr : *values) { - descr.type = type; - } + ReplaceTypes(type, values); + } else if (auto type = CommonBinary(*values)) { + ReplaceTypes(type, values); } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 9276532a704..42419ae1b1c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -481,6 +481,9 @@ TEST(TestCompareKernel, DispatchBest) { CheckDispatchBest(name, {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO)}, {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); + + CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()}); + CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()}); } } diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 261464f7ba6..69fefc1b4d0 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -21,6 +21,7 @@ #include #include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" #include "arrow/compute/exec_internal.h" #include "arrow/dataset/expression_internal.h" #include "arrow/io/memory.h" @@ -746,7 +747,24 @@ Result ReplaceFieldsWithKnownValues( if (auto ref = expr.field_ref()) { auto it = known_values.find(*ref); if (it != known_values.end()) { - ARROW_ASSIGN_OR_RAISE(Datum lit, compute::Cast(it->second, expr.type())); + Datum lit = it->second; + if (expr.type()->id() == Type::DICTIONARY) { + if (lit.is_scalar()) { + // FIXME the "right" way to support this is adding support for scalars to + // dictionary_encode and support for casting between index types to cast + ARROW_ASSIGN_OR_RAISE( + auto index, + Int32Scalar(0).CastTo( + checked_cast(*expr.type()).index_type())); + + ARROW_ASSIGN_OR_RAISE(auto dictionary, + MakeArrayFromScalar(*lit.scalar(), 1)); + + return literal( + DictionaryScalar::Make(std::move(index), std::move(dictionary))); + } + } + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(it->second, expr.type())); return literal(std::move(lit)); } } diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index d8c6b7445df..f17dc2f2af1 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -86,16 +86,12 @@ struct Comparison { return nullptr; } - // Execute a simple Comparison between scalars, casting the RHS if types disagree + // Execute a simple Comparison between scalars static Result Execute(Datum l, Datum r) { if (!l.is_scalar() || !r.is_scalar()) { return Status::Invalid("Cannot Execute Comparison on non-scalars"); } - if (!l.type()->Equals(r.type())) { - ARROW_ASSIGN_OR_RAISE(r, compute::Cast(r, l.type())); - } - std::vector arguments{std::move(l), std::move(r)}; ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments)); diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 4b05d7251e2..9fc5774f3a6 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -57,14 +57,10 @@ void ExpectResultsEqual(Actual&& actual, Expected&& expected) { MaybeExpected maybe_expected(std::forward(expected)); if (maybe_expected.ok()) { - ASSERT_OK_AND_ASSIGN(auto actual, maybe_actual); - EXPECT_EQ(actual, *maybe_expected); + EXPECT_EQ(maybe_actual, maybe_expected); } else { - EXPECT_EQ(maybe_actual.status().code(), expected.status().code()); - EXPECT_NE(maybe_actual.status().message().find(expected.status().message()), - std::string::npos) - << " actual: " << maybe_actual.status() << "\n" - << " expected: " << maybe_expected.status(); + EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT( + expected.status().code(), HasSubstr(expected.status().message()), maybe_actual); } } @@ -75,24 +71,32 @@ TEST(ExpressionUtils, Comparison) { ExpectResultsEqual(Comparison::Execute(l, r).Map(Comparison::GetName), expected); }; - Datum zero(0), one(1), two(2), null(std::make_shared()), str("hello"); + Datum zero(0), one(1), two(2), null(std::make_shared()); + Datum str("hello"), bin(std::make_shared(Buffer::FromString("hello"))); + Datum dict_str(DictionaryScalar::Make(std::make_shared(0), + ArrayFromJSON(utf8(), R"(["a", "b", "c"])"))); - Status parse_failure = Status::Invalid("Failed to parse"); + Status not_impl = Status::NotImplemented("no kernel matching input types"); Expect("equal", one, one); Expect("less", one, two); Expect("greater", one, zero); - // cast RHS to LHS type; "hello" > "1" - Expect("greater", str, one); - // cast RHS to LHS type; "hello" is not convertible to int - Expect(parse_failure, one, str); - Expect("na", one, null); - Expect("na", str, null); Expect("na", null, one); - // cast RHS to LHS type; "hello" is not convertible to int - Expect(parse_failure, null, str); + + // strings and ints are not comparable without explicit casts + Expect(not_impl, str, one); + Expect(not_impl, one, str); + Expect(not_impl, str, null); // not even null ints + + // string -> binary implicit cast allowed + Expect("equal", str, bin); + Expect("equal", bin, str); + + // dict_str -> string, implicit casts allowed + Expect("less", dict_str, str); + Expect("less", dict_str, bin); } TEST(ExpressionUtils, StripOrderPreservingCasts) { @@ -287,9 +291,9 @@ TEST(Expression, IsSatisfiable) { // When a top level conjunction contains an Expression which is certain to evaluate to // null, it can only evaluate to null or false. - auto null_or_false = and_(literal(null), field_ref("a")); - // This may appear in satisfiable filters if coalesced - EXPECT_TRUE(call("is_null", {null_or_false}).IsSatisfiable()); + auto never_true = and_(literal(null), field_ref("a")); + // This may appear in satisfiable filters if coalesced (for example, wrapped in fill_na) + EXPECT_TRUE(call("is_null", {never_true}).IsSatisfiable()); // ... but at the top level it is not satisfiable. // This special case arises when (for example) an absent column has made // one member of the conjunction always-null. This is fairly common and @@ -412,6 +416,9 @@ TEST(Expression, BindWithImplicitCasts) { // cast dictionary to value type ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), cmp(cast(field_ref("dict_str"), utf8()), field_ref("str"))); + + ExpectBindsTo(cmp(field_ref("dict_i32"), literal(int64_t(4))), + cmp(cast(field_ref("dict_i32"), int64()), literal(int64_t(4)))); } compute::SetLookupOptions in_a{ArrayFromJSON(utf8(), R"(["a"])")}; @@ -997,6 +1004,10 @@ TEST(Expression, SimplifyWithGuarantee) { call("is_in", {field_ref("i64")}, in_123)})} .WithGuarantee(greater(field_ref("f32"), literal(0.F))) .Expect(call("is_in", {field_ref("i64")}, in_123)); + + Simplify{greater(field_ref("dict_i32"), literal(int64_t(1)))} + .WithGuarantee(equal(field_ref("dict_i32"), literal(0))) + .Expect(false); } TEST(Expression, SimplifyThenExecute) { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index b3e9ae424da..c72283312cb 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -62,6 +62,7 @@ const std::shared_ptr kBoringSchema = schema({ field("date64", date64()), field("str", utf8()), field("dict_str", dictionary(int32(), utf8())), + field("dict_i32", dictionary(int32(), int32())), field("ts_ns", timestamp(TimeUnit::NANO)), }); diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index 09dfd59c8d2..6504d950674 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -317,6 +317,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return ValueUnsafe(); } const T& operator*() const& { return ValueOrDie(); } + const T* operator->() const& { return &ValueOrDie(); } /// Gets a mutable reference to the stored `T` value. /// @@ -331,6 +332,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return ValueUnsafe(); } T& operator*() & { return ValueOrDie(); } + T* operator->() & { return &ValueOrDie(); } /// Moves and returns the internally-stored `T` value. /// @@ -453,9 +455,9 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { } }; -#define ARROW_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ - auto&& result_name = (rexpr); \ - ARROW_RETURN_NOT_OK((result_name).status()); \ +#define ARROW_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + ARROW_RETURN_IF_(!(result_name).ok(), (result_name).status(), ARROW_STRINGIFY(rexpr)); \ lhs = std::move(result_name).ValueUnsafe(); #define ARROW_ASSIGN_OR_RAISE_NAME(x, y) ARROW_CONCAT(x, y) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index eca711d7c4f..06fc6783ff3 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -268,6 +268,13 @@ Result> DictionaryScalar::GetEncodedValue() const { return value.dictionary->GetScalar(index_value); } +std::shared_ptr DictionaryScalar::Make(std::shared_ptr index, + std::shared_ptr dict) { + auto type = dictionary(index->type, dict->type()); + return std::make_shared(ValueType{std::move(index), std::move(dict)}, + std::move(type)); +} + template using scalar_constructor_has_arrow_type = std::is_constructible::ScalarType, std::shared_ptr>; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 2888874d292..e84e3fab900 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -448,6 +448,9 @@ struct ARROW_EXPORT DictionaryScalar : public Scalar { DictionaryScalar(ValueType value, std::shared_ptr type, bool is_valid = true) : Scalar(std::move(type), is_valid), value(std::move(value)) {} + static std::shared_ptr Make(std::shared_ptr index, + std::shared_ptr dict); + Result> GetEncodedValue() const; }; diff --git a/cpp/src/arrow/status.cc b/cpp/src/arrow/status.cc index 480bbd3e468..cfc5eb1e345 100644 --- a/cpp/src/arrow/status.cc +++ b/cpp/src/arrow/status.cc @@ -132,7 +132,7 @@ void Status::Abort(const std::string& message) const { void Status::AddContextLine(const char* filename, int line, const char* expr) { ARROW_CHECK(!ok()) << "Cannot add context line to ok status"; std::stringstream ss; - ss << "\nIn " << filename << ", line " << line << ", code: " << expr; + ss << "\n" << filename << ":" << line << " " << expr; state_->msg += ss.str(); } #endif diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 2e523eac2bb..cdb23a92899 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -79,15 +79,18 @@ EXPECT_THAT(_st.ToString(), (matcher)); \ } while (false) -#define ASSERT_OK(expr) \ - do { \ - auto _res = (expr); \ - ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \ - if (!_st.ok()) { \ - FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString(); \ - } \ +#define EXPECT_RAISES_WITH_CODE_AND_MESSAGE_THAT(code, matcher, expr) \ + do { \ + auto _res = (expr); \ + ::arrow::Status _st = ::arrow::internal::GenericToStatus(_res); \ + EXPECT_EQ(_st.CodeAsString(), Status::CodeAsString(code)); \ + EXPECT_THAT(_st.ToString(), (matcher)); \ } while (false) +#define ASSERT_OK(expr) \ + for (::arrow::Status _st = ::arrow::internal::GenericToStatus((expr)); !_st.ok();) \ + FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString() + #define ASSERT_OK_NO_THROW(expr) ASSERT_NO_THROW(ASSERT_OK(expr)) #define ARROW_EXPECT_OK(expr) \ @@ -426,13 +429,6 @@ inline void BitmapFromVector(const std::vector& is_valid, ASSERT_OK(GetBitmapFromVector(is_valid, out)); } -template -void AssertSortedEquals(std::vector u, std::vector v) { - std::sort(u.begin(), u.end()); - std::sort(v.begin(), v.end()); - ASSERT_EQ(u, v); -} - ARROW_TESTING_EXPORT void SleepFor(double seconds); @@ -474,6 +470,17 @@ class ARROW_TESTING_EXPORT EnvVarGuard { #define LARGE_MEMORY_TEST(name) name #endif +inline void PrintTo(const Status& st, std::ostream* os) { *os << st.ToString(); } + +template +void PrintTo(const Result& result, std::ostream* os) { + if (result.ok()) { + ::testing::internal::UniversalPrint(result.ValueOrDie(), os); + } else { + *os << result.status(); + } +} + } // namespace arrow namespace nonstd { diff --git a/python/pyarrow/tests/parquet/test_dataset.py b/python/pyarrow/tests/parquet/test_dataset.py index 48c62c7e458..cc49f14030a 100644 --- a/python/pyarrow/tests/parquet/test_dataset.py +++ b/python/pyarrow/tests/parquet/test_dataset.py @@ -191,11 +191,6 @@ def test_filters_equivalency(tempdir, use_legacy_dataset): ['string', string_keys], ['boolean', boolean_keys] ] - schema = pa.schema({ - 'integer': pa.int32(), - 'string': pa.string(), - 'boolean', pa.boolean() - }) df = pd.DataFrame({ 'integer': np.array(integer_keys, dtype='i4').repeat(15), @@ -209,9 +204,9 @@ def test_filters_equivalency(tempdir, use_legacy_dataset): # Old filters syntax: # integer == 1 AND string != b AND boolean == True dataset = pq.ParquetDataset( - base_path, filesystem=fs, schema=schema, + base_path, filesystem=fs, filters=[('integer', '=', 1), ('string', '!=', 'b'), - ('boolean', '==', True)], + ('boolean', '==', 'True')], use_legacy_dataset=use_legacy_dataset, ) table = dataset.read() diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 1e0f9418eec..e0143bcf66b 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -554,6 +554,7 @@ test_that("filter() on timestamp columns", { ) # Now with bare string date + skip("Implement more aggressive implicit casting for scalars") expect_equivalent( ds %>% filter(ts >= "2015-05-04") %>% @@ -666,8 +667,6 @@ test_that("filter() with expressions", { ) ) - skip("Implicit casts aren't being inserted everywhere they need to be (ARROW-8919)") - # Error: NotImplemented: Function multiply_checked has no kernel matching input types (scalar[double], array[int32]) expect_equivalent( ds %>% select(chr, dbl, int) %>% @@ -680,8 +679,6 @@ test_that("filter() with expressions", { ) ) - skip("Implicit casts are only inserted for scalars (ARROW-8919)") - # Error: NotImplemented: Function add_checked has no kernel matching input types (array[double], array[int32]) expect_equivalent( ds %>% select(chr, dbl, int) %>% @@ -700,7 +697,7 @@ test_that("filter scalar validation doesn't crash (ARROW-7772)", { ds %>% filter(int == "fff", part == 1) %>% collect(), - "Failed to parse string: 'fff' as a scalar of type int32" + "equal has no kernel matching input types .array.int32., scalar.string.." ) }) From c1de51d81b7914d5dd6b3b52dce96b85f06e512d Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 22 Jan 2021 12:57:26 -0500 Subject: [PATCH 13/22] describe implicit cast behavior in compute.rst --- docs/source/cpp/compute.rst | 44 +++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index c513ed5b0ab..038027522fc 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -196,9 +196,11 @@ Binary functions have the following semantics (which is sometimes called Arithmetic functions ~~~~~~~~~~~~~~~~~~~~ -These functions expect two inputs of the same type and apply a given binary +These functions expect two inputs of numeric type and apply a given binary operation to each pair of elements gathered from the inputs. If any of the input elements in a pair is null, the corresponding output element is null. +Inputs will be cast to the :ref:`common numeric type ` +(and dictionary decoded, if applicable) before the operation is applied. The default variant of these functions does not detect overflow (the result then typically wraps around). Each function is also available in an @@ -228,9 +230,12 @@ an ``Invalid`` :class:`Status` when overflow is detected. Comparisons ~~~~~~~~~~~ -Those functions expect two inputs of the same type and apply a given -comparison operator. If any of the input elements in a pair is null, -the corresponding output element is null. +These functions expect two inputs of numeric type (in which case they will be +cast to the :ref:`common numeric type ` before comparison), +or two inputs of Binary- or String-like types, or two inputs of Temporal types. +If any input is dictionary encoded it will be expanded for the purposes of +comparison. If any of the input elements in a pair is null, the corresponding +output element is null. +--------------------------+------------+---------------------------------------------+---------------------+ | Function names | Arity | Input types | Output type | @@ -744,3 +749,34 @@ Structural transforms * \(2) For each value in the list child array, the index at which it is found in the list array is appended to the output. Nulls in the parent list array are discarded. + +.. _common-numeric-type: + +Common numeric type +~~~~~~~~~~~~~~~~~~~ + +The common numeric type of a set of input numeric types is the smallest numeric +type which can accommodate any value of any input. If any input is a floating +point type the common numeric type is the widest floating point type among the +inputs. Otherwise the common numeric type is integral, is signed if any input +is signed, and its width is the maximum width of any input. For example: + ++-------------------+----------------------+-------------------------------------------+ +| Input types | Common numeric type | Notes | ++===================+======================+===========================================+ +| int32, int32 | int32 | | ++-------------------+----------------------+-------------------------------------------+ +| int16, int32 | int32 | Max width is 32, promote LHS to int32 | ++-------------------+----------------------+-------------------------------------------+ +| uint16, uint32 | uint32 | All inputs unsigned, maintain unsigned | ++-------------------+----------------------+-------------------------------------------+ +| uint16, int32 | int32 | One input signed, override unsigned | ++-------------------+----------------------+-------------------------------------------+ +| int16, uint32 | int32 | | ++-------------------+----------------------+-------------------------------------------+ +| float32, int32 | float32 | Promote RHS to float32 | ++-------------------+----------------------+-------------------------------------------+ +| float32, float64 | float64 | | ++-------------------+----------------------+-------------------------------------------+ +| float32, int64 | float32 | int64 is wider, still promotes to float32 | ++-------------------+----------------------+-------------------------------------------+ From db5ae2f1edf9abebc0f6831f940d577303de0f0c Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 22 Jan 2021 13:15:21 -0500 Subject: [PATCH 14/22] msvc: linkage fix --- cpp/src/arrow/compute/function.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index ccfbeedfe33..3bac73e6563 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -24,6 +24,7 @@ #include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec_internal.h" +#include "arrow/compute/kernels/common.h" #include "arrow/datum.h" #include "arrow/util/cpu_info.h" From 0852305ab9d616b0e290e66f7e0fff0c22d40d17 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 27 Jan 2021 13:03:33 -0500 Subject: [PATCH 15/22] review comments --- cpp/src/arrow/compute/function.cc | 6 +++--- docs/source/cpp/compute.rst | 7 +++++++ r/R/array.R | 7 +++++++ r/R/arrowExports.R | 4 ++++ r/man/array.Rd | 1 + r/src/array.cpp | 6 ++++++ r/src/arrowExports.cpp | 10 ++++++++++ r/tests/testthat/test-compute-arith.R | 14 +++++++++++++- r/tests/testthat/test-dataset.R | 2 +- 9 files changed, 52 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 3bac73e6563..ade4ae769e4 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -41,9 +41,9 @@ const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } Status CheckArityImpl(const Function* function, int passed_num_args, const char* passed_num_args_label) { if (function->arity().is_varargs && passed_num_args < function->arity().num_args) { - return Status::Invalid("VarArgs function needs at least ", function->arity().num_args, - " arguments but ", passed_num_args_label, " only ", - passed_num_args); + return Status::Invalid("VarArgs function ", function->name(), " needs at least ", + function->arity().num_args, " arguments but ", + passed_num_args_label, " only ", passed_num_args); } if (!function->arity().is_varargs && passed_num_args != function->arity().num_args) { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 038027522fc..79e25294ebb 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -766,6 +766,8 @@ is signed, and its width is the maximum width of any input. For example: +===================+======================+===========================================+ | int32, int32 | int32 | | +-------------------+----------------------+-------------------------------------------+ +| uint32, int32 | int32 | One input signed, override unsigned | ++-------------------+----------------------+-------------------------------------------+ | int16, int32 | int32 | Max width is 32, promote LHS to int32 | +-------------------+----------------------+-------------------------------------------+ | uint16, uint32 | uint32 | All inputs unsigned, maintain unsigned | @@ -780,3 +782,8 @@ is signed, and its width is the maximum width of any input. For example: +-------------------+----------------------+-------------------------------------------+ | float32, int64 | float32 | int64 is wider, still promotes to float32 | +-------------------+----------------------+-------------------------------------------+ + +In particulary, note that comparing a `uint32` column to an `int32` column may +emit an error if one of the LHS' values cannot be expressed as the common type +`int32` (for example, `2 ** 31`). This tradeoff is made to keep the results of +arithmetic operations narrow. diff --git a/r/R/array.R b/r/R/array.R index ec2b545dfae..acb612be5ef 100644 --- a/r/R/array.R +++ b/r/R/array.R @@ -62,6 +62,7 @@ #' - `$type_id()`: type id #' - `$Equals(other)` : is this array equal to `other` #' - `$ApproxEquals(other)` : +#' - `$Diff(other)` : return a string expressing the difference between two arrays #' - `$data()`: return the underlying [ArrayData][ArrayData] #' - `$as_vector()`: convert to an R vector #' - `$ToString()`: string representation of the array @@ -95,6 +96,12 @@ Array <- R6Class("Array", ApproxEquals = function(other) { inherits(other, "Array") && Array__ApproxEquals(self, other) }, + Diff = function(other) { + if (!inherits(other, "Array")) { + other <- Array$create(other) + } + Array__Diff(self, other) + }, data = function() Array__data(self), as_vector = function() Array__as_vector(self), ToString = function() { diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index d6a5f9356e8..ec0aae94f30 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -48,6 +48,10 @@ Array__ApproxEquals <- function(lhs, rhs){ .Call(`_arrow_Array__ApproxEquals`, lhs, rhs) } +Array__Diff <- function(lhs, rhs){ + .Call(`_arrow_Array__Diff`, lhs, rhs) +} + Array__data <- function(array){ .Call(`_arrow_Array__data`, array) } diff --git a/r/man/array.Rd b/r/man/array.Rd index b133c073824..fbc91e4dc35 100644 --- a/r/man/array.Rd +++ b/r/man/array.Rd @@ -60,6 +60,7 @@ a == a \item \verb{$type_id()}: type id \item \verb{$Equals(other)} : is this array equal to \code{other} \item \verb{$ApproxEquals(other)} : +\item \verb{$Diff(other)} : return a string expressing the difference between two arrays \item \verb{$data()}: return the underlying \link{ArrayData} \item \verb{$as_vector()}: convert to an R vector \item \verb{$ToString()}: string representation of the array diff --git a/r/src/array.cpp b/r/src/array.cpp index e96e286a073..9601ee43c03 100644 --- a/r/src/array.cpp +++ b/r/src/array.cpp @@ -141,6 +141,12 @@ bool Array__ApproxEquals(const std::shared_ptr& lhs, return lhs->ApproxEquals(rhs); } +// [[arrow::export]] +std::string Array__Diff(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + return lhs->Diff(*rhs); +} + // [[arrow::export]] std::shared_ptr Array__data( const std::shared_ptr& array) { diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index ae90abd5adf..2fbfecacfa1 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -108,6 +108,15 @@ BEGIN_CPP11 END_CPP11 } // array.cpp +std::string Array__Diff(const std::shared_ptr& lhs, const std::shared_ptr& rhs); +extern "C" SEXP _arrow_Array__Diff(SEXP lhs_sexp, SEXP rhs_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type lhs(lhs_sexp); + arrow::r::Input&>::type rhs(rhs_sexp); + return cpp11::as_sexp(Array__Diff(lhs, rhs)); +END_CPP11 +} +// array.cpp std::shared_ptr Array__data(const std::shared_ptr& array); extern "C" SEXP _arrow_Array__data(SEXP array_sexp){ BEGIN_CPP11 @@ -3512,6 +3521,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Array__type_id", (DL_FUNC) &_arrow_Array__type_id, 1}, { "_arrow_Array__Equals", (DL_FUNC) &_arrow_Array__Equals, 2}, { "_arrow_Array__ApproxEquals", (DL_FUNC) &_arrow_Array__ApproxEquals, 2}, + { "_arrow_Array__Diff", (DL_FUNC) &_arrow_Array__Diff, 2}, { "_arrow_Array__data", (DL_FUNC) &_arrow_Array__data, 1}, { "_arrow_Array__RangeEquals", (DL_FUNC) &_arrow_Array__RangeEquals, 5}, { "_arrow_Array__View", (DL_FUNC) &_arrow_Array__View, 2}, diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 3bfa8c2b41e..d3cd2eedf6d 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -28,8 +28,14 @@ test_that("Addition", { a8 <- a$cast(int8()) expect_type_equal(a8 + Scalar$create(1, int8()), int8()) + + # int8 will be promoted to int32 when added to int32 expect_type_equal(a8 + 127L, int32()) - expect_type_equal(a8 + 200L, int32()) + expect_equal(a8 + 127L, Array$create(c(128:131, NA_integer_))) + + b <- Array$create(c(4:1, NA_integer_)) + expect_type_equal(a8 + b, int32()) + expect_equal(a8 + b, Array$create(c(5L, 5L, 5L, 5L, NA_integer_))) expect_type_equal(a + 4.1, float64()) expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_))) @@ -38,11 +44,17 @@ test_that("Addition", { test_that("Subtraction", { a <- Array$create(c(1:4, NA_integer_)) expect_equal(a - 3L, Array$create(c(-2:1, NA_integer_))) + + expect_equal(Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_)) - a, + Array$create(c(4.1, 4.1, 4.1, 4.1, NA_real_))) }) test_that("Multiplication", { a <- Array$create(c(1:4, NA_integer_)) expect_equal(a * 2L, Array$create(c(1:4 * 2L, NA_integer_))) + + expect_equal((a * 0.5) * 3L, + Array$create(c(1.5, 3, 4.5, 6, NA_real_))) }) test_that("Division", { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index e0143bcf66b..990f024212e 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -554,7 +554,7 @@ test_that("filter() on timestamp columns", { ) # Now with bare string date - skip("Implement more aggressive implicit casting for scalars") + skip("Implement more aggressive implicit casting for scalars (ARROW-11402)") expect_equivalent( ds %>% filter(ts >= "2015-05-04") %>% From ff9cde2c95ee6131f9a99e78e229f192a2209804 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 27 Jan 2021 14:55:40 -0500 Subject: [PATCH 16/22] unskip implicit casting comparison test --- r/R/expression.R | 1 - 1 file changed, 1 deletion(-) diff --git a/r/R/expression.R b/r/R/expression.R index 5475f7a44bc..ffc19f3547e 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -97,7 +97,6 @@ cast_array_expression <- function(x, to_type, safe = TRUE, ...) { .wrap_arrow <- function(arg, fun) { if (!inherits(arg, c("ArrowObject", "array_expression"))) { # TODO: Array$create if lengths are equal? - # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) if (fun == "%in%") { arg <- Array$create(arg) } else { From c761233210f21491ede794e07715eedd7106d14b Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 27 Jan 2021 15:59:47 -0500 Subject: [PATCH 17/22] Revert "unskip implicit casting comparison test" This reverts commit 2d26bf199d36b807d378c7ecc321b86f9922867b. --- r/R/expression.R | 1 + 1 file changed, 1 insertion(+) diff --git a/r/R/expression.R b/r/R/expression.R index ffc19f3547e..5475f7a44bc 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -97,6 +97,7 @@ cast_array_expression <- function(x, to_type, safe = TRUE, ...) { .wrap_arrow <- function(arg, fun) { if (!inherits(arg, c("ArrowObject", "array_expression"))) { # TODO: Array$create if lengths are equal? + # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) if (fun == "%in%") { arg <- Array$create(arg) } else { From dd683420b0e0dbd16e3d74f6db97c7725a2844bc Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 5 Feb 2021 14:37:42 -0500 Subject: [PATCH 18/22] review comments --- cpp/src/arrow/compute/cast.cc | 16 +++++++++ cpp/src/arrow/compute/cast.h | 12 +++++++ cpp/src/arrow/compute/function.cc | 33 ++++--------------- cpp/src/arrow/compute/function.h | 8 +++++ .../arrow/compute/kernels/codegen_internal.cc | 4 +++ cpp/src/arrow/compute/kernels/common.h | 12 ------- .../compute/kernels/scalar_arithmetic_test.cc | 16 ++++----- .../arrow/compute/kernels/scalar_cast_test.cc | 14 +++----- .../compute/kernels/scalar_compare_test.cc | 10 +++--- cpp/src/arrow/compute/kernels/test_util.cc | 17 ++-------- cpp/src/arrow/dataset/expression_internal.h | 4 +++ cpp/src/arrow/type_traits.h | 6 ++++ r/tests/testthat/test-compute-vector.R | 1 + 13 files changed, 75 insertions(+), 78 deletions(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 4462256164f..8a091f2355d 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -225,5 +225,21 @@ bool CanCast(const DataType& from_type, const DataType& to_type) { return false; } +Result> Cast(std::vector datums, std::vector descrs, + ExecContext* ctx) { + for (size_t i = 0; i != datums.size(); ++i) { + if (descrs[i] != datums[i].descr()) { + if (descrs[i].shape != datums[i].shape()) { + return Status::NotImplemented("casting between Datum shapes"); + } + + ARROW_ASSIGN_OR_RAISE(datums[i], + Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx)); + } + } + + return datums; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index 60111df7bf2..c0405fb93e9 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -155,5 +155,17 @@ Result Cast(const Datum& value, std::shared_ptr to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); +/// \brief Cast several values simultaneously. Safe cast options are used. +/// \param[in] values datums to cast +/// \param[in] descrs ValueDescrs to cast to +/// \param[in] ctx the function execution context, optional +/// \return the resulting datums +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Cast(std::vector values, std::vector descrs, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index ade4ae769e4..776c14d5ca0 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -145,31 +145,19 @@ Result Function::DispatchExact( } Result Function::DispatchBest(std::vector* values) const { - if (kind_ == Function::META) { - return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); - } - RETURN_NOT_OK(CheckArity(*values)); - - // first try for an exact match - if (auto kernel = detail::DispatchExactImpl(this, *values)) { - return kernel; - } - - // XXX permit generic conversions here, for example dict -> decoded, null -> any? + // TODO(ARROW-11508) permit generic conversions here return DispatchExact(*values); } -Result Function::Execute(const std::vector& original_args, +Result Function::Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { if (options == nullptr) { options = default_options(); } if (ctx == nullptr) { ExecContext default_ctx; - return Execute(original_args, options, &default_ctx); + return Execute(args, options, &default_ctx); } - // make a local copy to accommodate implicit casts - auto args = original_args; // type-check Datum arguments here. Really we'd like to avoid this as much as // possible @@ -180,16 +168,7 @@ Result Function::Execute(const std::vector& original_args, } ARROW_ASSIGN_OR_RAISE(auto kernel, DispatchBest(&inputs)); - for (size_t i = 0; i != args.size(); ++i) { - if (inputs[i] != args[i].descr()) { - if (inputs[i].shape != args[i].shape()) { - return Status::NotImplemented( - "Automatic broadcasting of scalars to arrays for function ", name()); - } - - ARROW_ASSIGN_OR_RAISE(args[i], Cast(args[i], inputs[i].type)); - } - } + ARROW_ASSIGN_OR_RAISE(auto implicitly_cast_args, Cast(args, inputs, ctx)); std::unique_ptr state; @@ -211,8 +190,8 @@ Result Function::Execute(const std::vector& original_args, RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); auto listener = std::make_shared(); - RETURN_NOT_OK(executor->Execute(args, listener.get())); - return executor->WrapResults(args, listener->values()); + RETURN_NOT_OK(executor->Execute(implicitly_cast_args, listener.get())); + return executor->WrapResults(implicitly_cast_args, listener->values()); } Status Function::Validate() const { diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index 0ae4d22e200..af5d81a30ec 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -233,6 +233,14 @@ class FunctionImpl : public Function { std::vector kernels_; }; +/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. +ARROW_EXPORT +const Kernel* DispatchExactImpl(const Function* func, const std::vector&); + +/// \brief Return an error message if no Kernel is found. +ARROW_EXPORT +Status NoMatchingKernel(const Function* func, const std::vector&); + } // namespace detail /// \brief A function that executes elementwise operations on arrays or diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 2df44ee8f9a..4381feacd32 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -215,6 +215,10 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { // a common numeric type is only possible if all types are numeric return nullptr; } + if (id == Type::HALF_FLOAT) { + // float16 arithmetic is not currently supported + return nullptr; + } } for (const auto& descr : descrs) { if (descr.type->id() == Type::DOUBLE) return float64(); diff --git a/cpp/src/arrow/compute/kernels/common.h b/cpp/src/arrow/compute/kernels/common.h index 6566555e7f1..21244320f38 100644 --- a/cpp/src/arrow/compute/kernels/common.h +++ b/cpp/src/arrow/compute/kernels/common.h @@ -51,16 +51,4 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -namespace compute { -namespace detail { - -/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. -ARROW_EXPORT -const Kernel* DispatchExactImpl(const Function* func, const std::vector&); - -ARROW_EXPORT -Status NoMatchingKernel(const Function* func, const std::vector&); - -} // namespace detail -} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index f14a28aaa6c..de955bbf1e2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -674,20 +674,18 @@ TEST(TestBinaryArithmetic, AddWithImplicitCasts) { ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), ArrayFromJSON(int32(), "[-13, 4, 21, null]")); - CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"), - ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), - ArrayFromJSON(int32(), "[3, 5, 7, null]")); + CheckScalarBinary("add", + ArrayFromJSON(dictionary(int32(), int32()), "[8, 6, 3, null, 2]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7, 0]"), + ArrayFromJSON(int32(), "[11, 10, 8, null, 2]")); CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"), std::make_shared(4), ArrayFromJSON(int32(), "[null, null, null, null]")); - // Not currently implemented since it would invoke a double implicit cast: - // dictionary(int32, int8) -> int8 -> int32 - // CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, - // null]"), - // ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), - // ArrayFromJSON(int32(), "[3, 5, 7, null]")); + CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(int32(), "[3, 5, 7, null]")); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index bc7110d1c10..2a0f44187b2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -468,13 +468,11 @@ TEST_F(TestCast, CanCast) { ExpectCannotCast(fixed_size_binary(3), {fixed_size_binary(3)}); // FIXME missing identity cast - auto smallint = std::make_shared(); - ASSERT_OK(RegisterExtensionType(smallint)); - ExpectCanCast(smallint, {int16()}); // cast storage - ExpectCanCast(smallint, + ExtensionTypeGuard smallint_guard(smallint()); + ExpectCanCast(smallint(), {int16()}); // cast storage + ExpectCanCast(smallint(), kNumericTypes); // any cast which is valid for storage is supported - ExpectCannotCast(null(), {smallint}); // FIXME missing common cast from null - ASSERT_OK(UnregisterExtensionType("smallint")); + ExpectCannotCast(null(), {smallint()}); // FIXME missing common cast from null } TEST_F(TestCast, SameTypeZeroCopy) { @@ -1929,7 +1927,7 @@ std::shared_ptr SmallintArrayFromJSON(const std::string& json_data) { TEST_F(TestCast, ExtensionTypeToIntDowncast) { auto smallint = std::make_shared(); - ASSERT_OK(RegisterExtensionType(smallint)); + ExtensionTypeGuard smallint_guard(smallint); CastOptions options; options.allow_int_overflow = false; @@ -1965,8 +1963,6 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { // disallow overflow options.allow_int_overflow = false; ASSERT_RAISES(Invalid, Cast(*v3, uint8(), options)); - - ASSERT_OK(UnregisterExtensionType("smallint")); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 42419ae1b1c..d54a5cbf8ee 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -511,12 +511,10 @@ TEST(TestCompareKernel, GreaterWithImplicitCasts) { ArrayFromJSON(date64(), "[86400000, 0, 86400000]"), ArrayFromJSON(boolean(), "[false, true, false]")); - // Not currently implemented since it would invoke a double implicit cast: - // dictionary(int32, int8) -> int8 -> int32 - // CheckScalarBinary("greater", - // ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), - // ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), - // ArrayFromJSON(boolean(), "[false, false, false, null]")); + CheckScalarBinary("greater", + ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), + ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), + ArrayFromJSON(boolean(), "[false, false, false, null]")); } class TestStringCompareKernel : public ::testing::Test {}; diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index c8dd00250dc..2fc817120d9 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -185,23 +185,10 @@ void CheckDispatchBest(std::string func_name, std::vector original_v ASSERT_OK_AND_ASSIGN(auto expected_kernel, function->DispatchExact(expected_equivalent_values)); - auto Format = [](const std::vector& descrs) { - std::stringstream ss; - ss << "("; - for (size_t i = 0; i < descrs.size(); ++i) { - if (i > 0) { - ss << ", "; - } - ss << descrs[i].ToString(); - } - ss << ")"; - return ss.str(); - }; - EXPECT_EQ(actual_kernel, expected_kernel) - << "DispatchBest" << Format(original_values) << " => " + << "DispatchBest" << ValueDescr::ToString(original_values) << " => " << actual_kernel->signature->ToString() << "\n" - << "DispatchExact" << Format(expected_equivalent_values) << " => " + << "DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => " << expected_kernel->signature->ToString(); } diff --git a/cpp/src/arrow/dataset/expression_internal.h b/cpp/src/arrow/dataset/expression_internal.h index f17dc2f2af1..24e60377f5a 100644 --- a/cpp/src/arrow/dataset/expression_internal.h +++ b/cpp/src/arrow/dataset/expression_internal.h @@ -105,6 +105,10 @@ struct Comparison { return less.scalar_as().value ? LESS : GREATER; } + // Given an Expression wrapped in casts which preserve ordering + // (for example, cast(field_ref("i16"), to_type=int32())), unwrap the inner Expression. + // This is used to destructure implicitly cast field_refs during Expression + // simplification. static const Expression& StripOrderPreservingCasts(const Expression& expr) { auto call = expr.call(); if (!call) return expr; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 48dce38e87a..e872a31f31d 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -963,6 +963,12 @@ static inline int bit_width(Type::type type_id) { return 32; case Type::INTERVAL_DAY_TIME: return 64; + + case Type::DECIMAL128: + return 128; + case Type::DECIMAL256: + return 256; + default: break; } diff --git a/r/tests/testthat/test-compute-vector.R b/r/tests/testthat/test-compute-vector.R index 4fe7fed4d1c..0b184889bee 100644 --- a/r/tests/testthat/test-compute-vector.R +++ b/r/tests/testthat/test-compute-vector.R @@ -43,6 +43,7 @@ test_that("compare ops with Array", { expect_array_compares(Array$create(c(NA, 1:5)), 4) expect_array_compares(Array$create(as.numeric(c(NA, 1:5))), 4) expect_array_compares(Array$create(c(NA, 1:5)), Array$create(rev(c(NA, 1:5)))) + expect_array_compares(Array$create(c(NA, 1:5)), Array$create(rev(c(NA, 1:5)), type=double())) }) test_that("compare ops with ChunkedArray", { From 282dac5a718472b8d0ef6674224b3d6a6b660e60 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Fri, 5 Feb 2021 17:13:17 -0500 Subject: [PATCH 19/22] expand common numeric type when signed/unsigned --- .../arrow/compute/kernels/codegen_internal.cc | 33 +++++-- .../compute/kernels/scalar_arithmetic_test.cc | 37 ++++++-- .../compute/kernels/scalar_compare_test.cc | 40 +++++++-- cpp/src/arrow/compute/kernels/test_util.cc | 4 +- docs/source/cpp/compute.rst | 88 +++++++++++-------- 5 files changed, 138 insertions(+), 64 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 4381feacd32..45211600727 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -209,6 +209,8 @@ void ReplaceTypes(const std::shared_ptr& type, } std::shared_ptr CommonNumeric(const std::vector& descrs) { + DCHECK(!descrs.empty()) << "tried to find CommonNumeric type of an empty set"; + for (const auto& descr : descrs) { auto id = descr.type->id(); if (!is_floating(id) && !is_integer(id)) { @@ -220,27 +222,40 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { return nullptr; } } + for (const auto& descr : descrs) { if (descr.type->id() == Type::DOUBLE) return float64(); } + for (const auto& descr : descrs) { if (descr.type->id() == Type::FLOAT) return float32(); } - bool at_least_one_signed = false; - int max_width = 0; + int max_width_signed = 0, max_width_unsigned = 0; for (const auto& descr : descrs) { auto id = descr.type->id(); - at_least_one_signed |= is_signed_integer(id); - max_width = std::max(bit_width(id), max_width); + auto max_width = is_signed_integer(id) ? &max_width_signed : &max_width_unsigned; + *max_width = std::max(bit_width(id), *max_width); + } + + if (max_width_signed == 0) { + if (max_width_unsigned >= 64) return uint64(); + if (max_width_unsigned == 32) return uint32(); + if (max_width_unsigned == 16) return uint16(); + DCHECK_EQ(max_width_unsigned, 8); + return int8(); + } + + if (max_width_signed <= max_width_unsigned) { + max_width_signed = BitUtil::NextPower2(max_width_unsigned + 1); } - if (max_width == 64) return at_least_one_signed ? int64() : uint64(); - if (max_width == 32) return at_least_one_signed ? int32() : uint32(); - if (max_width == 16) return at_least_one_signed ? int16() : uint16(); - DCHECK_EQ(max_width, 8); - return at_least_one_signed ? int8() : uint8(); + if (max_width_signed >= 64) return int64(); + if (max_width_signed == 32) return int32(); + if (max_width_signed == 16) return int16(); + DCHECK_EQ(max_width_signed, 8); + return int8(); } std::shared_ptr CommonTimestamp(const std::vector& descrs) { diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index de955bbf1e2..4d4f14e1154 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -643,22 +643,28 @@ TEST(TestBinaryArithmetic, DispatchBest) { name += suffix; CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); - CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); - CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()}); CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()}); - CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()}); + CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()}); - CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()}); + CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()}); + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, {float64(), float64()}); - CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, {float64(), float64()}); } @@ -672,12 +678,12 @@ TEST(TestBinaryArithmetic, AddWithImplicitCasts) { CheckScalarBinary("add", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), - ArrayFromJSON(int32(), "[-13, 4, 21, null]")); + ArrayFromJSON(int64(), "[-13, 4, 21, null]")); CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int32()), "[8, 6, 3, null, 2]"), ArrayFromJSON(uint32(), "[3, 4, 5, 7, 0]"), - ArrayFromJSON(int32(), "[11, 10, 8, null, 2]")); + ArrayFromJSON(int64(), "[11, 10, 8, null, 2]")); CheckScalarBinary("add", ArrayFromJSON(int32(), "[0, 1, 2, null]"), std::make_shared(4), @@ -685,7 +691,22 @@ TEST(TestBinaryArithmetic, AddWithImplicitCasts) { CheckScalarBinary("add", ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), - ArrayFromJSON(int32(), "[3, 5, 7, null]")); + ArrayFromJSON(int64(), "[3, 5, 7, null]")); +} + +TEST(TestBinaryArithmetic, AddWithImplicitCastsUint64EdgeCase) { + // int64 is as wide as we can promote + CheckDispatchBest("add", {int8(), uint64()}, {int64(), int64()}); + + // this works sometimes + CheckScalarBinary("add", ArrayFromJSON(int8(), "[-1]"), ArrayFromJSON(uint64(), "[0]"), + ArrayFromJSON(int64(), "[-1]")); + + // ... but it can result in impossible implicit casts in the presence of uint64, since + // some uint64 values cannot be cast to int64: + ASSERT_RAISES(Invalid, + CallFunction("add", {ArrayFromJSON(int64(), "[-1]"), + ArrayFromJSON(uint64(), "[18446744073709551615]")})); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index d54a5cbf8ee..7b0906395d7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -455,27 +455,31 @@ TEST(TestCompareKernel, DispatchBest) { for (std::string name : {"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) { CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); - CheckDispatchBest(name, {int32(), null()}, {int32(), int32()}); - CheckDispatchBest(name, {null(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int8()}, {int32(), int32()}); CheckDispatchBest(name, {int32(), int16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int32()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), int64()}, {int64(), int64()}); - CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {int32(), uint8()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint16()}, {int32(), int32()}); + CheckDispatchBest(name, {int32(), uint32()}, {int64(), int64()}); + CheckDispatchBest(name, {int32(), uint64()}, {int64(), int64()}); - CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); + CheckDispatchBest(name, {uint8(), uint8()}, {uint8(), uint8()}); + CheckDispatchBest(name, {uint8(), uint16()}, {uint16(), uint16()}); + CheckDispatchBest(name, {int32(), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {float32(), int64()}, {float32(), float32()}); CheckDispatchBest(name, {float64(), int32()}, {float64(), float64()}); CheckDispatchBest(name, {dictionary(int8(), float64()), float64()}, {float64(), float64()}); - CheckDispatchBest(name, {dictionary(int8(), float64()), int16()}, {float64(), float64()}); - CheckDispatchBest(name, {dictionary(int8(), utf8()), utf8()}, {utf8(), utf8()}); - CheckDispatchBest(name, {timestamp(TimeUnit::MICRO), date64()}, {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}); @@ -496,6 +500,10 @@ TEST(TestCompareKernel, GreaterWithImplicitCasts) { ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), ArrayFromJSON(boolean(), "[false, false, true, null]")); + CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-16, 0, 16, null]"), + ArrayFromJSON(uint8(), "[255, 254, 1, 0]"), + ArrayFromJSON(boolean(), "[false, false, true, null]")); + CheckScalarBinary("greater", ArrayFromJSON(dictionary(int32(), int32()), "[0, 1, 2, null]"), ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), @@ -512,11 +520,27 @@ TEST(TestCompareKernel, GreaterWithImplicitCasts) { ArrayFromJSON(boolean(), "[false, true, false]")); CheckScalarBinary("greater", - ArrayFromJSON(dictionary(int32(), int8()), "[0, 1, 2, null]"), + ArrayFromJSON(dictionary(int32(), int8()), "[3, -3, -28, null]"), ArrayFromJSON(uint32(), "[3, 4, 5, 7]"), ArrayFromJSON(boolean(), "[false, false, false, null]")); } +TEST(TestCompareKernel, GreaterWithImplicitCastsUint64EdgeCase) { + // int64 is as wide as we can promote + CheckDispatchBest("greater", {int8(), uint64()}, {int64(), int64()}); + + // this works sometimes + CheckScalarBinary("greater", ArrayFromJSON(int8(), "[-1]"), + ArrayFromJSON(uint64(), "[0]"), ArrayFromJSON(boolean(), "[false]")); + + // ... but it can result in impossible implicit casts in the presence of uint64, since + // some uint64 values cannot be cast to int64: + ASSERT_RAISES( + Invalid, + CallFunction("greater", {ArrayFromJSON(int64(), "[-1]"), + ArrayFromJSON(uint64(), "[18446744073709551615]")})); +} + class TestStringCompareKernel : public ::testing::Test {}; TEST_F(TestStringCompareKernel, SimpleCompareArrayScalar) { diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 2fc817120d9..73e900351fb 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -186,9 +186,9 @@ void CheckDispatchBest(std::string func_name, std::vector original_v function->DispatchExact(expected_equivalent_values)); EXPECT_EQ(actual_kernel, expected_kernel) - << "DispatchBest" << ValueDescr::ToString(original_values) << " => " + << " DispatchBest" << ValueDescr::ToString(original_values) << " => " << actual_kernel->signature->ToString() << "\n" - << "DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => " + << " DispatchExact" << ValueDescr::ToString(expected_equivalent_values) << " => " << expected_kernel->signature->ToString(); } diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 79e25294ebb..c344ddb8648 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -104,6 +104,57 @@ exact semantics of the function:: .. seealso:: :doc:`Compute API reference ` +Implicit casts +============== + +Functions may require conversion of their arguments before execution if a +kernel does not match the argument types precisely. For example comparison +of dictionary encoded arrays is not directly supported by any kernel, but an +implicit cast can be made allowing comparison against the decoded array. + +Each function may define implicit cast behaviour as appropriate. For example +comparison and arithmetic kernels require identically typed arguments, and +support execution against differing numeric types by promoting their arguments +to numeric type which can accommodate any value from either input. + +.. _common-numeric-type: + +Common numeric type +~~~~~~~~~~~~~~~~~~~ + +The common numeric type of a set of input numeric types is the smallest numeric +type which can accommodate any value of any input. If any input is a floating +point type the common numeric type is the widest floating point type among the +inputs. Otherwise the common numeric type is integral and is signed if any input +is signed. For example: + ++-------------------+----------------------+------------------------------------------------+ +| Input types | Common numeric type | Notes | ++===================+======================+================================================+ +| int32, int32 | int32 | | ++-------------------+----------------------+------------------------------------------------+ +| int16, int32 | int32 | Max width is 32, promote LHS to int32 | ++-------------------+----------------------+------------------------------------------------+ +| uint16, int32 | int32 | One input signed, override unsigned | ++-------------------+----------------------+------------------------------------------------+ +| uint32, int32 | int64 | Widen to accommodate range of uint32 | ++-------------------+----------------------+------------------------------------------------+ +| uint16, uint32 | uint32 | All inputs unsigned, maintain unsigned | ++-------------------+----------------------+------------------------------------------------+ +| int16, uint32 | int64 | | ++-------------------+----------------------+------------------------------------------------+ +| uint64, int16 | int64 | NB: int64 cannot accommodate all uint64 values | ++-------------------+----------------------+------------------------------------------------+ +| float32, int32 | float32 | Promote RHS to float32 | ++-------------------+----------------------+------------------------------------------------+ +| float32, float64 | float64 | | ++-------------------+----------------------+------------------------------------------------+ +| float32, int64 | float32 | int64 is wider, still promotes to float32 | ++-------------------+----------------------+------------------------------------------------+ + +In particulary, note that comparing a `uint64` column to an `int16` column may +emit an error if one of the LHS' values cannot be expressed as the common type +`int64` (for example, `2 ** 63`). .. _compute-function-list: @@ -750,40 +801,3 @@ Structural transforms in the list array is appended to the output. Nulls in the parent list array are discarded. -.. _common-numeric-type: - -Common numeric type -~~~~~~~~~~~~~~~~~~~ - -The common numeric type of a set of input numeric types is the smallest numeric -type which can accommodate any value of any input. If any input is a floating -point type the common numeric type is the widest floating point type among the -inputs. Otherwise the common numeric type is integral, is signed if any input -is signed, and its width is the maximum width of any input. For example: - -+-------------------+----------------------+-------------------------------------------+ -| Input types | Common numeric type | Notes | -+===================+======================+===========================================+ -| int32, int32 | int32 | | -+-------------------+----------------------+-------------------------------------------+ -| uint32, int32 | int32 | One input signed, override unsigned | -+-------------------+----------------------+-------------------------------------------+ -| int16, int32 | int32 | Max width is 32, promote LHS to int32 | -+-------------------+----------------------+-------------------------------------------+ -| uint16, uint32 | uint32 | All inputs unsigned, maintain unsigned | -+-------------------+----------------------+-------------------------------------------+ -| uint16, int32 | int32 | One input signed, override unsigned | -+-------------------+----------------------+-------------------------------------------+ -| int16, uint32 | int32 | | -+-------------------+----------------------+-------------------------------------------+ -| float32, int32 | float32 | Promote RHS to float32 | -+-------------------+----------------------+-------------------------------------------+ -| float32, float64 | float64 | | -+-------------------+----------------------+-------------------------------------------+ -| float32, int64 | float32 | int64 is wider, still promotes to float32 | -+-------------------+----------------------+-------------------------------------------+ - -In particulary, note that comparing a `uint32` column to an `int32` column may -emit an error if one of the LHS' values cannot be expressed as the common type -`int32` (for example, `2 ** 31`). This tradeoff is made to keep the results of -arithmetic operations narrow. From 66aa801acb968324843ad078429bb887b083844e Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 8 Feb 2021 17:02:48 -0500 Subject: [PATCH 20/22] add test case for stripping casts from uint32 to signed integer types --- cpp/src/arrow/dataset/expression_test.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/dataset/expression_test.cc b/cpp/src/arrow/dataset/expression_test.cc index 9fc5774f3a6..ae62283b1d7 100644 --- a/cpp/src/arrow/dataset/expression_test.cc +++ b/cpp/src/arrow/dataset/expression_test.cc @@ -128,6 +128,11 @@ TEST(ExpressionUtils, StripOrderPreservingCasts) { Expect(cast(field_ref("u32"), uint16()), no_change); Expect(cast(field_ref("u32"), uint8()), no_change); + Expect(cast(field_ref("u32"), int64()), field_ref("u32")); + Expect(cast(field_ref("u32"), int32()), field_ref("u32")); + Expect(cast(field_ref("u32"), int16()), no_change); + Expect(cast(field_ref("u32"), int8()), no_change); + // Casting float to int can affect ordering. // For example, let // a = 3.5, b = 3.0, assert(a > b) @@ -411,7 +416,7 @@ TEST(Expression, BindWithImplicitCasts) { cmp(cast(field_ref("i32"), int64()), field_ref("i64"))); ExpectBindsTo(cmp(field_ref("i8"), field_ref("u32")), - cmp(cast(field_ref("i8"), int32()), cast(field_ref("u32"), int32()))); + cmp(cast(field_ref("i8"), int64()), cast(field_ref("u32"), int64()))); // cast dictionary to value type ExpectBindsTo(cmp(field_ref("dict_str"), field_ref("str")), From 62a6b5ef4dec9a05254b511692c664cf490a6fe8 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 10 Feb 2021 15:19:30 +0100 Subject: [PATCH 21/22] Nits + fix compile error (hopefully) --- cpp/src/arrow/compute/cast.h | 2 +- cpp/src/arrow/compute/function.cc | 4 ++-- cpp/src/arrow/compute/kernels/codegen_internal.cc | 2 +- docs/source/cpp/compute.rst | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index c0405fb93e9..818f2ef9182 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -161,7 +161,7 @@ Result Cast(const Datum& value, std::shared_ptr to_type, /// \param[in] ctx the function execution context, optional /// \return the resulting datums /// -/// \since 1.0.0 +/// \since 4.0.0 /// \note API not yet finalized ARROW_EXPORT Result> Cast(std::vector values, std::vector descrs, diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 776c14d5ca0..70d7d998e9c 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -38,8 +38,8 @@ static const FunctionDoc kEmptyFunctionDoc{}; const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } -Status CheckArityImpl(const Function* function, int passed_num_args, - const char* passed_num_args_label) { +static Status CheckArityImpl(const Function* function, int passed_num_args, + const char* passed_num_args_label) { if (function->arity().is_varargs && passed_num_args < function->arity().num_args) { return Status::Invalid("VarArgs function ", function->name(), " needs at least ", function->arity().num_args, " arguments but ", diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 45211600727..b321ff3fc8b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -248,7 +248,7 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { } if (max_width_signed <= max_width_unsigned) { - max_width_signed = BitUtil::NextPower2(max_width_unsigned + 1); + max_width_signed = static_cast(BitUtil::NextPower2(max_width_unsigned + 1)); } if (max_width_signed >= 64) return int64(); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index c344ddb8648..4101c36ef8f 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -143,7 +143,7 @@ is signed. For example: +-------------------+----------------------+------------------------------------------------+ | int16, uint32 | int64 | | +-------------------+----------------------+------------------------------------------------+ -| uint64, int16 | int64 | NB: int64 cannot accommodate all uint64 values | +| uint64, int16 | int64 | int64 cannot accommodate all uint64 values | +-------------------+----------------------+------------------------------------------------+ | float32, int32 | float32 | Promote RHS to float32 | +-------------------+----------------------+------------------------------------------------+ @@ -152,9 +152,9 @@ is signed. For example: | float32, int64 | float32 | int64 is wider, still promotes to float32 | +-------------------+----------------------+------------------------------------------------+ -In particulary, note that comparing a `uint64` column to an `int16` column may -emit an error if one of the LHS' values cannot be expressed as the common type -`int64` (for example, `2 ** 63`). +In particulary, note that comparing a ``uint64`` column to an ``int16`` column +may emit an error if one of the ``uint64`` values cannot be expressed as the +common type ``int64`` (for example, ``2 ** 63``). .. _compute-function-list: From 6ded65f0b7beeaeaa89403d6651872a74633f832 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 10 Feb 2021 14:52:12 -0500 Subject: [PATCH 22/22] inline InitKernelState, ensure KernelInitArgs::inputs is bound to a non temporary for msvc --- cpp/src/arrow/dataset/expression.cc | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/dataset/expression.cc b/cpp/src/arrow/dataset/expression.cc index 69fefc1b4d0..56339430ee9 100644 --- a/cpp/src/arrow/dataset/expression.cc +++ b/cpp/src/arrow/dataset/expression.cc @@ -379,19 +379,6 @@ bool Expression::IsSatisfiable() const { namespace { -Result> InitKernelState( - const Expression::Call& call, compute::ExecContext* exec_context) { - if (!call.kernel->init) return nullptr; - - compute::KernelContext kernel_context(exec_context); - compute::KernelInitArgs kernel_init_args{call.kernel, GetDescriptors(call.arguments), - call.options.get()}; - - auto kernel_state = call.kernel->init(&kernel_context, kernel_init_args); - RETURN_NOT_OK(kernel_context.status()); - return std::move(kernel_state); -} - // Produce a bound Expression from unbound Call and bound arguments. Result BindNonRecursive(Expression::Call call, bool insert_implicit_casts, compute::ExecContext* exec_context) { @@ -436,8 +423,13 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ } compute::KernelContext kernel_context(exec_context); - ARROW_ASSIGN_OR_RAISE(call.kernel_state, InitKernelState(call, exec_context)); - kernel_context.SetState(call.kernel_state.get()); + if (call.kernel->init) { + call.kernel_state = + call.kernel->init(&kernel_context, {call.kernel, descrs, call.options.get()}); + + RETURN_NOT_OK(kernel_context.status()); + kernel_context.SetState(call.kernel_state.get()); + } ARROW_ASSIGN_OR_RAISE( call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs));