Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "arrow/compute/kernels/codegen_internal.h"

#include <cmath>
#include <functional>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -341,6 +342,91 @@ std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs) {
return large_binary();
}

Status CastBinaryDecimalArgs(DecimalPromotion promotion,
std::vector<ValueDescr>* descrs) {
auto& left_type = (*descrs)[0].type;
auto& right_type = (*descrs)[1].type;
DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));

// decimal + float = float
if (is_floating(left_type->id())) {
right_type = left_type;
return Status::OK();
} else if (is_floating(right_type->id())) {
left_type = right_type;
return Status::OK();
}

// precision, scale of left and right args
int32_t p1, s1, p2, s2;

// decimal + integer = decimal
if (is_decimal(left_type->id())) {
auto decimal = checked_cast<const DecimalType*>(left_type.get());
p1 = decimal->precision();
s1 = decimal->scale();
} else {
DCHECK(is_integer(left_type->id()));
ARROW_ASSIGN_OR_RAISE(p1, MaxDecimalDigitsForInteger(left_type->id()));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this came up in a different PR or discussion somewhere, but this makes it so that decimal-integer addition works 'properly' now if the decimal precision wasn't big enough to hold the integer (now it will promote both to a larger decimal type).

s1 = 0;
}
if (is_decimal(right_type->id())) {
auto decimal = checked_cast<const DecimalType*>(right_type.get());
p2 = decimal->precision();
s2 = decimal->scale();
} else {
DCHECK(is_integer(right_type->id()));
ARROW_ASSIGN_OR_RAISE(p2, MaxDecimalDigitsForInteger(right_type->id()));
s2 = 0;
}
if (s1 < 0 || s2 < 0) {
return Status::NotImplemented("Decimals with negative scales not supported");
}

// decimal128 + decimal256 = decimal256
Type::type casted_type_id = Type::DECIMAL128;
if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
casted_type_id = Type::DECIMAL256;
}

// decimal promotion rules compatible with amazon redshift
// https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
int32_t left_scaleup, right_scaleup;

switch (promotion) {
case DecimalPromotion::kAdd: {
left_scaleup = std::max(s1, s2) - s1;
right_scaleup = std::max(s1, s2) - s2;
break;
}
case DecimalPromotion::kMultiply: {
left_scaleup = right_scaleup = 0;
break;
}
case DecimalPromotion::kDivide: {
left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
right_scaleup = 0;
break;
}
default:
DCHECK(false) << "Invalid DecimalPromotion value " << static_cast<int>(promotion);
}
ARROW_ASSIGN_OR_RAISE(
left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
s2 + right_scaleup));
return Status::OK();
}

bool HasDecimal(const std::vector<ValueDescr>& descrs) {
for (const auto& descr : descrs) {
if (is_decimal(descr.type->id())) {
return true;
}
}
return false;
}

} // namespace internal
} // namespace compute
} // namespace arrow
13 changes: 13 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,19 @@ std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& descrs)
ARROW_EXPORT
std::shared_ptr<DataType> CommonBinary(const std::vector<ValueDescr>& descrs);

/// How to promote decimal precision/scale in CastBinaryDecimalArgs.
enum class DecimalPromotion : uint8_t {
kAdd,
kMultiply,
kDivide,
};

ARROW_EXPORT
Status CastBinaryDecimalArgs(DecimalPromotion promotion, std::vector<ValueDescr>* descrs);

ARROW_EXPORT
bool HasDecimal(const std::vector<ValueDescr>& descrs);

} // namespace internal
} // namespace compute
} // namespace arrow
94 changes: 13 additions & 81 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -965,78 +965,6 @@ ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) {
}
}

