Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions cpp/src/arrow/compute/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,150 @@ struct CastFunctor<
}
};

// ----------------------------------------------------------------------
// Decimals

// Decimal to Integer

template <typename O>
struct CastFunctor<O, Decimal128Type, enable_if_t<is_integer_type<O>::value>> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
using out_type = typename O::c_type;
const auto& in_type_inst = checked_cast<const Decimal128Type&>(*input.type);
auto in_scale = in_type_inst.scale();

auto out_data = output->GetMutableValues<out_type>(1);

constexpr auto min_value = std::numeric_limits<out_type>::min();
constexpr auto max_value = std::numeric_limits<out_type>::max();
constexpr auto zero = out_type{};

if (options.allow_decimal_truncate) {
if (in_scale < 0) {
// Unsafe upscale
auto convert_value = [&](util::optional<util::string_view> v) {
*out_data = zero;
if (v.has_value()) {
auto dec_value = Decimal128(reinterpret_cast<const uint8_t*>(v->data()));
auto converted = dec_value.IncreaseScaleBy(-in_scale);
if (!options.allow_int_overflow &&
ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) {
ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
} else {
*out_data = static_cast<out_type>(converted.low_bits());
}
}
++out_data;
};
VisitArrayDataInline<Decimal128Type>(input, std::move(convert_value));
} else {
// Unsafe downscale
auto convert_value = [&](util::optional<util::string_view> v) {
*out_data = zero;
if (v.has_value()) {
auto dec_value = Decimal128(reinterpret_cast<const uint8_t*>(v->data()));
auto converted = dec_value.ReduceScaleBy(in_scale, false);
if (!options.allow_int_overflow &&
ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) {
ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
} else {
*out_data = static_cast<out_type>(converted.low_bits());
}
}
++out_data;
};
VisitArrayDataInline<Decimal128Type>(input, std::move(convert_value));
}
} else {
// Safe rescale
auto convert_value = [&](util::optional<util::string_view> v) {
*out_data = zero;
if (v.has_value()) {
auto dec_value = Decimal128(reinterpret_cast<const uint8_t*>(v->data()));
auto result = dec_value.Rescale(in_scale, 0);
if (ARROW_PREDICT_FALSE(!result.ok())) {
ctx->SetStatus(result.status());
} else {
auto converted = *std::move(result);
if (!options.allow_int_overflow &&
ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) {
ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
} else {
*out_data = static_cast<out_type>(converted.low_bits());
}
}
}
++out_data;
};
VisitArrayDataInline<Decimal128Type>(input, std::move(convert_value));
}
}
};

// Decimal to Decimal

template <>
struct CastFunctor<Decimal128Type, Decimal128Type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
const auto& in_type_inst = checked_cast<const Decimal128Type&>(*input.type);
const auto& out_type_inst = checked_cast<const Decimal128Type&>(*output->type);
auto in_scale = in_type_inst.scale();
auto out_scale = out_type_inst.scale();

auto out_data = output->GetMutableValues<uint8_t>(1);

const auto write_zero = [](uint8_t* out_data) { memset(out_data, 0, 16); };

if (options.allow_decimal_truncate) {
if (in_scale < out_scale) {
// Unsafe upscale
auto convert_value = [&](util::optional<util::string_view> v) {
if (v.has_value()) {
auto dec_value = Decimal128(reinterpret_cast<const uint8_t*>(v->data()));
dec_value.IncreaseScaleBy(out_scale - in_scale).ToBytes(out_data);
} else {
write_zero(out_data);
}
out_data += 16;
};
VisitArrayDataInline<Decimal128Type>(input, std::move(convert_value));
} else {
// Unsafe downscale
auto convert_value = [&](util::optional<util::string_view> v) {
if (v.has_value()) {
auto dec_value = Decimal128(reinterpret_cast<const uint8_t*>(v->data()));
dec_value.ReduceScaleBy(in_scale - out_scale, false).ToBytes(out_data);
} else {
write_zero(out_data);
}
out_data += 16;
};
VisitArrayDataInline<Decimal128Type>(input, std::move(convert_value));
}
} else {
// Safe rescale
auto convert_value = [&](util::optional<util::string_view> v) {
if (v.has_value()) {
auto dec_value = Decimal128(reinterpret_cast<const uint8_t*>(v->data()));
auto result = dec_value.Rescale(in_scale, out_scale);
if (ARROW_PREDICT_FALSE(!result.ok())) {
ctx->SetStatus(result.status());
write_zero(out_data);
} else {
(*std::move(result)).ToBytes(out_data);
}
} else {
write_zero(out_data);
}
out_data += 16;
};
VisitArrayDataInline<Decimal128Type>(input, std::move(convert_value));
}
}
};

// ----------------------------------------------------------------------
// From one timestamp to another

