From 76cfec0f834c079fa858238835988c5e9c373891 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Mon, 17 May 2021 02:27:02 +0000 Subject: [PATCH 1/4] ARROW-12074: [C++][Compute] Add scalar arithmetic kernels for decimal --- .../arrow/compute/kernels/codegen_internal.h | 46 +++ .../compute/kernels/scalar_arithmetic.cc | 208 +++++++++++- .../compute/kernels/scalar_arithmetic_test.cc | 301 ++++++++++++++++++ cpp/src/arrow/compute/kernels/test_util.cc | 16 + cpp/src/arrow/compute/kernels/test_util.h | 3 + docs/source/cpp/compute.rst | 34 +- 6 files changed, 593 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 891f90a97d4..8ff3b104191 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -149,6 +149,8 @@ struct GetViewType { static T LogicalValue(PhysicalType value) { return Decimal128(reinterpret_cast(value.data())); } + + static T LogicalValue(T value) { return value; } }; template <> @@ -159,6 +161,8 @@ struct GetViewType { static T LogicalValue(PhysicalType value) { return Decimal256(reinterpret_cast(value.data())); } + + static T LogicalValue(T value) { return value; } }; template @@ -243,6 +247,19 @@ struct ArrayIterator> { } }; +template +struct ArrayIterator> { + using T = typename TypeTraits::ScalarType::ValueType; + using Bytes = std::array; + const Bytes* values; + + explicit ArrayIterator(const ArrayData& data) : values(data.GetValues(1)) { + DCHECK_EQ(sizeof(T) * 8, bit_width(data.type->id())); + } + + T operator()() { return T{values++->data()}; } +}; + // Iterator over various output array types, taking a GetOutputType template @@ -262,6 +279,21 @@ struct OutputArrayWriter> { void WriteNull() { *values++ = T{}; } }; +template +struct OutputArrayWriter> { + using T = typename TypeTraits::ScalarType::ValueType; + using Bytes = std::array; + Bytes* values; + + explicit OutputArrayWriter(ArrayData* data) : values(data->GetMutableValues(1)) { + DCHECK_EQ(sizeof(T) * 8, bit_width(data->type->id())); + } + + void Write(T value) { value.ToBytes(values++->data()); } + + void WriteNull() { T{}.ToBytes(values++->data()); } +}; + // (Un)box Scalar to / from C++ value template @@ -538,6 +570,20 @@ struct OutputAdapter> { } }; +template +struct OutputAdapter> { + using T = typename TypeTraits::ScalarType::ValueType; + template + static Status Write(KernelContext*, Datum* out, Generator&& generator) { + ArrayData* out_arr = out->mutable_array(); + T* out_data = out_arr->GetMutableValues(1); + for (int64_t i = 0; i < out_arr->length; ++i) { + *out_data++ = generator(); + } + return Status::OK(); + } +}; + // A kernel exec generator for unary functions that addresses both array and // scalar inputs and dispatches input iteration and output writing to other // templates diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 743d2e3fc0e..c437029d7e7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include "arrow/compute/kernels/common.h" #include "arrow/type_traits.h" +#include "arrow/util/decimal.h" #include "arrow/util/int_util_internal.h" #include "arrow/util/macros.h" @@ -62,6 +64,11 @@ using enable_if_integer = template using enable_if_floating_point = enable_if_t::value, T>; +template +using enable_if_decimal = + enable_if_t::value || std::is_same::value, + T>; + template ::type> constexpr Unsigned to_unsigned(T signed_) { return static_cast(signed_); @@ -82,6 +89,12 @@ struct AbsoluteValue { static constexpr enable_if_signed_integer Call(KernelContext*, T arg, Status* st) { return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg; } + + template + static enable_if_decimal Call(KernelContext*, T arg, Status* st) { + *st = Status::NotImplemented("NYI"); + return T(); + } }; struct AbsoluteValueChecked { @@ -106,6 +119,12 @@ struct AbsoluteValueChecked { static_assert(std::is_same::value, ""); return std::fabs(arg); } + + template + static enable_if_decimal Call(KernelContext*, Arg arg, Status* st) { + *st = Status::NotImplemented("NYI"); + return T(); + } }; struct Add { @@ -126,11 +145,16 @@ struct Add { Status*) { return arrow::internal::SafeSignedAdd(left, right); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + right; + } }; struct AddChecked { template - enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + static enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) { @@ -140,10 +164,16 @@ struct AddChecked { } template - enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { static_assert(std::is_same::value && std::is_same::value, ""); return left + right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + right; + } }; struct Subtract { @@ -164,11 +194,16 @@ struct Subtract { Status*) { return arrow::internal::SafeSignedSubtract(left, right); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + (-right); + } }; struct SubtractChecked { template - enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + static enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { @@ -178,10 +213,16 @@ struct SubtractChecked { } template - enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { static_assert(std::is_same::value && std::is_same::value, ""); return left - right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + (-right); + } }; struct Multiply { @@ -224,11 +265,16 @@ struct Multiply { static constexpr uint16_t Call(KernelContext*, uint16_t left, uint16_t right, Status*) { return static_cast(left) * static_cast(right); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left * right; + } }; struct MultiplyChecked { template - enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + static enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) { @@ -238,10 +284,16 @@ struct MultiplyChecked { } template - enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { static_assert(std::is_same::value && std::is_same::value, ""); return left * right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left * right; + } }; struct Divide { @@ -263,6 +315,16 @@ struct Divide { } return result; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + if (right == Arg1()) { + *st = Status::Invalid("Divide by zero"); + return T(); + } else { + return left / right; + } + } }; struct DivideChecked { @@ -290,6 +352,16 @@ struct DivideChecked { } return left / right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + if (right == Arg1()) { + *st = Status::Invalid("Divide by zero"); + return T(); + } else { + return left / right; + } + } }; struct Negate { @@ -304,9 +376,15 @@ struct Negate { } template - static constexpr enable_if_signed_integer Call(KernelContext*, Arg arg, Status* st) { + static constexpr enable_if_signed_integer Call(KernelContext*, Arg arg, Status*) { return arrow::internal::SafeSignedNegate(arg); } + + template + static enable_if_decimal Call(KernelContext*, Arg arg, Status* st) { + *st = Status::NotImplemented("NYI"); + return T(); + } }; struct NegateChecked { @@ -333,6 +411,12 @@ struct NegateChecked { static_assert(std::is_same::value, ""); return -arg; } + + template + static enable_if_decimal Call(KernelContext*, Arg arg, Status* st) { + *st = Status::NotImplemented("NYI"); + return T(); + } }; struct Power { @@ -361,6 +445,12 @@ struct Power { static enable_if_floating_point Call(KernelContext*, T base, T exp, Status*) { return std::pow(base, exp); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 base, Arg1 exp, Status* st) { + *st = Status::NotImplemented("NYI"); + return T(); + } }; struct PowerChecked { @@ -395,6 +485,12 @@ struct PowerChecked { static_assert(std::is_same::value && std::is_same::value, ""); return std::pow(base, exp); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 base, Arg1 exp, Status* st) { + *st = Status::NotImplemented("NYI"); + return T(); + } }; // Generate a kernel given an arithmetic functor @@ -428,12 +524,69 @@ ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) { } } +// calculate output precision/scale and args rescaling per operation type +Result> GetDecimalBinaryOutput( + const std::string& op, const std::vector& values, + std::vector>* replaced = nullptr) { + const auto& left_type = checked_pointer_cast(values[0].type); + const auto& right_type = checked_pointer_cast(values[1].type); + + const int32_t p1 = left_type->precision(), s1 = left_type->scale(); + const int32_t p2 = right_type->precision(), s2 = right_type->scale(); + if (s1 < 0 || s2 < 0) { + return Status::NotImplemented("Decimals with negative scales not supported"); + } + + int32_t out_prec, out_scale; + int32_t left_scaleup = 0, right_scaleup = 0; + + // decimal upscaling behaviour references amazon redshift + // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html + if (op.find("add") == 0 || op.find("subtract") == 0) { + out_scale = std::max(s1, s2); + out_prec = std::max(p1 - s1, p2 - s2) + 1 + out_scale; + left_scaleup = out_scale - s1; + right_scaleup = out_scale - s2; + } else if (op.find("multiply") == 0) { + out_scale = s1 + s2; + out_prec = p1 + p2 + 1; + } else if (op.find("divide") == 0) { + out_scale = std::max(4, s1 + p2 - s2 + 1); + out_prec = p1 - s1 + s2 + out_scale; // >= p1 + p2 + 1 + left_scaleup = out_prec - p1; + } else { + return Status::Invalid("Invalid decimal operation: ", op); + } + + const auto id = left_type->id(); + auto make = [id](int32_t precision, int32_t scale) { + if (id == Type::DECIMAL128) { + return Decimal128Type::Make(precision, scale); + } else { + return Decimal256Type::Make(precision, scale); + } + }; + + if (replaced) { + replaced->resize(2); + ARROW_ASSIGN_OR_RAISE((*replaced)[0], make(p1 + left_scaleup, s1 + left_scaleup)); + ARROW_ASSIGN_OR_RAISE((*replaced)[1], make(p2 + right_scaleup, s2 + right_scaleup)); + } + + return make(out_prec, out_scale); +} + struct ArithmeticFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); + const auto type_id = (*values)[0].type->id(); + if (type_id == Type::DECIMAL128 || type_id == Type::DECIMAL256) { + return DispatchDecimal(values); + } + using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -451,8 +604,45 @@ struct ArithmeticFunction : ScalarFunction { if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); } + + Result DispatchDecimal(std::vector* values) const { + if (values->size() == 2) { + std::vector> replaced; + RETURN_NOT_OK(GetDecimalBinaryOutput(name(), *values, &replaced)); + (*values)[0].type = std::move(replaced[0]); + (*values)[1].type = std::move(replaced[1]); + } + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +// resolve decimal operation output type +struct DecimalBinaryOutputResolver { + std::string func_name; + + DecimalBinaryOutputResolver(std::string func_name) : func_name(std::move(func_name)) {} + + Result operator()(KernelContext*, const std::vector& args) { + ARROW_ASSIGN_OR_RAISE(auto out_type, GetDecimalBinaryOutput(func_name, args)); + return ValueDescr(std::move(out_type)); + } }; +template +void AddDecimalBinaryKernels(const std::string& name, + std::shared_ptr* func) { + auto out_type = OutputType(DecimalBinaryOutputResolver(name)); + auto in_type128 = InputType(Type::DECIMAL128); + auto in_type256 = InputType(Type::DECIMAL256); + auto exec128 = ScalarBinaryNotNullEqualTypes::Exec; + auto exec256 = ScalarBinaryNotNullEqualTypes::Exec; + DCHECK_OK((*func)->AddKernel({in_type128, in_type128}, out_type, exec128)); + DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256)); +} + template std::shared_ptr MakeArithmeticFunction(std::string name, const FunctionDoc* doc) { @@ -461,6 +651,8 @@ std::shared_ptr MakeArithmeticFunction(std::string name, auto exec = ArithmeticExecFromOp(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); } + AddDecimalBinaryKernels(name, &func); + return func; } @@ -474,6 +666,8 @@ std::shared_ptr MakeArithmeticFunctionNotNull(std::string name, auto exec = ArithmeticExecFromOp(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); } + AddDecimalBinaryKernels(name, &func); + return func; } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index c4bfac459dc..2475f785891 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1148,5 +1148,306 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) { } } +class TestBinaryArithmeticDecimal : public TestBase { + protected: + struct Arg { + std::shared_ptr type; + std::string value; + }; + + std::shared_ptr GetOutType(const std::string& op, + const std::shared_ptr& left_type, + const std::shared_ptr& right_type) { + auto left_decimal_type = std::static_pointer_cast(left_type); + auto right_decimal_type = std::static_pointer_cast(right_type); + + const int32_t p1 = left_decimal_type->precision(), s1 = left_decimal_type->scale(); + const int32_t p2 = right_decimal_type->precision(), s2 = right_decimal_type->scale(); + + // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html + int32_t precision, scale; + if (op == "add" || op == "subtract") { + scale = std::max(s1, s2); + precision = std::max(p1 - s1, p2 - s2) + 1 + scale; + } else if (op == "multiply") { + scale = s1 + s2; + precision = p1 + p2 + 1; + } else if (op == "divide") { + scale = std::max(4, s1 + p2 - s2 + 1); + precision = p1 - s1 + s2 + scale; + } else { + ABORT_NOT_OK(Status::Invalid("invalid binary operator: ", op)); + } + + std::shared_ptr type; + if (left_type->id() == Type::DECIMAL128) { + ASSIGN_OR_ABORT(type, Decimal128Type::Make(precision, scale)); + } else { + ASSIGN_OR_ABORT(type, Decimal256Type::Make(precision, scale)); + } + return type; + } + + std::shared_ptr MakeScalar(const std::shared_ptr& type, + const std::string& str) { + std::shared_ptr scalar; + if (type->id() == Type::DECIMAL128) { + Decimal128 value; + int32_t dummy; + ABORT_NOT_OK(Decimal128::FromString(str, &value, &dummy)); + ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value)); + } else { + Decimal256 value; + int32_t dummy; + ABORT_NOT_OK(Decimal256::FromString(str, &value, &dummy)); + ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value)); + } + return scalar; + } + + Datum ToDatum(const std::shared_ptr& type, const std::string& value) { + if (value.find("[") == std::string::npos) { + return Datum(MakeScalar(type, value)); + } else { + return Datum(ArrayFromJSON(type, value)); + } + } + + void Assert(const std::string& op, const Arg& left, const Arg& right, + const std::string& expected) { + const Datum arg0 = ToDatum(left.type, left.value); + const Datum arg1 = ToDatum(right.type, right.value); + + auto out_type = GetOutType(op, left.type, right.type); + CheckScalar(op, {arg0, arg1}, ToDatum(out_type, expected), &options_); + + // commutative operations + if (op == "add" || op == "multiply") { + CheckScalar(op, {arg1, arg0}, ToDatum(out_type, expected), &options_); + } + } + + void AssertFail(const std::string& op, const Arg& left, const Arg& right) { + const Datum arg0 = ToDatum(left.type, left.value); + const Datum arg1 = ToDatum(right.type, right.value); + + ASSERT_NOT_OK(CallFunction(op, {arg0, arg1}, &options_)); + if (op == "add" || op == "multiply") { + ASSERT_NOT_OK(CallFunction(op, {arg1, arg0}, &options_)); + } + } + + ArithmeticOptions options_ = ArithmeticOptions(); +}; + +// reference result from bc (precsion=100, scale=40) +TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { + Arg left, right; + std::string added, subtracted; + + // array array, decimal128 + left = { + decimal128(30, 3), + R"([ + "1.000", + "-123456789012345678901234567.890", + "98765432109876543210.987", + "-999999999999999999999999999.999" + ])", + }; + right = { + decimal128(20, 9), + R"([ + "-1.000000000", + "12345678901.234567890", + "98765.432101234", + "-99999999999.999999999" + ])", + }; + added = R"([ + "0.000000000", + "-123456789012345666555555666.655432110", + "98765432109876641976.419101234", + "-1000000000000000099999999999.998999999" + ])"; + subtracted = R"([ + "2.000000000", + "-123456789012345691246913469.124567890", + "98765432109876444445.554898766", + "-999999999999999899999999999.999000001" + ])"; + this->Assert("add", left, right, added); + this->Assert("subtract", left, right, subtracted); + + // array array, decimal256 + left = { + decimal256(30, 20), + R"([ + "-1.00000000000000000001", + "1234567890.12345678900000000000", + "-9876543210.09876543210987654321", + "9999999999.99999999999999999999" + ])", + }; + right = { + decimal256(30, 10), + R"([ + "1.0000000000", + "-1234567890.1234567890", + "6789.5432101234", + "99999999999999999999.9999999999" + ])", + }; + added = R"([ + "-0.00000000000000000001", + "0.00000000000000000000", + "-9876536420.55555530870987654321", + "100000000009999999999.99999999989999999999" + ])"; + subtracted = R"([ + "-2.00000000000000000001", + "2469135780.24691357800000000000", + "-9876549999.64197555550987654321", + "-99999999989999999999.99999999990000000001" + ])"; + this->Assert("add", left, right, added); + this->Assert("subtract", left, right, subtracted); + + // scalar array + left = {decimal128(6, 1), "12345.6"}; + right = {decimal128(10, 3), R"(["1.234", "1234.000", "-9876.543", "666.888"])"}; + added = R"(["12346.834", "13579.600", "2469.057", "13012.488"])"; + subtracted = R"(["12344.366", "11111.600", "22222.143", "11678.712"])"; + this->Assert("add", left, right, added); + this->Assert("subtract", left, right, subtracted); + // right - left + subtracted = R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])"; + this->Assert("subtract", right, left, subtracted); + + // scalar scalar + left = {decimal256(3, 0), "666"}; + right = {decimal256(3, 0), "888"}; + this->Assert("add", left, right, "1554"); + this->Assert("subtract", left, right, "-222"); + + // failed case: result *maybe* overflow + left = {decimal128(21, 20), "0.12345678901234567890"}; + right = {decimal128(21, 1), "1.0"}; + this->AssertFail("add", left, right); + this->AssertFail("subtract", left, right); + + left = {decimal256(75, 0), "0"}; + right = {decimal256(2, 1), "0.0"}; + this->AssertFail("add", left, right); + this->AssertFail("subtract", left, right); +} + +TEST_F(TestBinaryArithmeticDecimal, Multiply) { + Arg left, right; + std::string expected; + + // array array + left = { + decimal128(20, 10), + R"([ + "1234567890.1234567890", + "-0.0000000001", + "-9999999999.9999999999" + ])", + }; + right = { + decimal128(13, 3), + R"([ + "1234567890.123", + "0.001", + "-9999999999.999" + ])", + }; + expected = R"([ + "1524157875323319737.9870903950470", + "-0.0000000000001", + "99999999999989999999.0000000000001" + ])"; + this->Assert("multiply", left, right, expected); + + left = { + decimal256(30, 3), + R"([ + "123456789012345678901234567.890", + "0.000" + ])", + }; + right = { + decimal256(20, 9), + R"([ + "-12345678901.234567890", + "99999999999.999999999" + ])", + }; + expected = R"([ + "-1524157875323883675034293577501905199.875019052100", + "0.000000000000" + ])"; + this->Assert("multiply", left, right, expected); + + // scalar array + left = {decimal128(3, 2), "3.14"}; + right = {decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"}; + expected = R"(["3.14", "6.28", "9.42", "12.56", "15.70"])"; + this->Assert("multiply", left, right, expected); + + // scalar scalar + left = {decimal128(1, 0), "1"}; + right = {decimal128(1, 0), "1"}; + this->Assert("multiply", left, right, "1"); + + // failed case: result *maybe* overflow + left = {decimal128(20, 0), "1"}; + right = {decimal128(18, 1), "1.0"}; + this->AssertFail("multiply", left, right); +} + +TEST_F(TestBinaryArithmeticDecimal, Divide) { + Arg left, right; + std::string expected; + + // array array + left = {decimal128(13, 3), R"(["1234567890.123", "0.001"])"}; + right = {decimal128(3, 0), R"(["-987", "999"])"}; + // scale = 7 + expected = R"(["-1250828.6627386", "0.0000010"])"; + this->Assert("divide", left, right, expected); + + left = {decimal256(20, 10), R"(["1234567890.1234567890", "9999999999.9999999999"])"}; + right = {decimal256(13, 3), R"(["1234567890.123", "0.001"])"}; + // scale = 21 + expected = R"(["1.000000000000369999093", "9999999999999.999999900000000000000"])"; + this->Assert("divide", left, right, expected); + + // scalar array + left = {decimal128(1, 0), "1"}; + right = {decimal128(1, 0), R"(["1", "2", "3", "4"])"}; + // scale = 4 + expected = R"(["1.0000", "0.5000", "0.3333", "0.2500"])"; + this->Assert("divide", left, right, expected); + // right / left + expected = R"(["1.0000", "2.0000", "3.0000", "4.0000"])"; + this->Assert("divide", right, left, expected); + + // scalar scalar + left = {decimal256(6, 5), "2.71828"}; + right = {decimal256(6, 5), "3.14159"}; + // scale = 7 + this->Assert("divide", left, right, "0.8652561"); + + // failed case: result *maybe* overflow + left = {decimal128(20, 20), "0.12345678901234567890"}; + right = {decimal128(20, 0), "12345678901234567890"}; + this->AssertFail("divide", left, right); + + // failed case: divide by 0 + this->AssertFail("divide", {decimal256(1, 0), "0"}, {decimal256(1, 0), "0"}); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 18257973150..cf37b33c006 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -169,6 +169,22 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, } } +void CheckScalar(std::string func_name, const std::vector& inputs, + const Datum& expected, const FunctionOptions* options) { + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); + ASSERT_EQ(out.kind(), expected.kind()); + if (out.kind() == Datum::ARRAY) { + std::shared_ptr actual = out.make_array(); + ASSERT_OK(actual->ValidateFull()); + AssertArraysEqual(*expected.make_array(), *actual, /*verbose=*/true); + } else if (out.kind() == Datum::SCALAR) { + std::shared_ptr actual = out.scalar(); + AssertScalarsEqual(*expected.scalar(), *actual, /*verbose=*/true); + } else { + ASSERT_EQ(out, expected); + } +} + void CheckScalarUnary(std::string func_name, std::shared_ptr input, std::shared_ptr expected, const FunctionOptions* options) { ArrayVector input_vector = {input}; diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 85ed04c183a..4bfce3d360f 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -99,6 +99,9 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); +void CheckScalar(std::string func_name, const std::vector& inputs, + const Datum& expected, const FunctionOptions* options = nullptr); + void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index dfdd64d19c6..0a8c685e917 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -266,17 +266,17 @@ an ``Invalid`` :class:`Status` when overflow is detected. +--------------------------+------------+--------------------+---------------------+ | abs_checked | Unary | Numeric | Numeric | +--------------------------+------------+--------------------+---------------------+ -| add | Binary | Numeric | Numeric | +| add | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| add_checked | Binary | Numeric | Numeric | +| add_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| divide | Binary | Numeric | Numeric | +| divide | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| divide_checked | Binary | Numeric | Numeric | +| divide_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| multiply | Binary | Numeric | Numeric | +| multiply | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| multiply_checked | Binary | Numeric | Numeric | +| multiply_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ | negate | Unary | Numeric | Numeric | +--------------------------+------------+--------------------+---------------------+ @@ -286,11 +286,29 @@ an ``Invalid`` :class:`Status` when overflow is detected. +--------------------------+------------+--------------------+---------------------+ | power_checked | Binary | Numeric | Numeric | +--------------------------+------------+--------------------+---------------------+ -| subtract | Binary | Numeric | Numeric | +| subtract | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| subtract_checked | Binary | Numeric | Numeric | +| subtract_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ +* \(1) Precision and scale of computed DECIMAL results + ++------------+---------------------------------------------+ +| Operation | Result precision and scale | ++============+=============================================+ +| | add | | scale = max(s1, s2) | +| | subtract | | precision = max(p1-s1, p2-s2) + 1 + scale | ++------------+---------------------------------------------+ +| multiply | | scale = s1 + s2 | +| | | precision = p1 + p2 + 1 | ++------------+---------------------------------------------+ +| divide | | scale = max(4, s1 + p2 - s2 + 1) | +| | | precision = p1 - s1 + s2 + scale | ++------------+---------------------------------------------+ + +Decimal overflow is checked before calculation. Error is returned if the result +precision is beyond the decimal range. + Comparisons ~~~~~~~~~~~ From ae3ef781cdc1ab08302ffe7a7a0391ebe8b59dc0 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Wed, 9 Jun 2021 09:45:41 +0000 Subject: [PATCH 2/4] address review comments --- .../arrow/compute/kernels/codegen_internal.h | 24 +- .../compute/kernels/scalar_arithmetic.cc | 169 ++++-- .../compute/kernels/scalar_arithmetic_test.cc | 520 +++++++++--------- cpp/src/arrow/compute/kernels/test_util.cc | 16 - cpp/src/arrow/compute/kernels/test_util.h | 3 - cpp/src/arrow/type.cc | 11 + cpp/src/arrow/type.h | 4 + cpp/src/arrow/type_traits.h | 11 + docs/source/cpp/compute.rst | 7 +- 9 files changed, 425 insertions(+), 340 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 8ff3b104191..6a5cee124c0 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -250,12 +250,11 @@ struct ArrayIterator> { template struct ArrayIterator> { using T = typename TypeTraits::ScalarType::ValueType; - using Bytes = std::array; - const Bytes* values; + using endian_agnostic = std::array; + const endian_agnostic* values; - explicit ArrayIterator(const ArrayData& data) : values(data.GetValues(1)) { - DCHECK_EQ(sizeof(T) * 8, bit_width(data.type->id())); - } + explicit ArrayIterator(const ArrayData& data) + : values(data.GetValues(1)) {} T operator()() { return T{values++->data()}; } }; @@ -282,12 +281,11 @@ struct OutputArrayWriter> { template struct OutputArrayWriter> { using T = typename TypeTraits::ScalarType::ValueType; - using Bytes = std::array; - Bytes* values; + using endian_agnostic = std::array; + endian_agnostic* values; - explicit OutputArrayWriter(ArrayData* data) : values(data->GetMutableValues(1)) { - DCHECK_EQ(sizeof(T) * 8, bit_width(data->type->id())); - } + explicit OutputArrayWriter(ArrayData* data) + : values(data->GetMutableValues(1)) {} void Write(T value) { value.ToBytes(values++->data()); } @@ -573,12 +571,14 @@ struct OutputAdapter> { template struct OutputAdapter> { using T = typename TypeTraits::ScalarType::ValueType; + using endian_agnostic = std::array; + template static Status Write(KernelContext*, Datum* out, Generator&& generator) { ArrayData* out_arr = out->mutable_array(); - T* out_data = out_arr->GetMutableValues(1); + auto out_data = out_arr->GetMutableValues(1); for (int64_t i = 0; i < out_arr->length; ++i) { - *out_data++ = generator(); + generator().ToBytes(out_data++->data()); } return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index c437029d7e7..a5199f680ac 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -524,56 +524,110 @@ ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) { } } -// calculate output precision/scale and args rescaling per operation type -Result> GetDecimalBinaryOutput( - const std::string& op, const std::vector& values, - std::vector>* replaced = nullptr) { - const auto& left_type = checked_pointer_cast(values[0].type); - const auto& right_type = checked_pointer_cast(values[1].type); - - const int32_t p1 = left_type->precision(), s1 = left_type->scale(); - const int32_t p2 = right_type->precision(), s2 = right_type->scale(); +Status CastBinaryDecimalArgs(const std::string& func_name, + std::vector* values) { + auto& left_type = (*values)[0].type; + auto& right_type = (*values)[1].type; + DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); + + // decimal + float = float + if (is_floating(left_type->id())) { + right_type = left_type; + return Status::OK(); + } else if (is_floating(right_type->id())) { + left_type = right_type; + return Status::OK(); + } + + // precision, scale of left and right args + int32_t p1, s1, p2, s2; + + // decimal + integer = decimal + if (is_decimal(left_type->id())) { + auto decimal = checked_cast(left_type.get()); + p1 = decimal->precision(); + s1 = decimal->scale(); + } else { + DCHECK(is_integer(left_type->id())); + p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); + s1 = 0; + } + if (is_decimal(right_type->id())) { + auto decimal = checked_cast(right_type.get()); + p2 = decimal->precision(); + s2 = decimal->scale(); + } else { + DCHECK(is_integer(right_type->id())); + p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); + s2 = 0; + } if (s1 < 0 || s2 < 0) { return Status::NotImplemented("Decimals with negative scales not supported"); } - int32_t out_prec, out_scale; - int32_t left_scaleup = 0, right_scaleup = 0; + // decimal128 + decimal256 = decimal256 + Type::type casted_type_id = Type::DECIMAL128; + if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { + casted_type_id = Type::DECIMAL256; + } - // decimal upscaling behaviour references amazon redshift + // decimal promotion rules compatible with amazon redshift // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html - if (op.find("add") == 0 || op.find("subtract") == 0) { - out_scale = std::max(s1, s2); - out_prec = std::max(p1 - s1, p2 - s2) + 1 + out_scale; - left_scaleup = out_scale - s1; - right_scaleup = out_scale - s2; - } else if (op.find("multiply") == 0) { - out_scale = s1 + s2; - out_prec = p1 + p2 + 1; - } else if (op.find("divide") == 0) { - out_scale = std::max(4, s1 + p2 - s2 + 1); - out_prec = p1 - s1 + s2 + out_scale; // >= p1 + p2 + 1 - left_scaleup = out_prec - p1; + int32_t left_scaleup, right_scaleup; + + // "add_checked" -> "add" + const std::string op = func_name.substr(0, func_name.find("_")); + if (op == "add" || op == "subtract") { + left_scaleup = std::max(s1, s2) - s1; + right_scaleup = std::max(s1, s2) - s2; + } else if (op == "multiply") { + left_scaleup = right_scaleup = 0; + } else if (op == "divide") { + left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; + right_scaleup = 0; } else { - return Status::Invalid("Invalid decimal operation: ", op); + return Status::Invalid("Invalid decimal function: ", func_name); } - const auto id = left_type->id(); - auto make = [id](int32_t precision, int32_t scale) { - if (id == Type::DECIMAL128) { - return Decimal128Type::Make(precision, scale); - } else { - return Decimal256Type::Make(precision, scale); - } - }; + ARROW_ASSIGN_OR_RAISE( + left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); + ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, + s2 + right_scaleup)); + return Status::OK(); +} + +// resolve output decimal type per *casted* args +Result> ResolveBinaryDecimalOutput( + const std::string& func_name, const std::vector& values) { + // casted args should be same size decimals + auto left_type = checked_cast(values[0].type.get()); + auto right_type = checked_cast(values[1].type.get()); + DCHECK_EQ(left_type->id(), right_type->id()); + const Type::type out_type_id = left_type->id(); + + const int32_t p1 = left_type->precision(), s1 = left_type->scale(); + const int32_t p2 = right_type->precision(), s2 = right_type->scale(); + DCHECK(s1 >= 0 && s2 >= 0); - if (replaced) { - replaced->resize(2); - ARROW_ASSIGN_OR_RAISE((*replaced)[0], make(p1 + left_scaleup, s1 + left_scaleup)); - ARROW_ASSIGN_OR_RAISE((*replaced)[1], make(p2 + right_scaleup, s2 + right_scaleup)); + int32_t out_precision, out_scale; + + const std::string op = func_name.substr(0, func_name.find("_")); + if (op == "add" || op == "subtract") { + DCHECK_EQ(s1, s2); + out_scale = s1; + out_precision = std::max(p1 - s1, p2 - s2) + 1 + out_scale; + } else if (op == "multiply") { + out_scale = s1 + s2; + out_precision = p1 + p2 + 1; + } else if (op == "divide") { + DCHECK_GE(s1, s2); + out_scale = s1 - s2; + out_precision = p1; + } else { + return Status::Invalid("Invalid decimal function: ", func_name); } - return make(out_prec, out_scale); + return DecimalType::Make(out_type_id, out_precision, out_scale); } struct ArithmeticFunction : ScalarFunction { @@ -582,10 +636,7 @@ struct ArithmeticFunction : ScalarFunction { Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); - const auto type_id = (*values)[0].type->id(); - if (type_id == Type::DECIMAL128 || type_id == Type::DECIMAL256) { - return DispatchDecimal(values); - } + RETURN_NOT_OK(CheckDecimals(values)); using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -605,36 +656,40 @@ struct ArithmeticFunction : ScalarFunction { return arrow::compute::detail::NoMatchingKernel(this, *values); } - Result DispatchDecimal(std::vector* values) const { - if (values->size() == 2) { - std::vector> replaced; - RETURN_NOT_OK(GetDecimalBinaryOutput(name(), *values, &replaced)); - (*values)[0].type = std::move(replaced[0]); - (*values)[1].type = std::move(replaced[1]); + Status CheckDecimals(std::vector* values) const { + bool has_decimal = false; + for (const auto& value : *values) { + if (is_decimal(value.type->id())) { + has_decimal = true; + break; + } } + if (!has_decimal) return Status::OK(); - using arrow::compute::detail::DispatchExactImpl; - if (auto kernel = DispatchExactImpl(this, *values)) return kernel; - return arrow::compute::detail::NoMatchingKernel(this, *values); + if (values->size() == 2) { + return CastBinaryDecimalArgs(name(), values); + } + return Status::OK(); } }; // resolve decimal operation output type -struct DecimalBinaryOutputResolver { - std::string func_name; +struct BinaryDecimalOutputResolver { + const std::string func_name; - DecimalBinaryOutputResolver(std::string func_name) : func_name(std::move(func_name)) {} + explicit BinaryDecimalOutputResolver(std::string func_name) + : func_name(std::move(func_name)) {} Result operator()(KernelContext*, const std::vector& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, GetDecimalBinaryOutput(func_name, args)); - return ValueDescr(std::move(out_type)); + ARROW_ASSIGN_OR_RAISE(auto type, ResolveBinaryDecimalOutput(func_name, args)); + return ValueDescr(std::move(type), GetBroadcastShape(args)); } }; template void AddDecimalBinaryKernels(const std::string& name, std::shared_ptr* func) { - auto out_type = OutputType(DecimalBinaryOutputResolver(name)); + auto out_type = OutputType(BinaryDecimalOutputResolver(name)); auto in_type128 = InputType(Type::DECIMAL128); auto in_type256 = InputType(Type::DECIMAL256); auto exec128 = ScalarBinaryNotNullEqualTypes::Exec; diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 2475f785891..96f6a6bcd4c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1150,44 +1150,6 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) { class TestBinaryArithmeticDecimal : public TestBase { protected: - struct Arg { - std::shared_ptr type; - std::string value; - }; - - std::shared_ptr GetOutType(const std::string& op, - const std::shared_ptr& left_type, - const std::shared_ptr& right_type) { - auto left_decimal_type = std::static_pointer_cast(left_type); - auto right_decimal_type = std::static_pointer_cast(right_type); - - const int32_t p1 = left_decimal_type->precision(), s1 = left_decimal_type->scale(); - const int32_t p2 = right_decimal_type->precision(), s2 = right_decimal_type->scale(); - - // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html - int32_t precision, scale; - if (op == "add" || op == "subtract") { - scale = std::max(s1, s2); - precision = std::max(p1 - s1, p2 - s2) + 1 + scale; - } else if (op == "multiply") { - scale = s1 + s2; - precision = p1 + p2 + 1; - } else if (op == "divide") { - scale = std::max(4, s1 + p2 - s2 + 1); - precision = p1 - s1 + s2 + scale; - } else { - ABORT_NOT_OK(Status::Invalid("invalid binary operator: ", op)); - } - - std::shared_ptr type; - if (left_type->id() == Type::DECIMAL128) { - ASSIGN_OR_ABORT(type, Decimal128Type::Make(precision, scale)); - } else { - ASSIGN_OR_ABORT(type, Decimal256Type::Make(precision, scale)); - } - return type; - } - std::shared_ptr MakeScalar(const std::shared_ptr& type, const std::string& str) { std::shared_ptr scalar; @@ -1204,249 +1166,307 @@ class TestBinaryArithmeticDecimal : public TestBase { } return scalar; } +}; - Datum ToDatum(const std::shared_ptr& type, const std::string& value) { - if (value.find("[") == std::string::npos) { - return Datum(MakeScalar(type, value)); - } else { - return Datum(ArrayFromJSON(type, value)); - } +// reference result from bc (precsion=100, scale=40) +TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(30, 3), + R"([ + "1.000", + "-123456789012345678901234567.890", + "98765432109876543210.987", + "-999999999999999999999999999.999" + ])"); + auto right = ArrayFromJSON(decimal128(20, 9), + R"([ + "-1.000000000", + "12345678901.234567890", + "98765.432101234", + "-99999999999.999999999" + ])"); + auto added = ArrayFromJSON(decimal128(37, 9), + R"([ + "0.000000000", + "-123456789012345666555555666.655432110", + "98765432109876641976.419101234", + "-1000000000000000099999999999.998999999" + ])"); + auto subtracted = ArrayFromJSON(decimal128(37, 9), + R"([ + "2.000000000", + "-123456789012345691246913469.124567890", + "98765432109876444445.554898766", + "-999999999999999899999999999.999000001" + ])"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("subtract", left, right, subtracted); } - void Assert(const std::string& op, const Arg& left, const Arg& right, - const std::string& expected) { - const Datum arg0 = ToDatum(left.type, left.value); - const Datum arg1 = ToDatum(right.type, right.value); + // array array, decimal256 + { + auto left = ArrayFromJSON(decimal256(30, 20), + R"([ + "-1.00000000000000000001", + "1234567890.12345678900000000000", + "-9876543210.09876543210987654321", + "9999999999.99999999999999999999" + ])"); + auto right = ArrayFromJSON(decimal256(30, 10), + R"([ + "1.0000000000", + "-1234567890.1234567890", + "6789.5432101234", + "99999999999999999999.9999999999" + ])"); + auto added = ArrayFromJSON(decimal256(41, 20), + R"([ + "-0.00000000000000000001", + "0.00000000000000000000", + "-9876536420.55555530870987654321", + "100000000009999999999.99999999989999999999" + ])"); + auto subtracted = ArrayFromJSON(decimal256(41, 20), + R"([ + "-2.00000000000000000001", + "2469135780.24691357800000000000", + "-9876549999.64197555550987654321", + "-99999999989999999999.99999999990000000001" + ])"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("subtract", left, right, subtracted); + } - auto out_type = GetOutType(op, left.type, right.type); - CheckScalar(op, {arg0, arg1}, ToDatum(out_type, expected), &options_); + // scalar array + { + auto left = this->MakeScalar(decimal128(6, 1), "12345.6"); + auto right = ArrayFromJSON(decimal128(10, 3), + R"(["1.234", "1234.000", "-9876.543", "666.888"])"); + auto added = ArrayFromJSON(decimal128(11, 3), + R"(["12346.834", "13579.600", "2469.057", "13012.488"])"); + auto left_sub_right = ArrayFromJSON( + decimal128(11, 3), R"(["12344.366", "11111.600", "22222.143", "11678.712"])"); + auto right_sub_left = ArrayFromJSON( + decimal128(11, 3), R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("add", right, left, added); + CheckScalarBinary("subtract", left, right, left_sub_right); + CheckScalarBinary("subtract", right, left, right_sub_left); + } - // commutative operations - if (op == "add" || op == "multiply") { - CheckScalar(op, {arg1, arg0}, ToDatum(out_type, expected), &options_); - } + // scalar scalar + { + auto left = this->MakeScalar(decimal256(3, 0), "666"); + auto right = this->MakeScalar(decimal256(3, 0), "888"); + auto added = this->MakeScalar(decimal256(4, 0), "1554"); + auto subtracted = this->MakeScalar(decimal256(4, 0), "-222"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("subtract", left, right, subtracted); } - void AssertFail(const std::string& op, const Arg& left, const Arg& right) { - const Datum arg0 = ToDatum(left.type, left.value); - const Datum arg1 = ToDatum(right.type, right.value); + // decimal128 decimal256 + { + auto left = this->MakeScalar(decimal128(3, 0), "666"); + auto right = this->MakeScalar(decimal256(3, 0), "888"); + auto added = this->MakeScalar(decimal256(4, 0), "1554"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("add", right, left, added); + } - ASSERT_NOT_OK(CallFunction(op, {arg0, arg1}, &options_)); - if (op == "add" || op == "multiply") { - ASSERT_NOT_OK(CallFunction(op, {arg1, arg0}, &options_)); - } + // decimal float + { + auto left = this->MakeScalar(decimal128(3, 0), "666"); + ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888)); + ASSIGN_OR_ABORT(auto added, arrow::MakeScalar(float64(), 1554)); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("add", right, left, added); } - ArithmeticOptions options_ = ArithmeticOptions(); -}; + // TODO: decimal integer -// reference result from bc (precsion=100, scale=40) -TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { - Arg left, right; - std::string added, subtracted; + // failed case: result maybe overflow + { + std::shared_ptr left, right; + + left = this->MakeScalar(decimal128(21, 20), "0.12345678901234567890"); + right = this->MakeScalar(decimal128(21, 1), "1.0"); + ASSERT_RAISES(Invalid, CallFunction("add", {left, right})); + ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right})); + + left = this->MakeScalar(decimal256(75, 0), "0"); + right = this->MakeScalar(decimal256(2, 1), "0.0"); + ASSERT_RAISES(Invalid, CallFunction("add", {left, right})); + ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right})); + } +} +TEST_F(TestBinaryArithmeticDecimal, Multiply) { // array array, decimal128 - left = { - decimal128(30, 3), - R"([ - "1.000", - "-123456789012345678901234567.890", - "98765432109876543210.987", - "-999999999999999999999999999.999" - ])", - }; - right = { - decimal128(20, 9), - R"([ - "-1.000000000", - "12345678901.234567890", - "98765.432101234", - "-99999999999.999999999" - ])", - }; - added = R"([ - "0.000000000", - "-123456789012345666555555666.655432110", - "98765432109876641976.419101234", - "-1000000000000000099999999999.998999999" - ])"; - subtracted = R"([ - "2.000000000", - "-123456789012345691246913469.124567890", - "98765432109876444445.554898766", - "-999999999999999899999999999.999000001" - ])"; - this->Assert("add", left, right, added); - this->Assert("subtract", left, right, subtracted); + { + auto left = ArrayFromJSON(decimal128(20, 10), + R"([ + "1234567890.1234567890", + "-0.0000000001", + "-9999999999.9999999999" + ])"); + auto right = ArrayFromJSON(decimal128(13, 3), + R"([ + "1234567890.123", + "0.001", + "-9999999999.999" + ])"); + auto expected = ArrayFromJSON(decimal128(34, 13), + R"([ + "1524157875323319737.9870903950470", + "-0.0000000000001", + "99999999999989999999.0000000000001" + ])"); + CheckScalarBinary("multiply", left, right, expected); + } - // array array, decimal256 - left = { - decimal256(30, 20), - R"([ - "-1.00000000000000000001", - "1234567890.12345678900000000000", - "-9876543210.09876543210987654321", - "9999999999.99999999999999999999" - ])", - }; - right = { - decimal256(30, 10), - R"([ - "1.0000000000", - "-1234567890.1234567890", - "6789.5432101234", - "99999999999999999999.9999999999" - ])", - }; - added = R"([ - "-0.00000000000000000001", - "0.00000000000000000000", - "-9876536420.55555530870987654321", - "100000000009999999999.99999999989999999999" - ])"; - subtracted = R"([ - "-2.00000000000000000001", - "2469135780.24691357800000000000", - "-9876549999.64197555550987654321", - "-99999999989999999999.99999999990000000001" - ])"; - this->Assert("add", left, right, added); - this->Assert("subtract", left, right, subtracted); + // array array, decimal26 + { + auto left = ArrayFromJSON(decimal256(30, 3), + R"([ + "123456789012345678901234567.890", + "0.000" + ])"); + auto right = ArrayFromJSON(decimal256(20, 9), + R"([ + "-12345678901.234567890", + "99999999999.999999999" + ])"); + auto expected = ArrayFromJSON(decimal256(51, 12), + R"([ + "-1524157875323883675034293577501905199.875019052100", + "0.000000000000" + ])"); + CheckScalarBinary("multiply", left, right, expected); + } // scalar array - left = {decimal128(6, 1), "12345.6"}; - right = {decimal128(10, 3), R"(["1.234", "1234.000", "-9876.543", "666.888"])"}; - added = R"(["12346.834", "13579.600", "2469.057", "13012.488"])"; - subtracted = R"(["12344.366", "11111.600", "22222.143", "11678.712"])"; - this->Assert("add", left, right, added); - this->Assert("subtract", left, right, subtracted); - // right - left - subtracted = R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])"; - this->Assert("subtract", right, left, subtracted); + { + auto left = this->MakeScalar(decimal128(3, 2), "3.14"); + auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"); + auto expected = + ArrayFromJSON(decimal128(5, 2), R"(["3.14", "6.28", "9.42", "12.56", "15.70"])"); + CheckScalarBinary("multiply", left, right, expected); + CheckScalarBinary("multiply", right, left, expected); + } // scalar scalar - left = {decimal256(3, 0), "666"}; - right = {decimal256(3, 0), "888"}; - this->Assert("add", left, right, "1554"); - this->Assert("subtract", left, right, "-222"); - - // failed case: result *maybe* overflow - left = {decimal128(21, 20), "0.12345678901234567890"}; - right = {decimal128(21, 1), "1.0"}; - this->AssertFail("add", left, right); - this->AssertFail("subtract", left, right); - - left = {decimal256(75, 0), "0"}; - right = {decimal256(2, 1), "0.0"}; - this->AssertFail("add", left, right); - this->AssertFail("subtract", left, right); -} + { + auto left = this->MakeScalar(decimal128(1, 0), "1"); + auto right = this->MakeScalar(decimal128(1, 0), "1"); + auto expected = this->MakeScalar(decimal128(3, 0), "1"); + CheckScalarBinary("multiply", left, right, expected); + } -TEST_F(TestBinaryArithmeticDecimal, Multiply) { - Arg left, right; - std::string expected; - - // array array - left = { - decimal128(20, 10), - R"([ - "1234567890.1234567890", - "-0.0000000001", - "-9999999999.9999999999" - ])", - }; - right = { - decimal128(13, 3), - R"([ - "1234567890.123", - "0.001", - "-9999999999.999" - ])", - }; - expected = R"([ - "1524157875323319737.9870903950470", - "-0.0000000000001", - "99999999999989999999.0000000000001" - ])"; - this->Assert("multiply", left, right, expected); - - left = { - decimal256(30, 3), - R"([ - "123456789012345678901234567.890", - "0.000" - ])", - }; - right = { - decimal256(20, 9), - R"([ - "-12345678901.234567890", - "99999999999.999999999" - ])", - }; - expected = R"([ - "-1524157875323883675034293577501905199.875019052100", - "0.000000000000" - ])"; - this->Assert("multiply", left, right, expected); + // decimal128 decimal256 + { + auto left = this->MakeScalar(decimal128(3, 2), "6.66"); + auto right = this->MakeScalar(decimal256(3, 1), "88.8"); + auto expected = this->MakeScalar(decimal256(7, 3), "591.408"); + CheckScalarBinary("multiply", left, right, expected); + CheckScalarBinary("multiply", right, left, expected); + } - // scalar array - left = {decimal128(3, 2), "3.14"}; - right = {decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"}; - expected = R"(["3.14", "6.28", "9.42", "12.56", "15.70"])"; - this->Assert("multiply", left, right, expected); + // decimal float + { + auto left = this->MakeScalar(decimal128(3, 0), "666"); + ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888)); + ASSIGN_OR_ABORT(auto expected, arrow::MakeScalar(float64(), 591408)); + CheckScalarBinary("multiply", left, right, expected); + CheckScalarBinary("multiply", right, left, expected); + } - // scalar scalar - left = {decimal128(1, 0), "1"}; - right = {decimal128(1, 0), "1"}; - this->Assert("multiply", left, right, "1"); - - // failed case: result *maybe* overflow - left = {decimal128(20, 0), "1"}; - right = {decimal128(18, 1), "1.0"}; - this->AssertFail("multiply", left, right); + // TODO: decimal integer + + // failed case: result maybe overflow + { + auto left = this->MakeScalar(decimal128(20, 0), "1"); + auto right = this->MakeScalar(decimal128(18, 1), "1.0"); + ASSERT_RAISES(Invalid, CallFunction("multiply", {left, right})); + } } TEST_F(TestBinaryArithmeticDecimal, Divide) { - Arg left, right; - std::string expected; - - // array array - left = {decimal128(13, 3), R"(["1234567890.123", "0.001"])"}; - right = {decimal128(3, 0), R"(["-987", "999"])"}; - // scale = 7 - expected = R"(["-1250828.6627386", "0.0000010"])"; - this->Assert("divide", left, right, expected); - - left = {decimal256(20, 10), R"(["1234567890.1234567890", "9999999999.9999999999"])"}; - right = {decimal256(13, 3), R"(["1234567890.123", "0.001"])"}; - // scale = 21 - expected = R"(["1.000000000000369999093", "9999999999999.999999900000000000000"])"; - this->Assert("divide", left, right, expected); + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(13, 3), R"(["1234567890.123", "0.001"])"); + auto right = ArrayFromJSON(decimal128(3, 0), R"(["-987", "999"])"); + auto expected = + ArrayFromJSON(decimal128(17, 7), R"(["-1250828.6627386", "0.0000010"])"); + CheckScalarBinary("divide", left, right, expected); + } + + // array array, decimal256 + { + auto left = ArrayFromJSON(decimal256(20, 10), + R"(["1234567890.1234567890", "9999999999.9999999999"])"); + auto right = ArrayFromJSON(decimal256(13, 3), R"(["1234567890.123", "0.001"])"); + auto expected = ArrayFromJSON( + decimal256(34, 21), + R"(["1.000000000000369999093", "9999999999999.999999900000000000000"])"); + CheckScalarBinary("divide", left, right, expected); + } // scalar array - left = {decimal128(1, 0), "1"}; - right = {decimal128(1, 0), R"(["1", "2", "3", "4"])"}; - // scale = 4 - expected = R"(["1.0000", "0.5000", "0.3333", "0.2500"])"; - this->Assert("divide", left, right, expected); - // right / left - expected = R"(["1.0000", "2.0000", "3.0000", "4.0000"])"; - this->Assert("divide", right, left, expected); + { + auto left = this->MakeScalar(decimal128(1, 0), "1"); + auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4"])"); + auto left_div_right = + ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "0.5000", "0.3333", "0.2500"])"); + auto right_div_left = + ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "2.0000", "3.0000", "4.0000"])"); + CheckScalarBinary("divide", left, right, left_div_right); + CheckScalarBinary("divide", right, left, right_div_left); + } // scalar scalar - left = {decimal256(6, 5), "2.71828"}; - right = {decimal256(6, 5), "3.14159"}; - // scale = 7 - this->Assert("divide", left, right, "0.8652561"); + { + auto left = this->MakeScalar(decimal256(6, 5), "2.71828"); + auto right = this->MakeScalar(decimal256(6, 5), "3.14159"); + auto expected = this->MakeScalar(decimal256(13, 7), "0.8652561"); + CheckScalarBinary("divide", left, right, expected); + } - // failed case: result *maybe* overflow - left = {decimal128(20, 20), "0.12345678901234567890"}; - right = {decimal128(20, 0), "12345678901234567890"}; - this->AssertFail("divide", left, right); + // decimal128 decimal256 + { + auto left = this->MakeScalar(decimal256(6, 5), "2.71828"); + auto right = this->MakeScalar(decimal128(6, 5), "3.14159"); + auto left_div_right = this->MakeScalar(decimal256(13, 7), "0.8652561"); + auto right_div_left = this->MakeScalar(decimal256(13, 7), "1.1557271"); + CheckScalarBinary("divide", left, right, left_div_right); + CheckScalarBinary("divide", right, left, right_div_left); + } + + // decimal float + { + auto left = this->MakeScalar(decimal128(3, 0), "100"); + ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 50)); + ASSIGN_OR_ABORT(auto left_div_right, arrow::MakeScalar(float64(), 2)); + ASSIGN_OR_ABORT(auto right_div_left, arrow::MakeScalar(float64(), 0.5)); + CheckScalarBinary("divide", left, right, left_div_right); + CheckScalarBinary("divide", right, left, right_div_left); + } + + // TODO: decimal integer + + // failed case: result maybe overflow + { + auto left = this->MakeScalar(decimal128(20, 20), "0.12345678901234567890"); + auto right = this->MakeScalar(decimal128(20, 0), "12345678901234567890"); + ASSERT_RAISES(Invalid, CallFunction("divide", {left, right})); + } // failed case: divide by 0 - this->AssertFail("divide", {decimal256(1, 0), "0"}, {decimal256(1, 0), "0"}); + { + auto left = this->MakeScalar(decimal256(1, 0), "1"); + auto right = this->MakeScalar(decimal256(1, 0), "0"); + ASSERT_RAISES(Invalid, CallFunction("divide", {left, right})); + } } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index cf37b33c006..18257973150 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -169,22 +169,6 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, } } -void CheckScalar(std::string func_name, const std::vector& inputs, - const Datum& expected, const FunctionOptions* options) { - ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); - ASSERT_EQ(out.kind(), expected.kind()); - if (out.kind() == Datum::ARRAY) { - std::shared_ptr actual = out.make_array(); - ASSERT_OK(actual->ValidateFull()); - AssertArraysEqual(*expected.make_array(), *actual, /*verbose=*/true); - } else if (out.kind() == Datum::SCALAR) { - std::shared_ptr actual = out.scalar(); - AssertScalarsEqual(*expected.scalar(), *actual, /*verbose=*/true); - } else { - ASSERT_EQ(out, expected); - } -} - void CheckScalarUnary(std::string func_name, std::shared_ptr input, std::shared_ptr expected, const FunctionOptions* options) { ArrayVector input_vector = {input}; diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 4bfce3d360f..85ed04c183a 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -99,9 +99,6 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); -void CheckScalar(std::string func_name, const std::vector& inputs, - const Datum& expected, const FunctionOptions* options = nullptr); - void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 344585446fc..65c783ce847 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -771,6 +771,17 @@ std::vector> StructType::GetAllFieldsByName( return result; } +Result> DecimalType::Make(Type::type type_id, int32_t precision, + int32_t scale) { + if (type_id == Type::DECIMAL128) { + return Decimal128Type::Make(precision, scale); + } else if (type_id == Type::DECIMAL256) { + return Decimal256Type::Make(precision, scale); + } else { + return Status::Invalid("Not a decimal type_id: ", type_id); + } +} + // Taken from the Apache Impala codebase. The comments next // to the return values are the maximum value that can be represented in 2's // complement with the returned number of bytes. diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 1d3d1e27f92..b933da66089 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -880,6 +880,10 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { int32_t scale) : FixedSizeBinaryType(byte_width, type_id), precision_(precision), scale_(scale) {} + /// Constructs concrete decimal types + static Result> Make(Type::type type_id, int32_t precision, + int32_t scale); + int32_t precision() const { return precision_; } int32_t scale() const { return scale_; } diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index b74aa3b0adb..86664bbb162 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -845,6 +845,17 @@ static inline bool is_floating(Type::type type_id) { return false; } +static inline bool is_decimal(Type::type type_id) { + switch (type_id) { + case Type::DECIMAL128: + case Type::DECIMAL256: + return true; + default: + break; + } + return false; +} + static inline bool is_primitive(Type::type type_id) { switch (type_id) { case Type::BOOL: diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 0a8c685e917..147885560f5 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -306,8 +306,11 @@ an ``Invalid`` :class:`Status` when overflow is detected. | | | precision = p1 - s1 + s2 + scale | +------------+---------------------------------------------+ -Decimal overflow is checked before calculation. Error is returned if the result -precision is beyond the decimal range. +It's compatible with Redshift's decimal promotion rules. All decimal digits +are preserved for `add`, `subtract` and `multiply` operations. The result +precision of `divide` is at least the sum of precisions of both operands with +enough scale kept. Error is returned if the result precision is beyond the +decimal value range. Comparisons ~~~~~~~~~~~ From 9fd0cca552c9b2aaa0c73209792deb72dbbf269b Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Fri, 18 Jun 2021 06:24:09 +0000 Subject: [PATCH 3/4] address review comments #2 --- cpp/src/arrow/compute/kernel.h | 3 + .../compute/kernels/scalar_arithmetic.cc | 93 ++++++----- .../compute/kernels/scalar_arithmetic_test.cc | 154 +++++++++++------- 3 files changed, 154 insertions(+), 96 deletions(-) diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 0d5fa147727..f8d15952e73 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -321,6 +321,9 @@ class ARROW_EXPORT OutputType { this->resolver_ = other.resolver_; } + OutputType& operator=(const OutputType&) = default; + OutputType& operator=(OutputType&&) = default; + /// \brief Return the shape and type of the expected output value of the /// kernel given the value descriptors (shapes and types) of the input /// arguments. The resolver may make use of state information kept in the diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index a5199f680ac..dac1ae1aec0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "arrow/compute/kernels/common.h" #include "arrow/type_traits.h" @@ -596,38 +597,52 @@ Status CastBinaryDecimalArgs(const std::string& func_name, return Status::OK(); } -// resolve output decimal type per *casted* args -Result> ResolveBinaryDecimalOutput( - const std::string& func_name, const std::vector& values) { +// resolve decimal binary operation output type per *casted* args +template +Result ResolveDecimalBinaryOperationOutput( + const std::vector& args, OutputGetter&& getter) { // casted args should be same size decimals - auto left_type = checked_cast(values[0].type.get()); - auto right_type = checked_cast(values[1].type.get()); + auto left_type = checked_cast(args[0].type.get()); + auto right_type = checked_cast(args[1].type.get()); DCHECK_EQ(left_type->id(), right_type->id()); - const Type::type out_type_id = left_type->id(); - const int32_t p1 = left_type->precision(), s1 = left_type->scale(); - const int32_t p2 = right_type->precision(), s2 = right_type->scale(); - DCHECK(s1 >= 0 && s2 >= 0); + int32_t precision, scale; + std::tie(precision, scale) = getter(left_type->precision(), left_type->scale(), + right_type->precision(), right_type->scale()); + ARROW_ASSIGN_OR_RAISE(auto type, DecimalType::Make(left_type->id(), precision, scale)); + return ValueDescr(std::move(type), GetBroadcastShape(args)); +} - int32_t out_precision, out_scale; +Result ResolveDecimalAdditionOrSubtractionOutput( + KernelContext*, const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + DCHECK_EQ(s1, s2); + const int32_t scale = s1; + const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1; + return std::make_pair(precision, scale); + }); +} - const std::string op = func_name.substr(0, func_name.find("_")); - if (op == "add" || op == "subtract") { - DCHECK_EQ(s1, s2); - out_scale = s1; - out_precision = std::max(p1 - s1, p2 - s2) + 1 + out_scale; - } else if (op == "multiply") { - out_scale = s1 + s2; - out_precision = p1 + p2 + 1; - } else if (op == "divide") { - DCHECK_GE(s1, s2); - out_scale = s1 - s2; - out_precision = p1; - } else { - return Status::Invalid("Invalid decimal function: ", func_name); - } +Result ResolveDecimalMultiplicationOutput( + KernelContext*, const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + const int32_t scale = s1 + s2; + const int32_t precision = p1 + p2 + 1; + return std::make_pair(precision, scale); + }); +} - return DecimalType::Make(out_type_id, out_precision, out_scale); +Result ResolveDecimalDivisionOutput(KernelContext*, + const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + DCHECK_GE(s1, s2); + const int32_t scale = s1 - s2; + const int32_t precision = p1; + return std::make_pair(precision, scale); + }); } struct ArithmeticFunction : ScalarFunction { @@ -673,23 +688,21 @@ struct ArithmeticFunction : ScalarFunction { } }; -// resolve decimal operation output type -struct BinaryDecimalOutputResolver { - const std::string func_name; - - explicit BinaryDecimalOutputResolver(std::string func_name) - : func_name(std::move(func_name)) {} - - Result operator()(KernelContext*, const std::vector& args) { - ARROW_ASSIGN_OR_RAISE(auto type, ResolveBinaryDecimalOutput(func_name, args)); - return ValueDescr(std::move(type), GetBroadcastShape(args)); - } -}; - template void AddDecimalBinaryKernels(const std::string& name, std::shared_ptr* func) { - auto out_type = OutputType(BinaryDecimalOutputResolver(name)); + OutputType out_type(null()); + const std::string op = name.substr(0, name.find("_")); + if (op == "add" || op == "subtract") { + out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput); + } else if (op == "multiply") { + out_type = OutputType(ResolveDecimalMultiplicationOutput); + } else if (op == "divide") { + out_type = OutputType(ResolveDecimalDivisionOutput); + } else { + DCHECK(false); + } + auto in_type128 = InputType(Type::DECIMAL128); auto in_type256 = InputType(Type::DECIMAL256); auto exec128 = ScalarBinaryNotNullEqualTypes::Exec; diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 96f6a6bcd4c..3ee862c834e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1148,28 +1148,70 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) { } } -class TestBinaryArithmeticDecimal : public TestBase { - protected: - std::shared_ptr MakeScalar(const std::shared_ptr& type, - const std::string& str) { - std::shared_ptr scalar; - if (type->id() == Type::DECIMAL128) { - Decimal128 value; - int32_t dummy; - ABORT_NOT_OK(Decimal128::FromString(str, &value, &dummy)); - ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value)); - } else { - Decimal256 value; - int32_t dummy; - ABORT_NOT_OK(Decimal256::FromString(str, &value, &dummy)); - ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value)); +TEST(TestBinaryDecimalArithmetic, DispatchBest) { + // decimal, floating point + for (std::string name : {"add", "subtract", "multiply", "divide"}) { + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(1, 0), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {decimal256(1, 0), float64()}, {float64(), float64()}); + CheckDispatchBest(name, {float32(), decimal256(1, 0)}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), decimal128(1, 0)}, {float64(), float64()}); } - return scalar; } -}; + + // decimal, decimal + for (std::string name : {"add", "subtract"}) { + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(3, 1), decimal128(3, 1)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(3, 1), decimal256(3, 1)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(3, 1), decimal256(3, 1)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(3, 1), decimal256(3, 1)}); + } + } + { + std::string name = "multiply"; + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(5, 2), decimal128(5, 2)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(5, 2), decimal256(5, 2)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(5, 2), decimal256(5, 2)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(5, 2), decimal256(5, 2)}); + } + } + { + std::string name = "divide"; + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(6, 4), decimal128(6, 4)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(6, 4), decimal256(6, 4)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(6, 4), decimal256(6, 4)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(6, 4), decimal256(6, 4)}); + } + } + + // TODO(ARROW-13067): add 'integer, decimal' tests +} // reference result from bc (precsion=100, scale=40) -TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { +TEST(TestBinaryArithmeticDecimal, AddSubtract) { // array array, decimal128 { auto left = ArrayFromJSON(decimal128(30, 3), @@ -1240,7 +1282,7 @@ TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { // scalar array { - auto left = this->MakeScalar(decimal128(6, 1), "12345.6"); + auto left = ScalarFromJSON(decimal128(6, 1), R"("12345.6")"); auto right = ArrayFromJSON(decimal128(10, 3), R"(["1.234", "1234.000", "-9876.543", "666.888"])"); auto added = ArrayFromJSON(decimal128(11, 3), @@ -1257,26 +1299,26 @@ TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { // scalar scalar { - auto left = this->MakeScalar(decimal256(3, 0), "666"); - auto right = this->MakeScalar(decimal256(3, 0), "888"); - auto added = this->MakeScalar(decimal256(4, 0), "1554"); - auto subtracted = this->MakeScalar(decimal256(4, 0), "-222"); + auto left = ScalarFromJSON(decimal256(3, 0), R"("666")"); + auto right = ScalarFromJSON(decimal256(3, 0), R"("888")"); + auto added = ScalarFromJSON(decimal256(4, 0), R"("1554")"); + auto subtracted = ScalarFromJSON(decimal256(4, 0), R"("-222")"); CheckScalarBinary("add", left, right, added); CheckScalarBinary("subtract", left, right, subtracted); } // decimal128 decimal256 { - auto left = this->MakeScalar(decimal128(3, 0), "666"); - auto right = this->MakeScalar(decimal256(3, 0), "888"); - auto added = this->MakeScalar(decimal256(4, 0), "1554"); + auto left = ScalarFromJSON(decimal128(3, 0), R"("666")"); + auto right = ScalarFromJSON(decimal256(3, 0), R"("888")"); + auto added = ScalarFromJSON(decimal256(4, 0), R"("1554")"); CheckScalarBinary("add", left, right, added); CheckScalarBinary("add", right, left, added); } // decimal float { - auto left = this->MakeScalar(decimal128(3, 0), "666"); + auto left = ScalarFromJSON(decimal128(3, 0), R"("666")"); ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888)); ASSIGN_OR_ABORT(auto added, arrow::MakeScalar(float64(), 1554)); CheckScalarBinary("add", left, right, added); @@ -1289,19 +1331,19 @@ TEST_F(TestBinaryArithmeticDecimal, AddSubtract) { { std::shared_ptr left, right; - left = this->MakeScalar(decimal128(21, 20), "0.12345678901234567890"); - right = this->MakeScalar(decimal128(21, 1), "1.0"); + left = ScalarFromJSON(decimal128(21, 20), R"("0.12345678901234567890")"); + right = ScalarFromJSON(decimal128(21, 1), R"("1.0")"); ASSERT_RAISES(Invalid, CallFunction("add", {left, right})); ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right})); - left = this->MakeScalar(decimal256(75, 0), "0"); - right = this->MakeScalar(decimal256(2, 1), "0.0"); + left = ScalarFromJSON(decimal256(75, 0), R"("0")"); + right = ScalarFromJSON(decimal256(2, 1), R"("0.0")"); ASSERT_RAISES(Invalid, CallFunction("add", {left, right})); ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right})); } } -TEST_F(TestBinaryArithmeticDecimal, Multiply) { +TEST(TestBinaryArithmeticDecimal, Multiply) { // array array, decimal128 { auto left = ArrayFromJSON(decimal128(20, 10), @@ -1347,7 +1389,7 @@ TEST_F(TestBinaryArithmeticDecimal, Multiply) { // scalar array { - auto left = this->MakeScalar(decimal128(3, 2), "3.14"); + auto left = ScalarFromJSON(decimal128(3, 2), R"("3.14")"); auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"); auto expected = ArrayFromJSON(decimal128(5, 2), R"(["3.14", "6.28", "9.42", "12.56", "15.70"])"); @@ -1357,24 +1399,24 @@ TEST_F(TestBinaryArithmeticDecimal, Multiply) { // scalar scalar { - auto left = this->MakeScalar(decimal128(1, 0), "1"); - auto right = this->MakeScalar(decimal128(1, 0), "1"); - auto expected = this->MakeScalar(decimal128(3, 0), "1"); + auto left = ScalarFromJSON(decimal128(1, 0), R"("1")"); + auto right = ScalarFromJSON(decimal128(1, 0), R"("1")"); + auto expected = ScalarFromJSON(decimal128(3, 0), R"("1")"); CheckScalarBinary("multiply", left, right, expected); } // decimal128 decimal256 { - auto left = this->MakeScalar(decimal128(3, 2), "6.66"); - auto right = this->MakeScalar(decimal256(3, 1), "88.8"); - auto expected = this->MakeScalar(decimal256(7, 3), "591.408"); + auto left = ScalarFromJSON(decimal128(3, 2), R"("6.66")"); + auto right = ScalarFromJSON(decimal256(3, 1), R"("88.8")"); + auto expected = ScalarFromJSON(decimal256(7, 3), R"("591.408")"); CheckScalarBinary("multiply", left, right, expected); CheckScalarBinary("multiply", right, left, expected); } // decimal float { - auto left = this->MakeScalar(decimal128(3, 0), "666"); + auto left = ScalarFromJSON(decimal128(3, 0), R"("666")"); ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888)); ASSIGN_OR_ABORT(auto expected, arrow::MakeScalar(float64(), 591408)); CheckScalarBinary("multiply", left, right, expected); @@ -1385,13 +1427,13 @@ TEST_F(TestBinaryArithmeticDecimal, Multiply) { // failed case: result maybe overflow { - auto left = this->MakeScalar(decimal128(20, 0), "1"); - auto right = this->MakeScalar(decimal128(18, 1), "1.0"); + auto left = ScalarFromJSON(decimal128(20, 0), R"("1")"); + auto right = ScalarFromJSON(decimal128(18, 1), R"("1.0")"); ASSERT_RAISES(Invalid, CallFunction("multiply", {left, right})); } } -TEST_F(TestBinaryArithmeticDecimal, Divide) { +TEST(TestBinaryArithmeticDecimal, Divide) { // array array, decimal128 { auto left = ArrayFromJSON(decimal128(13, 3), R"(["1234567890.123", "0.001"])"); @@ -1414,7 +1456,7 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) { // scalar array { - auto left = this->MakeScalar(decimal128(1, 0), "1"); + auto left = ScalarFromJSON(decimal128(1, 0), R"("1")"); auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4"])"); auto left_div_right = ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "0.5000", "0.3333", "0.2500"])"); @@ -1426,25 +1468,25 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) { // scalar scalar { - auto left = this->MakeScalar(decimal256(6, 5), "2.71828"); - auto right = this->MakeScalar(decimal256(6, 5), "3.14159"); - auto expected = this->MakeScalar(decimal256(13, 7), "0.8652561"); + auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")"); + auto right = ScalarFromJSON(decimal256(6, 5), R"("3.14159")"); + auto expected = ScalarFromJSON(decimal256(13, 7), R"("0.8652561")"); CheckScalarBinary("divide", left, right, expected); } // decimal128 decimal256 { - auto left = this->MakeScalar(decimal256(6, 5), "2.71828"); - auto right = this->MakeScalar(decimal128(6, 5), "3.14159"); - auto left_div_right = this->MakeScalar(decimal256(13, 7), "0.8652561"); - auto right_div_left = this->MakeScalar(decimal256(13, 7), "1.1557271"); + auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")"); + auto right = ScalarFromJSON(decimal128(6, 5), R"("3.14159")"); + auto left_div_right = ScalarFromJSON(decimal256(13, 7), R"("0.8652561")"); + auto right_div_left = ScalarFromJSON(decimal256(13, 7), R"("1.1557271")"); CheckScalarBinary("divide", left, right, left_div_right); CheckScalarBinary("divide", right, left, right_div_left); } // decimal float { - auto left = this->MakeScalar(decimal128(3, 0), "100"); + auto left = ScalarFromJSON(decimal128(3, 0), R"("100")"); ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 50)); ASSIGN_OR_ABORT(auto left_div_right, arrow::MakeScalar(float64(), 2)); ASSIGN_OR_ABORT(auto right_div_left, arrow::MakeScalar(float64(), 0.5)); @@ -1456,15 +1498,15 @@ TEST_F(TestBinaryArithmeticDecimal, Divide) { // failed case: result maybe overflow { - auto left = this->MakeScalar(decimal128(20, 20), "0.12345678901234567890"); - auto right = this->MakeScalar(decimal128(20, 0), "12345678901234567890"); + auto left = ScalarFromJSON(decimal128(20, 20), R"("0.12345678901234567890")"); + auto right = ScalarFromJSON(decimal128(20, 0), R"("12345678901234567890")"); ASSERT_RAISES(Invalid, CallFunction("divide", {left, right})); } // failed case: divide by 0 { - auto left = this->MakeScalar(decimal256(1, 0), "1"); - auto right = this->MakeScalar(decimal256(1, 0), "0"); + auto left = ScalarFromJSON(decimal256(1, 0), R"("1")"); + auto right = ScalarFromJSON(decimal256(1, 0), R"("0")"); ASSERT_RAISES(Invalid, CallFunction("divide", {left, right})); } } From dd38397cbedd0b6b543172f519f93768761334a8 Mon Sep 17 00:00:00 2001 From: Yibo Cai Date: Fri, 18 Jun 2021 10:11:38 +0000 Subject: [PATCH 4/4] add only supported decimal kernels --- .../compute/kernels/scalar_arithmetic.cc | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index dac1ae1aec0..f51484e53ff 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -90,12 +90,6 @@ struct AbsoluteValue { static constexpr enable_if_signed_integer Call(KernelContext*, T arg, Status* st) { return (arg < 0) ? arrow::internal::SafeSignedNegate(arg) : arg; } - - template - static enable_if_decimal Call(KernelContext*, T arg, Status* st) { - *st = Status::NotImplemented("NYI"); - return T(); - } }; struct AbsoluteValueChecked { @@ -120,12 +114,6 @@ struct AbsoluteValueChecked { static_assert(std::is_same::value, ""); return std::fabs(arg); } - - template - static enable_if_decimal Call(KernelContext*, Arg arg, Status* st) { - *st = Status::NotImplemented("NYI"); - return T(); - } }; struct Add { @@ -355,13 +343,9 @@ struct DivideChecked { } template - static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { - if (right == Arg1()) { - *st = Status::Invalid("Divide by zero"); - return T(); - } else { - return left / right; - } + static enable_if_decimal Call(KernelContext* ctx, Arg0 left, Arg1 right, + Status* st) { + return Divide::Call(ctx, left, right, st); } }; @@ -380,12 +364,6 @@ struct Negate { static constexpr enable_if_signed_integer Call(KernelContext*, Arg arg, Status*) { return arrow::internal::SafeSignedNegate(arg); } - - template - static enable_if_decimal Call(KernelContext*, Arg arg, Status* st) { - *st = Status::NotImplemented("NYI"); - return T(); - } }; struct NegateChecked { @@ -412,12 +390,6 @@ struct NegateChecked { static_assert(std::is_same::value, ""); return -arg; } - - template - static enable_if_decimal Call(KernelContext*, Arg arg, Status* st) { - *st = Status::NotImplemented("NYI"); - return T(); - } }; struct Power { @@ -446,12 +418,6 @@ struct Power { static enable_if_floating_point Call(KernelContext*, T base, T exp, Status*) { return std::pow(base, exp); } - - template - static enable_if_decimal Call(KernelContext*, Arg0 base, Arg1 exp, Status* st) { - *st = Status::NotImplemented("NYI"); - return T(); - } }; struct PowerChecked { @@ -486,12 +452,6 @@ struct PowerChecked { static_assert(std::is_same::value && std::is_same::value, ""); return std::pow(base, exp); } - - template - static enable_if_decimal Call(KernelContext*, Arg0 base, Arg1 exp, Status* st) { - *st = Status::NotImplemented("NYI"); - return T(); - } }; // Generate a kernel given an arithmetic functor @@ -645,6 +605,29 @@ Result ResolveDecimalDivisionOutput(KernelContext*, }); } +template +void AddDecimalBinaryKernels(const std::string& name, + std::shared_ptr* func) { + OutputType out_type(null()); + const std::string op = name.substr(0, name.find("_")); + if (op == "add" || op == "subtract") { + out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput); + } else if (op == "multiply") { + out_type = OutputType(ResolveDecimalMultiplicationOutput); + } else if (op == "divide") { + out_type = OutputType(ResolveDecimalDivisionOutput); + } else { + DCHECK(false); + } + + auto in_type128 = InputType(Type::DECIMAL128); + auto in_type256 = InputType(Type::DECIMAL256); + auto exec128 = ScalarBinaryNotNullEqualTypes::Exec; + auto exec256 = ScalarBinaryNotNullEqualTypes::Exec; + DCHECK_OK((*func)->AddKernel({in_type128, in_type128}, out_type, exec128)); + DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256)); +} + struct ArithmeticFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -688,29 +671,6 @@ struct ArithmeticFunction : ScalarFunction { } }; -template -void AddDecimalBinaryKernels(const std::string& name, - std::shared_ptr* func) { - OutputType out_type(null()); - const std::string op = name.substr(0, name.find("_")); - if (op == "add" || op == "subtract") { - out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput); - } else if (op == "multiply") { - out_type = OutputType(ResolveDecimalMultiplicationOutput); - } else if (op == "divide") { - out_type = OutputType(ResolveDecimalDivisionOutput); - } else { - DCHECK(false); - } - - auto in_type128 = InputType(Type::DECIMAL128); - auto in_type256 = InputType(Type::DECIMAL256); - auto exec128 = ScalarBinaryNotNullEqualTypes::Exec; - auto exec256 = ScalarBinaryNotNullEqualTypes::Exec; - DCHECK_OK((*func)->AddKernel({in_type128, in_type128}, out_type, exec128)); - DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256)); -} - template std::shared_ptr MakeArithmeticFunction(std::string name, const FunctionDoc* doc) { @@ -719,8 +679,6 @@ std::shared_ptr MakeArithmeticFunction(std::string name, auto exec = ArithmeticExecFromOp(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); } - AddDecimalBinaryKernels(name, &func); - return func; } @@ -734,8 +692,6 @@ std::shared_ptr MakeArithmeticFunctionNotNull(std::string name, auto exec = ArithmeticExecFromOp(ty); DCHECK_OK(func->AddKernel({ty, ty}, ty, exec)); } - AddDecimalBinaryKernels(name, &func); - return func; } @@ -879,16 +835,19 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto add = MakeArithmeticFunction("add", &add_doc); + AddDecimalBinaryKernels("add", &add); DCHECK_OK(registry->AddFunction(std::move(add))); // ---------------------------------------------------------------------- auto add_checked = MakeArithmeticFunctionNotNull("add_checked", &add_checked_doc); + AddDecimalBinaryKernels("add_checked", &add_checked); DCHECK_OK(registry->AddFunction(std::move(add_checked))); // ---------------------------------------------------------------------- // subtract auto subtract = MakeArithmeticFunction("subtract", &sub_doc); + AddDecimalBinaryKernels("subtract", &subtract); // Add subtract(timestamp, timestamp) -> duration for (auto unit : AllTimeUnits()) { @@ -902,24 +861,29 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto subtract_checked = MakeArithmeticFunctionNotNull( "subtract_checked", &sub_checked_doc); + AddDecimalBinaryKernels("subtract_checked", &subtract_checked); DCHECK_OK(registry->AddFunction(std::move(subtract_checked))); // ---------------------------------------------------------------------- auto multiply = MakeArithmeticFunction("multiply", &mul_doc); + AddDecimalBinaryKernels("multiply", &multiply); DCHECK_OK(registry->AddFunction(std::move(multiply))); // ---------------------------------------------------------------------- auto multiply_checked = MakeArithmeticFunctionNotNull( "multiply_checked", &mul_checked_doc); + AddDecimalBinaryKernels("multiply_checked", &multiply_checked); DCHECK_OK(registry->AddFunction(std::move(multiply_checked))); // ---------------------------------------------------------------------- auto divide = MakeArithmeticFunctionNotNull("divide", &div_doc); + AddDecimalBinaryKernels("divide", ÷); DCHECK_OK(registry->AddFunction(std::move(divide))); // ---------------------------------------------------------------------- auto divide_checked = MakeArithmeticFunctionNotNull("divide_checked", &div_checked_doc); + AddDecimalBinaryKernels("divide_checked", ÷_checked); DCHECK_OK(registry->AddFunction(std::move(divide_checked))); // ----------------------------------------------------------------------