Status CastBinaryDecimalArgs(const std::string& func_name,
std::vector<ValueDescr>* values) {
auto& left_type = (*values)[0].type;
auto& right_type = (*values)[1].type;
DCHECK(is_decimal(left_type->id()) || is_decimal(right_type->id()));

// decimal + float = float
if (is_floating(left_type->id())) {
right_type = left_type;
return Status::OK();
} else if (is_floating(right_type->id())) {
left_type = right_type;
return Status::OK();
}

// precision, scale of left and right args
int32_t p1, s1, p2, s2;

// decimal + integer = decimal
if (is_decimal(left_type->id())) {
auto decimal = checked_cast<const DecimalType*>(left_type.get());
p1 = decimal->precision();
s1 = decimal->scale();
} else {
DCHECK(is_integer(left_type->id()));
p1 = static_cast<int32_t>(std::ceil(std::log10(bit_width(left_type->id()))));
s1 = 0;
}
if (is_decimal(right_type->id())) {
auto decimal = checked_cast<const DecimalType*>(right_type.get());
p2 = decimal->precision();
s2 = decimal->scale();
} else {
DCHECK(is_integer(right_type->id()));
p2 = static_cast<int32_t>(std::ceil(std::log10(bit_width(right_type->id()))));
s2 = 0;
}
if (s1 < 0 || s2 < 0) {
return Status::NotImplemented("Decimals with negative scales not supported");
}

// decimal128 + decimal256 = decimal256
Type::type casted_type_id = Type::DECIMAL128;
if (left_type->id() == Type::DECIMAL256 || right_type->id() == Type::DECIMAL256) {
casted_type_id = Type::DECIMAL256;
}

// decimal promotion rules compatible with amazon redshift
// https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
int32_t left_scaleup, right_scaleup;

// "add_checked" -> "add"
const std::string op = func_name.substr(0, func_name.find("_"));
if (op == "add" || op == "subtract") {
left_scaleup = std::max(s1, s2) - s1;
right_scaleup = std::max(s1, s2) - s2;
} else if (op == "multiply") {
left_scaleup = right_scaleup = 0;
} else if (op == "divide") {
left_scaleup = std::max(4, s1 + p2 - s2 + 1) + s2 - s1;
right_scaleup = 0;
} else {
return Status::Invalid("Invalid decimal function: ", func_name);
}

ARROW_ASSIGN_OR_RAISE(
left_type, DecimalType::Make(casted_type_id, p1 + left_scaleup, s1 + left_scaleup));
ARROW_ASSIGN_OR_RAISE(right_type, DecimalType::Make(casted_type_id, p2 + right_scaleup,
s2 + right_scaleup));
return Status::OK();
}

// resolve decimal binary operation output type per *casted* args
template <typename OutputGetter>
Result<ValueDescr> ResolveDecimalBinaryOperationOutput(
Expand Down Expand Up @@ -1166,17 +1094,21 @@ struct ArithmeticFunction : ScalarFunction {
}

Status CheckDecimals(std::vector<ValueDescr>* values) const {
bool has_decimal = false;
for (const auto& value : *values) {
if (is_decimal(value.type->id())) {
has_decimal = true;
break;
}
}
if (!has_decimal) return Status::OK();
if (!HasDecimal(*values)) return Status::OK();

if (values->size() == 2) {
return CastBinaryDecimalArgs(name(), values);
// "add_checked" -> "add"
const auto func_name = name();
const std::string op = func_name.substr(0, func_name.find("_"));
if (op == "add" || op == "subtract") {
return CastBinaryDecimalArgs(DecimalPromotion::kAdd, values);
} else if (op == "multiply") {
return CastBinaryDecimalArgs(DecimalPromotion::kMultiply, values);
} else if (op == "divide") {
return CastBinaryDecimalArgs(DecimalPromotion::kDivide, values);
} else {
return Status::Invalid("Invalid decimal function: ", func_name);
}
}
return Status::OK();
}
Expand Down
7 changes: 2 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,8 @@ struct CastFunctor<O, I,
if (out_scale < 0) {
return Status::Invalid("Scale must be non-negative");
}
// maximal number of decimal digits for int8/16/32/64
constexpr std::array<int, 4> decimal_digits{3, 5, 10, 19};
using ctype = typename I::c_type;
static_assert(sizeof(ctype) <= 8, "");
const int precision = decimal_digits[BitUtil::Log2(sizeof(ctype))] + out_scale;
ARROW_ASSIGN_OR_RAISE(int32_t precision, MaxDecimalDigitsForInteger(I::type_id));
precision += out_scale;
if (out_precision < precision) {
return Status::Invalid(
"Precision is not great enough for the result. "
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ struct CompareFunction : ScalarFunction {
ReplaceTypes(type, values);
} else if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
} else if (HasDecimal(*values)) {
RETURN_NOT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, values));
}

if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
Expand Down Expand Up @@ -259,6 +261,12 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name,
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
}

for (const auto id : DecimalTypeIds()) {
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
}

return func;
}

Expand Down
114 changes: 114 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,109 @@ TEST(TestCompareTimestamps, Basics) {
CheckArrayCase(seconds_utc, CompareOperator::EQUAL, "[false, false, true]");
}

template <typename ArrowType>
class TestCompareDecimal : public ::testing::Test {};
TYPED_TEST_SUITE(TestCompareDecimal, DecimalArrowTypes);