Expand Down Expand Up @@ -1203,6 +1347,7 @@ GET_CAST_FUNCTION(UINT64_CASES, UInt64Type, CastKernel)
GET_CAST_FUNCTION(INT64_CASES, Int64Type, CastKernel)
GET_CAST_FUNCTION(FLOAT_CASES, FloatType, CastKernel)
GET_CAST_FUNCTION(DOUBLE_CASES, DoubleType, CastKernel)
GET_CAST_FUNCTION(DECIMAL128_CASES, Decimal128Type, CastKernel)
GET_CAST_FUNCTION(DATE32_CASES, Date32Type, CastKernel)
GET_CAST_FUNCTION(DATE64_CASES, Date64Type, CastKernel)
GET_CAST_FUNCTION(TIME32_CASES, Time32Type, CastKernel)
Expand Down Expand Up @@ -1293,6 +1438,7 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
CAST_FUNCTION_CASE(Int64Type);
CAST_FUNCTION_CASE(FloatType);
CAST_FUNCTION_CASE(DoubleType);
CAST_FUNCTION_CASE(Decimal128Type);
CAST_FUNCTION_CASE(Date32Type);
CAST_FUNCTION_CASE(Date64Type);
CAST_FUNCTION_CASE(Time32Type);
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernels/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ struct ARROW_EXPORT CastOptions {
: allow_int_overflow(false),
allow_time_truncate(false),
allow_time_overflow(false),
allow_decimal_truncate(false),
allow_float_truncate(false),
allow_invalid_utf8(false) {}

explicit CastOptions(bool safe)
: allow_int_overflow(!safe),
allow_time_truncate(!safe),
allow_time_overflow(!safe),
allow_decimal_truncate(!safe),
allow_float_truncate(!safe),
allow_invalid_utf8(!safe) {}

Expand All @@ -55,6 +57,7 @@ struct ARROW_EXPORT CastOptions {
bool allow_int_overflow;
bool allow_time_truncate;
bool allow_time_overflow;
bool allow_decimal_truncate;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't add this if this isn't used anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it is used now

bool allow_float_truncate;
// Indicate if conversions from Binary/FixedSizeBinary to string must
// validate the utf8 payload.
Expand Down
165 changes: 165 additions & 0 deletions cpp/src/arrow/compute/kernels/cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,171 @@ TEST_F(TestCast, IntToFloatingPoint) {
}
#endif

TEST_F(TestCast, DecimalToInt) {
CastOptions options;
std::vector<bool> is_valid2 = {true, true};
std::vector<bool> is_valid3 = {true, true, false};

// no overflow no truncation
std::vector<Decimal128> v12 = {Decimal128("02.0000000000"),
Decimal128("-11.0000000000")};
std::vector<Decimal128> v13 = {Decimal128("02.0000000000"),
Decimal128("-11.0000000000"),
Decimal128("-12.0000000000")};
std::vector<int64_t> e12 = {2, -11};
std::vector<int64_t> e13 = {2, -11, 0};

for (bool allow_int_overflow : {false, true}) {
for (bool allow_decimal_truncate : {false, true}) {
options.allow_int_overflow = allow_int_overflow;
options.allow_decimal_truncate = allow_decimal_truncate;
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v12, is_valid2, int64(), e12, options);
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v13, is_valid3, int64(), e13, options);
}
}

// truncation, no overflow
std::vector<Decimal128> v22 = {Decimal128("02.1000000000"),
Decimal128("-11.0000004500")};
std::vector<Decimal128> v23 = {Decimal128("02.1000000000"),
Decimal128("-11.0000004500"),
Decimal128("-12.0000004500")};
std::vector<int64_t> e22 = {2, -11};
std::vector<int64_t> e23 = {2, -11, 0};

for (bool allow_int_overflow : {false, true}) {
options.allow_int_overflow = allow_int_overflow;
options.allow_decimal_truncate = true;
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v22, is_valid2, int64(), e22, options);
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v23, is_valid3, int64(), e23, options);
options.allow_decimal_truncate = false;
CheckFails<Decimal128Type>(decimal(38, 10), v22, is_valid2, int64(), options);
CheckFails<Decimal128Type>(decimal(38, 10), v23, is_valid3, int64(), options);
}

// overflow, no truncation
std::vector<Decimal128> v32 = {Decimal128("12345678901234567890000.0000000000"),
Decimal128("99999999999999999999999.0000000000")};
std::vector<Decimal128> v33 = {Decimal128("12345678901234567890000.0000000000"),
Decimal128("99999999999999999999999.0000000000"),
Decimal128("99999999999999999999999.0000000000")};
// 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64
std::vector<int64_t> e32 = {4807115922877858896, 200376420520689663};
std::vector<int64_t> e33 = {4807115922877858896, 200376420520689663, -2};

