diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 0d5fa147727..f8d15952e73 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -321,6 +321,9 @@ class ARROW_EXPORT OutputType { this->resolver_ = other.resolver_; } + OutputType& operator=(const OutputType&) = default; + OutputType& operator=(OutputType&&) = default; + /// \brief Return the shape and type of the expected output value of the /// kernel given the value descriptors (shapes and types) of the input /// arguments. The resolver may make use of state information kept in the diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 891f90a97d4..6a5cee124c0 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -149,6 +149,8 @@ struct GetViewType { static T LogicalValue(PhysicalType value) { return Decimal128(reinterpret_cast(value.data())); } + + static T LogicalValue(T value) { return value; } }; template <> @@ -159,6 +161,8 @@ struct GetViewType { static T LogicalValue(PhysicalType value) { return Decimal256(reinterpret_cast(value.data())); } + + static T LogicalValue(T value) { return value; } }; template @@ -243,6 +247,18 @@ struct ArrayIterator> { } }; +template +struct ArrayIterator> { + using T = typename TypeTraits::ScalarType::ValueType; + using endian_agnostic = std::array; + const endian_agnostic* values; + + explicit ArrayIterator(const ArrayData& data) + : values(data.GetValues(1)) {} + + T operator()() { return T{values++->data()}; } +}; + // Iterator over various output array types, taking a GetOutputType template @@ -262,6 +278,20 @@ struct OutputArrayWriter> { void WriteNull() { *values++ = T{}; } }; +template +struct OutputArrayWriter> { + using T = typename TypeTraits::ScalarType::ValueType; + using endian_agnostic = std::array; + endian_agnostic* values; + + explicit OutputArrayWriter(ArrayData* data) + : values(data->GetMutableValues(1)) {} + + void Write(T value) { value.ToBytes(values++->data()); } + + void WriteNull() { T{}.ToBytes(values++->data()); } +}; + // (Un)box Scalar to / from C++ value template @@ -538,6 +568,22 @@ struct OutputAdapter> { } }; +template +struct OutputAdapter> { + using T = typename TypeTraits::ScalarType::ValueType; + using endian_agnostic = std::array; + + template + static Status Write(KernelContext*, Datum* out, Generator&& generator) { + ArrayData* out_arr = out->mutable_array(); + auto out_data = out_arr->GetMutableValues(1); + for (int64_t i = 0; i < out_arr->length; ++i) { + generator().ToBytes(out_data++->data()); + } + return Status::OK(); + } +}; + // A kernel exec generator for unary functions that addresses both array and // scalar inputs and dispatches input iteration and output writing to other // templates diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 743d2e3fc0e..f51484e53ff 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +#include #include #include +#include #include "arrow/compute/kernels/common.h" #include "arrow/type_traits.h" +#include "arrow/util/decimal.h" #include "arrow/util/int_util_internal.h" #include "arrow/util/macros.h" @@ -62,6 +65,11 @@ using enable_if_integer = template using enable_if_floating_point = enable_if_t::value, T>; +template +using enable_if_decimal = + enable_if_t::value || std::is_same::value, + T>; + template ::type> constexpr Unsigned to_unsigned(T signed_) { return static_cast(signed_); @@ -126,11 +134,16 @@ struct Add { Status*) { return arrow::internal::SafeSignedAdd(left, right); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + right; + } }; struct AddChecked { template - enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + static enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(AddWithOverflow(left, right, &result))) { @@ -140,10 +153,16 @@ struct AddChecked { } template - enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { static_assert(std::is_same::value && std::is_same::value, ""); return left + right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + right; + } }; struct Subtract { @@ -164,11 +183,16 @@ struct Subtract { Status*) { return arrow::internal::SafeSignedSubtract(left, right); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + (-right); + } }; struct SubtractChecked { template - enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + static enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, right, &result))) { @@ -178,10 +202,16 @@ struct SubtractChecked { } template - enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { static_assert(std::is_same::value && std::is_same::value, ""); return left - right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left + (-right); + } }; struct Multiply { @@ -224,11 +254,16 @@ struct Multiply { static constexpr uint16_t Call(KernelContext*, uint16_t left, uint16_t right, Status*) { return static_cast(left) * static_cast(right); } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left * right; + } }; struct MultiplyChecked { template - enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + static enable_if_integer Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { static_assert(std::is_same::value && std::is_same::value, ""); T result = 0; if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, right, &result))) { @@ -238,10 +273,16 @@ struct MultiplyChecked { } template - enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, + Status*) { static_assert(std::is_same::value && std::is_same::value, ""); return left * right; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + return left * right; + } }; struct Divide { @@ -263,6 +304,16 @@ struct Divide { } return result; } + + template + static enable_if_decimal Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + if (right == Arg1()) { + *st = Status::Invalid("Divide by zero"); + return T(); + } else { + return left / right; + } + } }; struct DivideChecked { @@ -290,6 +341,12 @@ struct DivideChecked { } return left / right; } + + template + static enable_if_decimal Call(KernelContext* ctx, Arg0 left, Arg1 right, + Status* st) { + return Divide::Call(ctx, left, right, st); + } }; struct Negate { @@ -304,7 +361,7 @@ struct Negate { } template - static constexpr enable_if_signed_integer Call(KernelContext*, Arg arg, Status* st) { + static constexpr enable_if_signed_integer Call(KernelContext*, Arg arg, Status*) { return arrow::internal::SafeSignedNegate(arg); } }; @@ -428,12 +485,157 @@ ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) { } } +Status CastBinaryDecimalArgs(const std::string& func_name, + std::vector* values) { + auto& left_type = (*values)[0].type; + auto& right_type = (*values)[1].type; + DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id())); + + // decimal + float = float + if (is_floating(left_type->id())) { + right_type = left_type; + return Status::OK(); + } else if (is_floating(right_type->id())) { + left_type = right_type; + return Status::OK(); + } + + // precision, scale of left and right args + int32_t p1, s1, p2, s2; + + // decimal + integer = decimal + if (is_decimal(left_type->id())) { + auto decimal = checked_cast(left_type.get()); + p1 = decimal->precision(); + s1 = decimal->scale(); + } else { + DCHECK(is_integer(left_type->id())); + p1 = static_cast(std::ceil(std::log10(bit_width(left_type->id())))); + s1 = 0; + } + if (is_decimal(right_type->id())) { + auto decimal = checked_cast(right_type.get()); + p2 = decimal->precision(); + s2 = decimal->scale(); + } else { + DCHECK(is_integer(right_type->id())); + p2 = static_cast(std::ceil(std::log10(bit_width(right_type->id())))); + s2 = 0; + } + if (s1 < 0 || s2 < 0) { + return Status::NotImplemented("Decimals with negative scales not supported"); + } + + // decimal128 + decimal256 = decimal256 + Type::type casted_type_id = Type::DECIMAL128; + if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) { + casted_type_id = Type::DECIMAL256; + } + + // decimal promotion rules compatible with amazon redshift + // https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html + int32_t left_scaleup, right_scaleup; + + // "add_checked" -> "add" + const std::string op = func_name.substr(0, func_name.find("_")); + if (op == "add" || op == "subtract") { + left_scaleup = std::max(s1, s2) - s1; + right_scaleup = std::max(s1, s2) - s2; + } else if (op == "multiply") { + left_scaleup = right_scaleup = 0; + } else if (op == "divide") { + left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1; + right_scaleup = 0; + } else { + return Status::Invalid("Invalid decimal function: ", func_name); + } + + ARROW_ASSIGN_OR_RAISE( + left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup)); + ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup, + s2 + right_scaleup)); + return Status::OK(); +} + +// resolve decimal binary operation output type per *casted* args +template +Result ResolveDecimalBinaryOperationOutput( + const std::vector& args, OutputGetter&& getter) { + // casted args should be same size decimals + auto left_type = checked_cast(args[0].type.get()); + auto right_type = checked_cast(args[1].type.get()); + DCHECK_EQ(left_type->id(), right_type->id()); + + int32_t precision, scale; + std::tie(precision, scale) = getter(left_type->precision(), left_type->scale(), + right_type->precision(), right_type->scale()); + ARROW_ASSIGN_OR_RAISE(auto type, DecimalType::Make(left_type->id(), precision, scale)); + return ValueDescr(std::move(type), GetBroadcastShape(args)); +} + +Result ResolveDecimalAdditionOrSubtractionOutput( + KernelContext*, const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + DCHECK_EQ(s1, s2); + const int32_t scale = s1; + const int32_t precision = std::max(p1 - s1, p2 - s2) + scale + 1; + return std::make_pair(precision, scale); + }); +} + +Result ResolveDecimalMultiplicationOutput( + KernelContext*, const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + const int32_t scale = s1 + s2; + const int32_t precision = p1 + p2 + 1; + return std::make_pair(precision, scale); + }); +} + +Result ResolveDecimalDivisionOutput(KernelContext*, + const std::vector& args) { + return ResolveDecimalBinaryOperationOutput( + args, [](int32_t p1, int32_t s1, int32_t p2, int32_t s2) { + DCHECK_GE(s1, s2); + const int32_t scale = s1 - s2; + const int32_t precision = p1; + return std::make_pair(precision, scale); + }); +} + +template +void AddDecimalBinaryKernels(const std::string& name, + std::shared_ptr* func) { + OutputType out_type(null()); + const std::string op = name.substr(0, name.find("_")); + if (op == "add" || op == "subtract") { + out_type = OutputType(ResolveDecimalAdditionOrSubtractionOutput); + } else if (op == "multiply") { + out_type = OutputType(ResolveDecimalMultiplicationOutput); + } else if (op == "divide") { + out_type = OutputType(ResolveDecimalDivisionOutput); + } else { + DCHECK(false); + } + + auto in_type128 = InputType(Type::DECIMAL128); + auto in_type256 = InputType(Type::DECIMAL256); + auto exec128 = ScalarBinaryNotNullEqualTypes::Exec; + auto exec256 = ScalarBinaryNotNullEqualTypes::Exec; + DCHECK_OK((*func)->AddKernel({in_type128, in_type128}, out_type, exec128)); + DCHECK_OK((*func)->AddKernel({in_type256, in_type256}, out_type, exec256)); +} + struct ArithmeticFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { RETURN_NOT_OK(CheckArity(*values)); + RETURN_NOT_OK(CheckDecimals(values)); + using arrow::compute::detail::DispatchExactImpl; if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -451,6 +653,22 @@ struct ArithmeticFunction : ScalarFunction { if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); } + + Status CheckDecimals(std::vector* values) const { + bool has_decimal = false; + for (const auto& value : *values) { + if (is_decimal(value.type->id())) { + has_decimal = true; + break; + } + } + if (!has_decimal) return Status::OK(); + + if (values->size() == 2) { + return CastBinaryDecimalArgs(name(), values); + } + return Status::OK(); + } }; template @@ -617,16 +835,19 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto add = MakeArithmeticFunction("add", &add_doc); + AddDecimalBinaryKernels("add", &add); DCHECK_OK(registry->AddFunction(std::move(add))); // ---------------------------------------------------------------------- auto add_checked = MakeArithmeticFunctionNotNull("add_checked", &add_checked_doc); + AddDecimalBinaryKernels("add_checked", &add_checked); DCHECK_OK(registry->AddFunction(std::move(add_checked))); // ---------------------------------------------------------------------- // subtract auto subtract = MakeArithmeticFunction("subtract", &sub_doc); + AddDecimalBinaryKernels("subtract", &subtract); // Add subtract(timestamp, timestamp) -> duration for (auto unit : AllTimeUnits()) { @@ -640,24 +861,29 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // ---------------------------------------------------------------------- auto subtract_checked = MakeArithmeticFunctionNotNull( "subtract_checked", &sub_checked_doc); + AddDecimalBinaryKernels("subtract_checked", &subtract_checked); DCHECK_OK(registry->AddFunction(std::move(subtract_checked))); // ---------------------------------------------------------------------- auto multiply = MakeArithmeticFunction("multiply", &mul_doc); + AddDecimalBinaryKernels("multiply", &multiply); DCHECK_OK(registry->AddFunction(std::move(multiply))); // ---------------------------------------------------------------------- auto multiply_checked = MakeArithmeticFunctionNotNull( "multiply_checked", &mul_checked_doc); + AddDecimalBinaryKernels("multiply_checked", &multiply_checked); DCHECK_OK(registry->AddFunction(std::move(multiply_checked))); // ---------------------------------------------------------------------- auto divide = MakeArithmeticFunctionNotNull("divide", &div_doc); + AddDecimalBinaryKernels("divide", ÷); DCHECK_OK(registry->AddFunction(std::move(divide))); // ---------------------------------------------------------------------- auto divide_checked = MakeArithmeticFunctionNotNull("divide_checked", &div_checked_doc); + AddDecimalBinaryKernels("divide_checked", ÷_checked); DCHECK_OK(registry->AddFunction(std::move(divide_checked))); // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index c4bfac459dc..3ee862c834e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -1148,5 +1148,368 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) { } } +TEST(TestBinaryDecimalArithmetic, DispatchBest) { + // decimal, floating point + for (std::string name : {"add", "subtract", "multiply", "divide"}) { + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(1, 0), float32()}, {float32(), float32()}); + CheckDispatchBest(name, {decimal256(1, 0), float64()}, {float64(), float64()}); + CheckDispatchBest(name, {float32(), decimal256(1, 0)}, {float32(), float32()}); + CheckDispatchBest(name, {float64(), decimal128(1, 0)}, {float64(), float64()}); + } + } + + // decimal, decimal + for (std::string name : {"add", "subtract"}) { + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(3, 1), decimal128(3, 1)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(3, 1), decimal256(3, 1)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(3, 1), decimal256(3, 1)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(3, 1), decimal256(3, 1)}); + } + } + { + std::string name = "multiply"; + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(5, 2), decimal128(5, 2)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(5, 2), decimal256(5, 2)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(5, 2), decimal256(5, 2)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(5, 2), decimal256(5, 2)}); + } + } + { + std::string name = "divide"; + for (std::string suffix : {"", "_checked"}) { + name += suffix; + + CheckDispatchBest(name, {decimal128(2, 1), decimal128(2, 1)}, + {decimal128(6, 4), decimal128(6, 4)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal256(2, 1)}, + {decimal256(6, 4), decimal256(6, 4)}); + CheckDispatchBest(name, {decimal128(2, 1), decimal256(2, 1)}, + {decimal256(6, 4), decimal256(6, 4)}); + CheckDispatchBest(name, {decimal256(2, 1), decimal128(2, 1)}, + {decimal256(6, 4), decimal256(6, 4)}); + } + } + + // TODO(ARROW-13067): add 'integer, decimal' tests +} + +// reference result from bc (precsion=100, scale=40) +TEST(TestBinaryArithmeticDecimal, AddSubtract) { + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(30, 3), + R"([ + "1.000", + "-123456789012345678901234567.890", + "98765432109876543210.987", + "-999999999999999999999999999.999" + ])"); + auto right = ArrayFromJSON(decimal128(20, 9), + R"([ + "-1.000000000", + "12345678901.234567890", + "98765.432101234", + "-99999999999.999999999" + ])"); + auto added = ArrayFromJSON(decimal128(37, 9), + R"([ + "0.000000000", + "-123456789012345666555555666.655432110", + "98765432109876641976.419101234", + "-1000000000000000099999999999.998999999" + ])"); + auto subtracted = ArrayFromJSON(decimal128(37, 9), + R"([ + "2.000000000", + "-123456789012345691246913469.124567890", + "98765432109876444445.554898766", + "-999999999999999899999999999.999000001" + ])"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("subtract", left, right, subtracted); + } + + // array array, decimal256 + { + auto left = ArrayFromJSON(decimal256(30, 20), + R"([ + "-1.00000000000000000001", + "1234567890.12345678900000000000", + "-9876543210.09876543210987654321", + "9999999999.99999999999999999999" + ])"); + auto right = ArrayFromJSON(decimal256(30, 10), + R"([ + "1.0000000000", + "-1234567890.1234567890", + "6789.5432101234", + "99999999999999999999.9999999999" + ])"); + auto added = ArrayFromJSON(decimal256(41, 20), + R"([ + "-0.00000000000000000001", + "0.00000000000000000000", + "-9876536420.55555530870987654321", + "100000000009999999999.99999999989999999999" + ])"); + auto subtracted = ArrayFromJSON(decimal256(41, 20), + R"([ + "-2.00000000000000000001", + "2469135780.24691357800000000000", + "-9876549999.64197555550987654321", + "-99999999989999999999.99999999990000000001" + ])"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("subtract", left, right, subtracted); + } + + // scalar array + { + auto left = ScalarFromJSON(decimal128(6, 1), R"("12345.6")"); + auto right = ArrayFromJSON(decimal128(10, 3), + R"(["1.234", "1234.000", "-9876.543", "666.888"])"); + auto added = ArrayFromJSON(decimal128(11, 3), + R"(["12346.834", "13579.600", "2469.057", "13012.488"])"); + auto left_sub_right = ArrayFromJSON( + decimal128(11, 3), R"(["12344.366", "11111.600", "22222.143", "11678.712"])"); + auto right_sub_left = ArrayFromJSON( + decimal128(11, 3), R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("add", right, left, added); + CheckScalarBinary("subtract", left, right, left_sub_right); + CheckScalarBinary("subtract", right, left, right_sub_left); + } + + // scalar scalar + { + auto left = ScalarFromJSON(decimal256(3, 0), R"("666")"); + auto right = ScalarFromJSON(decimal256(3, 0), R"("888")"); + auto added = ScalarFromJSON(decimal256(4, 0), R"("1554")"); + auto subtracted = ScalarFromJSON(decimal256(4, 0), R"("-222")"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("subtract", left, right, subtracted); + } + + // decimal128 decimal256 + { + auto left = ScalarFromJSON(decimal128(3, 0), R"("666")"); + auto right = ScalarFromJSON(decimal256(3, 0), R"("888")"); + auto added = ScalarFromJSON(decimal256(4, 0), R"("1554")"); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("add", right, left, added); + } + + // decimal float + { + auto left = ScalarFromJSON(decimal128(3, 0), R"("666")"); + ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888)); + ASSIGN_OR_ABORT(auto added, arrow::MakeScalar(float64(), 1554)); + CheckScalarBinary("add", left, right, added); + CheckScalarBinary("add", right, left, added); + } + + // TODO: decimal integer + + // failed case: result maybe overflow + { + std::shared_ptr left, right; + + left = ScalarFromJSON(decimal128(21, 20), R"("0.12345678901234567890")"); + right = ScalarFromJSON(decimal128(21, 1), R"("1.0")"); + ASSERT_RAISES(Invalid, CallFunction("add", {left, right})); + ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right})); + + left = ScalarFromJSON(decimal256(75, 0), R"("0")"); + right = ScalarFromJSON(decimal256(2, 1), R"("0.0")"); + ASSERT_RAISES(Invalid, CallFunction("add", {left, right})); + ASSERT_RAISES(Invalid, CallFunction("subtract", {left, right})); + } +} + +TEST(TestBinaryArithmeticDecimal, Multiply) { + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(20, 10), + R"([ + "1234567890.1234567890", + "-0.0000000001", + "-9999999999.9999999999" + ])"); + auto right = ArrayFromJSON(decimal128(13, 3), + R"([ + "1234567890.123", + "0.001", + "-9999999999.999" + ])"); + auto expected = ArrayFromJSON(decimal128(34, 13), + R"([ + "1524157875323319737.9870903950470", + "-0.0000000000001", + "99999999999989999999.0000000000001" + ])"); + CheckScalarBinary("multiply", left, right, expected); + } + + // array array, decimal26 + { + auto left = ArrayFromJSON(decimal256(30, 3), + R"([ + "123456789012345678901234567.890", + "0.000" + ])"); + auto right = ArrayFromJSON(decimal256(20, 9), + R"([ + "-12345678901.234567890", + "99999999999.999999999" + ])"); + auto expected = ArrayFromJSON(decimal256(51, 12), + R"([ + "-1524157875323883675034293577501905199.875019052100", + "0.000000000000" + ])"); + CheckScalarBinary("multiply", left, right, expected); + } + + // scalar array + { + auto left = ScalarFromJSON(decimal128(3, 2), R"("3.14")"); + auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"); + auto expected = + ArrayFromJSON(decimal128(5, 2), R"(["3.14", "6.28", "9.42", "12.56", "15.70"])"); + CheckScalarBinary("multiply", left, right, expected); + CheckScalarBinary("multiply", right, left, expected); + } + + // scalar scalar + { + auto left = ScalarFromJSON(decimal128(1, 0), R"("1")"); + auto right = ScalarFromJSON(decimal128(1, 0), R"("1")"); + auto expected = ScalarFromJSON(decimal128(3, 0), R"("1")"); + CheckScalarBinary("multiply", left, right, expected); + } + + // decimal128 decimal256 + { + auto left = ScalarFromJSON(decimal128(3, 2), R"("6.66")"); + auto right = ScalarFromJSON(decimal256(3, 1), R"("88.8")"); + auto expected = ScalarFromJSON(decimal256(7, 3), R"("591.408")"); + CheckScalarBinary("multiply", left, right, expected); + CheckScalarBinary("multiply", right, left, expected); + } + + // decimal float + { + auto left = ScalarFromJSON(decimal128(3, 0), R"("666")"); + ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 888)); + ASSIGN_OR_ABORT(auto expected, arrow::MakeScalar(float64(), 591408)); + CheckScalarBinary("multiply", left, right, expected); + CheckScalarBinary("multiply", right, left, expected); + } + + // TODO: decimal integer + + // failed case: result maybe overflow + { + auto left = ScalarFromJSON(decimal128(20, 0), R"("1")"); + auto right = ScalarFromJSON(decimal128(18, 1), R"("1.0")"); + ASSERT_RAISES(Invalid, CallFunction("multiply", {left, right})); + } +} + +TEST(TestBinaryArithmeticDecimal, Divide) { + // array array, decimal128 + { + auto left = ArrayFromJSON(decimal128(13, 3), R"(["1234567890.123", "0.001"])"); + auto right = ArrayFromJSON(decimal128(3, 0), R"(["-987", "999"])"); + auto expected = + ArrayFromJSON(decimal128(17, 7), R"(["-1250828.6627386", "0.0000010"])"); + CheckScalarBinary("divide", left, right, expected); + } + + // array array, decimal256 + { + auto left = ArrayFromJSON(decimal256(20, 10), + R"(["1234567890.1234567890", "9999999999.9999999999"])"); + auto right = ArrayFromJSON(decimal256(13, 3), R"(["1234567890.123", "0.001"])"); + auto expected = ArrayFromJSON( + decimal256(34, 21), + R"(["1.000000000000369999093", "9999999999999.999999900000000000000"])"); + CheckScalarBinary("divide", left, right, expected); + } + + // scalar array + { + auto left = ScalarFromJSON(decimal128(1, 0), R"("1")"); + auto right = ArrayFromJSON(decimal128(1, 0), R"(["1", "2", "3", "4"])"); + auto left_div_right = + ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "0.5000", "0.3333", "0.2500"])"); + auto right_div_left = + ArrayFromJSON(decimal128(5, 4), R"(["1.0000", "2.0000", "3.0000", "4.0000"])"); + CheckScalarBinary("divide", left, right, left_div_right); + CheckScalarBinary("divide", right, left, right_div_left); + } + + // scalar scalar + { + auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")"); + auto right = ScalarFromJSON(decimal256(6, 5), R"("3.14159")"); + auto expected = ScalarFromJSON(decimal256(13, 7), R"("0.8652561")"); + CheckScalarBinary("divide", left, right, expected); + } + + // decimal128 decimal256 + { + auto left = ScalarFromJSON(decimal256(6, 5), R"("2.71828")"); + auto right = ScalarFromJSON(decimal128(6, 5), R"("3.14159")"); + auto left_div_right = ScalarFromJSON(decimal256(13, 7), R"("0.8652561")"); + auto right_div_left = ScalarFromJSON(decimal256(13, 7), R"("1.1557271")"); + CheckScalarBinary("divide", left, right, left_div_right); + CheckScalarBinary("divide", right, left, right_div_left); + } + + // decimal float + { + auto left = ScalarFromJSON(decimal128(3, 0), R"("100")"); + ASSIGN_OR_ABORT(auto right, arrow::MakeScalar(float64(), 50)); + ASSIGN_OR_ABORT(auto left_div_right, arrow::MakeScalar(float64(), 2)); + ASSIGN_OR_ABORT(auto right_div_left, arrow::MakeScalar(float64(), 0.5)); + CheckScalarBinary("divide", left, right, left_div_right); + CheckScalarBinary("divide", right, left, right_div_left); + } + + // TODO: decimal integer + + // failed case: result maybe overflow + { + auto left = ScalarFromJSON(decimal128(20, 20), R"("0.12345678901234567890")"); + auto right = ScalarFromJSON(decimal128(20, 0), R"("12345678901234567890")"); + ASSERT_RAISES(Invalid, CallFunction("divide", {left, right})); + } + + // failed case: divide by 0 + { + auto left = ScalarFromJSON(decimal256(1, 0), R"("1")"); + auto right = ScalarFromJSON(decimal256(1, 0), R"("0")"); + ASSERT_RAISES(Invalid, CallFunction("divide", {left, right})); + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 344585446fc..65c783ce847 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -771,6 +771,17 @@ std::vector> StructType::GetAllFieldsByName( return result; } +Result> DecimalType::Make(Type::type type_id, int32_t precision, + int32_t scale) { + if (type_id == Type::DECIMAL128) { + return Decimal128Type::Make(precision, scale); + } else if (type_id == Type::DECIMAL256) { + return Decimal256Type::Make(precision, scale); + } else { + return Status::Invalid("Not a decimal type_id: ", type_id); + } +} + // Taken from the Apache Impala codebase. The comments next // to the return values are the maximum value that can be represented in 2's // complement with the returned number of bytes. diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 1d3d1e27f92..b933da66089 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -880,6 +880,10 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType { int32_t scale) : FixedSizeBinaryType(byte_width, type_id), precision_(precision), scale_(scale) {} + /// Constructs concrete decimal types + static Result> Make(Type::type type_id, int32_t precision, + int32_t scale); + int32_t precision() const { return precision_; } int32_t scale() const { return scale_; } diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index b74aa3b0adb..86664bbb162 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -845,6 +845,17 @@ static inline bool is_floating(Type::type type_id) { return false; } +static inline bool is_decimal(Type::type type_id) { + switch (type_id) { + case Type::DECIMAL128: + case Type::DECIMAL256: + return true; + default: + break; + } + return false; +} + static inline bool is_primitive(Type::type type_id) { switch (type_id) { case Type::BOOL: diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index dfdd64d19c6..147885560f5 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -266,17 +266,17 @@ an ``Invalid`` :class:`Status` when overflow is detected. +--------------------------+------------+--------------------+---------------------+ | abs_checked | Unary | Numeric | Numeric | +--------------------------+------------+--------------------+---------------------+ -| add | Binary | Numeric | Numeric | +| add | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| add_checked | Binary | Numeric | Numeric | +| add_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| divide | Binary | Numeric | Numeric | +| divide | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| divide_checked | Binary | Numeric | Numeric | +| divide_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| multiply | Binary | Numeric | Numeric | +| multiply | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| multiply_checked | Binary | Numeric | Numeric | +| multiply_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ | negate | Unary | Numeric | Numeric | +--------------------------+------------+--------------------+---------------------+ @@ -286,11 +286,32 @@ an ``Invalid`` :class:`Status` when overflow is detected. +--------------------------+------------+--------------------+---------------------+ | power_checked | Binary | Numeric | Numeric | +--------------------------+------------+--------------------+---------------------+ -| subtract | Binary | Numeric | Numeric | +| subtract | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ -| subtract_checked | Binary | Numeric | Numeric | +| subtract_checked | Binary | Numeric | Numeric (1) | +--------------------------+------------+--------------------+---------------------+ +* \(1) Precision and scale of computed DECIMAL results + ++------------+---------------------------------------------+ +| Operation | Result precision and scale | ++============+=============================================+ +| | add | | scale = max(s1, s2) | +| | subtract | | precision = max(p1-s1, p2-s2) + 1 + scale | ++------------+---------------------------------------------+ +| multiply | | scale = s1 + s2 | +| | | precision = p1 + p2 + 1 | ++------------+---------------------------------------------+ +| divide | | scale = max(4, s1 + p2 - s2 + 1) | +| | | precision = p1 - s1 + s2 + scale | ++------------+---------------------------------------------+ + +It's compatible with Redshift's decimal promotion rules. All decimal digits +are preserved for `add`, `subtract` and `multiply` operations. The result +precision of `divide` is at least the sum of precisions of both operands with +enough scale kept. Error is returned if the result precision is beyond the +decimal value range. + Comparisons ~~~~~~~~~~~