TYPED_TEST(TestCompareDecimal, ArrayScalar) {
auto ty = std::make_shared<TypeParam>(3, 2);

std::vector<std::pair<std::string, std::string>> cases = {
std::make_pair("equal", "[1, 0, 0, null]"),
std::make_pair("not_equal", "[0, 1, 1, null]"),
std::make_pair("less", "[0, 0, 1, null]"),
std::make_pair("less_equal", "[1, 0, 1, null]"),
std::make_pair("greater", "[0, 1, 0, null]"),
std::make_pair("greater_equal", "[1, 1, 0, null]"),
};

auto lhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])");
auto lhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]");
auto lhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])");
auto rhs = ScalarFromJSON(ty, R"("1.23")");
auto rhs_float = ScalarFromJSON(float64(), "1.23");
auto rhs_int = ScalarFromJSON(int64(), "1");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareDecimal, ScalarArray) {
auto ty = std::make_shared<TypeParam>(3, 2);

std::vector<std::pair<std::string, std::string>> cases = {
std::make_pair("equal", "[1, 0, 0, null]"),
std::make_pair("not_equal", "[0, 1, 1, null]"),
std::make_pair("less", "[0, 1, 0, null]"),
std::make_pair("less_equal", "[1, 1, 0, null]"),
std::make_pair("greater", "[0, 0, 1, null]"),
std::make_pair("greater_equal", "[1, 0, 1, null]"),
};

auto lhs = ScalarFromJSON(ty, R"("1.23")");
auto lhs_float = ScalarFromJSON(float64(), "1.23");
auto lhs_int = ScalarFromJSON(int64(), "1");
auto rhs = ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])");
auto rhs_float = ArrayFromJSON(float64(), "[1.23, 2.34, -1.23, null]");
auto rhs_intlike = ArrayFromJSON(ty, R"(["1.00", "2.00", "-1.00", null])");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_int, rhs_intlike, ArrayFromJSON(boolean(), expected));
}
}

TYPED_TEST(TestCompareDecimal, ArrayArray) {
auto ty = std::make_shared<TypeParam>(3, 2);

std::vector<std::pair<std::string, std::string>> cases = {
std::make_pair("equal", "[1, 0, 0, 1, 0, 0, null, null]"),
std::make_pair("not_equal", "[0, 1, 1, 0, 1, 1, null, null]"),
std::make_pair("less", "[0, 1, 0, 0, 1, 0, null, null]"),
std::make_pair("less_equal", "[1, 1, 0, 1, 1, 0, null, null]"),
std::make_pair("greater", "[0, 0, 1, 0, 0, 1, null, null]"),
std::make_pair("greater_equal", "[1, 0, 1, 1, 0, 1, null, null]"),
};

auto lhs = ArrayFromJSON(
ty, R"(["1.23", "1.23", "2.34", "-1.23", "-1.23", "1.23", "1.23", null])");
auto lhs_float =
ArrayFromJSON(float64(), "[1.23, 1.23, 2.34, -1.23, -1.23, 1.23, 1.23, null]");
auto lhs_intlike = ArrayFromJSON(
ty, R"(["1.00", "1.00", "2.00", "-1.00", "-1.00", "1.00", "1.00", null])");
auto rhs = ArrayFromJSON(
ty, R"(["1.23", "2.34", "1.23", "-1.23", "1.23", "-1.23", null, "1.23"])");
auto rhs_float =
ArrayFromJSON(float64(), "[1.23, 2.34, 1.23, -1.23, 1.23, -1.23, null, 1.23]");
auto rhs_int = ArrayFromJSON(int64(), "[1, 2, 1, -1, 1, -1, null, 1]");
for (const auto& op : cases) {
const auto& function = op.first;
const auto& expected = op.second;

SCOPED_TRACE(function);
CheckScalarBinary(function, ArrayFromJSON(ty, R"([])"), ArrayFromJSON(ty, R"([])"),
ArrayFromJSON(boolean(), "[]"));
CheckScalarBinary(function, ArrayFromJSON(ty, R"([null])"),
ArrayFromJSON(ty, R"([null])"), ArrayFromJSON(boolean(), "[null]"));
CheckScalarBinary(function, lhs, rhs, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_float, rhs, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs, rhs_float, ArrayFromJSON(boolean(), expected));
CheckScalarBinary(function, lhs_intlike, rhs_int, ArrayFromJSON(boolean(), expected));
}
}

TEST(TestCompareKernel, DispatchBest) {
for (std::string name :
{"equal", "not_equal", "less", "less_equal", "greater", "greater_equal"}) {
Expand Down Expand Up @@ -490,6 +593,17 @@ TEST(TestCompareKernel, DispatchBest) {

CheckDispatchBest(name, {utf8(), binary()}, {binary(), binary()});
CheckDispatchBest(name, {large_utf8(), binary()}, {large_binary(), large_binary()});

CheckDispatchBest(name, {decimal128(3, 2), decimal128(6, 3)},
{decimal128(4, 3), decimal128(6, 3)});
CheckDispatchBest(name, {decimal128(3, 2), decimal256(3, 2)},
{decimal256(3, 2), decimal256(3, 2)});
CheckDispatchBest(name, {decimal128(3, 2), float64()}, {float64(), float64()});
CheckDispatchBest(name, {float64(), decimal128(3, 2)}, {float64(), float64()});
CheckDispatchBest(name, {decimal128(3, 2), int64()},
{decimal128(3, 2), decimal128(3, 2)});
CheckDispatchBest(name, {int64(), decimal128(3, 2)},
{decimal128(3, 2), decimal128(3, 2)});
}
}

Expand Down
Loading