for (bool allow_decimal_truncate : {false, true}) {
options.allow_decimal_truncate = allow_decimal_truncate;
options.allow_int_overflow = true;
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v32, is_valid2, int64(), e32, options);
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v33, is_valid3, int64(), e33, options);
options.allow_int_overflow = false;
CheckFails<Decimal128Type>(decimal(38, 10), v32, is_valid2, int64(), options);
CheckFails<Decimal128Type>(decimal(38, 10), v33, is_valid3, int64(), options);
}

// overflow, truncation
std::vector<Decimal128> v42 = {Decimal128("12345678901234567890000.0045345000"),
Decimal128("99999999999999999999999.0000005430")};
std::vector<Decimal128> v43 = {Decimal128("12345678901234567890000.0005345340"),
Decimal128("99999999999999999999999.0000344300"),
Decimal128("99999999999999999999999.0004354000")};
// 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64
std::vector<int64_t> e42 = {4807115922877858896, 200376420520689663};
std::vector<int64_t> e43 = {4807115922877858896, 200376420520689663, -2};

for (bool allow_int_overflow : {false, true}) {
for (bool allow_decimal_truncate : {false, true}) {
options.allow_int_overflow = allow_int_overflow;
options.allow_decimal_truncate = allow_decimal_truncate;
if (options.allow_int_overflow && options.allow_decimal_truncate) {
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v42, is_valid2, int64(), e42, options);
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, 10), v43, is_valid3, int64(), e43, options);
} else {
CheckFails<Decimal128Type>(decimal(38, 10), v42, is_valid2, int64(), options);
CheckFails<Decimal128Type>(decimal(38, 10), v43, is_valid3, int64(), options);
}
}
}

// negative scale
std::vector<Decimal128> v5 = {Decimal128("1234567890000."), Decimal128("-120000.")};
for (int i = 0; i < 2; i++) v5[i] = v5[i].Rescale(0, -4).ValueOrDie();
std::vector<int64_t> e5 = {1234567890000, -120000};
CheckCase<Decimal128Type, Decimal128, Int64Type, int64_t>(
decimal(38, -4), v5, is_valid2, int64(), e5, options);
}

TEST_F(TestCast, DecimalToDecimal) {
CastOptions options;

std::vector<bool> is_valid2 = {true, true};
std::vector<bool> is_valid3 = {true, true, false};

// simple cases decimal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remark here: I would expect more tricky conversion cases, failures (overflow) and nulls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more tests: nulls, truncation


std::vector<Decimal128> v12 = {Decimal128("02.0000000000"),
Decimal128("30.0000000000")};
std::vector<Decimal128> e12 = {Decimal128("02."), Decimal128("30.")};
std::vector<Decimal128> v13 = {Decimal128("02.0000000000"), Decimal128("30.0000000000"),
Decimal128("30.0000000000")};
std::vector<Decimal128> e13 = {Decimal128("02."), Decimal128("30."), Decimal128("-1.")};

for (bool allow_decimal_truncate : {false, true}) {
options.allow_decimal_truncate = allow_decimal_truncate;
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(38, 10), v12, is_valid2, decimal(28, 0), e12, options);
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(38, 10), v13, is_valid3, decimal(28, 0), e13, options);
// and back
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(28, 0), e12, is_valid2, decimal(38, 10), v12, options);
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(28, 0), e13, is_valid3, decimal(38, 10), v13, options);
}

std::vector<Decimal128> v22 = {Decimal128("-02.1234567890"),
Decimal128("30.1234567890")};
std::vector<Decimal128> e22 = {Decimal128("-02."), Decimal128("30.")};
std::vector<Decimal128> f22 = {Decimal128("-02.0000000000"),
Decimal128("30.0000000000")};
std::vector<Decimal128> v23 = {Decimal128("-02.1234567890"),
Decimal128("30.1234567890"),
Decimal128("30.1234567890")};
std::vector<Decimal128> e23 = {Decimal128("-02."), Decimal128("30."),
Decimal128("-70.")};
std::vector<Decimal128> f23 = {Decimal128("-02.0000000000"),
Decimal128("30.0000000000"),
Decimal128("80.0000000000")};

options.allow_decimal_truncate = true;
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(38, 10), v22, is_valid2, decimal(28, 0), e22, options);
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(38, 10), v23, is_valid3, decimal(28, 0), e23, options);
// and back
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options);
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options);

options.allow_decimal_truncate = false;
CheckFails<Decimal128Type>(decimal(38, 10), v22, is_valid2, decimal(28, 0), options);
CheckFails<Decimal128Type>(decimal(38, 10), v23, is_valid3, decimal(28, 0), options);
// back case is ok
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options);
CheckCase<Decimal128Type, Decimal128, Decimal128Type, Decimal128>(
decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options);
}

TEST_F(TestCast, TimestampToTimestamp) {
CastOptions options;

Expand Down
Loading