Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions cpp/src/arrow/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,43 +380,33 @@ struct ARROW_EXPORT DurationScalar : public TemporalScalar<DurationType> {
: DurationScalar(std::move(value), duration(unit)) {}
};

struct ARROW_EXPORT Decimal128Scalar : public internal::PrimitiveScalarBase {
template <typename TYPE_CLASS, typename VALUE_TYPE>
struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase {
using internal::PrimitiveScalarBase::PrimitiveScalarBase;
using TypeClass = Decimal128Type;
using ValueType = Decimal128;
using TypeClass = TYPE_CLASS;
using ValueType = VALUE_TYPE;

Decimal128Scalar(Decimal128 value, std::shared_ptr<DataType> type)
DecimalScalar(ValueType value, std::shared_ptr<DataType> type)
: internal::PrimitiveScalarBase(std::move(type), true), value(value) {}

void* mutable_data() override {
return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
}

util::string_view view() const override {
return util::string_view(reinterpret_cast<const char*>(value.native_endian_bytes()),
16);
ValueType::kByteWidth);
}

Decimal128 value;
ValueType value;
};

struct ARROW_EXPORT Decimal256Scalar : public internal::PrimitiveScalarBase {
using internal::PrimitiveScalarBase::PrimitiveScalarBase;
using TypeClass = Decimal256Type;
using ValueType = Decimal256;

Decimal256Scalar(Decimal256 value, std::shared_ptr<DataType> type)
: internal::PrimitiveScalarBase(std::move(type), true), value(value) {}

void* mutable_data() override {
return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
}
util::string_view view() const override {
const std::array<uint64_t, 4>& bytes = value.native_endian_array();
return util::string_view(reinterpret_cast<const char*>(bytes.data()),
bytes.size() * sizeof(uint64_t));
}
struct ARROW_EXPORT Decimal128Scalar : public DecimalScalar<Decimal128Type, Decimal128> {
using DecimalScalar::DecimalScalar;
};

Decimal256 value;
struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar<Decimal256Type, Decimal256> {
using DecimalScalar::DecimalScalar;
};

struct ARROW_EXPORT BaseListScalar : public Scalar {
Expand Down
136 changes: 48 additions & 88 deletions cpp/src/arrow/util/basic_decimal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@

namespace arrow {

using internal::AddWithOverflow;
using internal::SafeLeftShift;
using internal::SafeSignedAdd;
using internal::SafeSignedSubtract;
using internal::SubtractWithOverflow;

static const BasicDecimal128 ScaleMultipliers[] = {
BasicDecimal128(1LL),
Expand Down Expand Up @@ -368,43 +371,16 @@ static constexpr uint64_t kInt32Mask = 0xFFFFFFFF;
static constexpr BasicDecimal128 kMaxValue =
BasicDecimal128(5421010862427522170LL, 687399551400673280ULL - 1);

#if ARROW_LITTLE_ENDIAN
BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
: BasicDecimal128(reinterpret_cast<const int64_t*>(bytes)[1],
reinterpret_cast<const uint64_t*>(bytes)[0]) {}
#else
BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
: BasicDecimal128(reinterpret_cast<const int64_t*>(bytes)[0],
reinterpret_cast<const uint64_t*>(bytes)[1]) {}
#endif

constexpr int BasicDecimal128::kBitWidth;
constexpr int BasicDecimal128::kMaxPrecision;
constexpr int BasicDecimal128::kMaxScale;

std::array<uint8_t, 16> BasicDecimal128::ToBytes() const {
std::array<uint8_t, 16> out{{0}};
ToBytes(out.data());
return out;
}

void BasicDecimal128::ToBytes(uint8_t* out) const {
DCHECK_NE(out, nullptr);
#if ARROW_LITTLE_ENDIAN
reinterpret_cast<uint64_t*>(out)[0] = low_bits_;
reinterpret_cast<int64_t*>(out)[1] = high_bits_;
#else
reinterpret_cast<int64_t*>(out)[0] = high_bits_;
reinterpret_cast<uint64_t*>(out)[1] = low_bits_;
#endif
}

BasicDecimal128& BasicDecimal128::Negate() {
low_bits_ = ~low_bits_ + 1;
high_bits_ = ~high_bits_;
if (low_bits_ == 0) {
high_bits_ = SafeSignedAdd<int64_t>(high_bits_, 1);
uint64_t result_lo = ~low_bits() + 1;
int64_t result_hi = ~high_bits();
if (result_lo == 0) {
result_hi = SafeSignedAdd<int64_t>(result_hi, 1);
}
*this = BasicDecimal128(result_hi, result_lo);
return *this;
}

Expand All @@ -422,22 +398,18 @@ bool BasicDecimal128::FitsInPrecision(int32_t precision) const {
}

BasicDecimal128& BasicDecimal128::operator+=(const BasicDecimal128& right) {
const uint64_t sum = low_bits_ + right.low_bits_;
high_bits_ = SafeSignedAdd<int64_t>(high_bits_, right.high_bits_);
if (sum < low_bits_) {
high_bits_ = SafeSignedAdd<int64_t>(high_bits_, 1);
}
low_bits_ = sum;
int64_t result_hi = SafeSignedAdd(high_bits(), right.high_bits());
uint64_t result_lo = low_bits() + right.low_bits();
result_hi = SafeSignedAdd<int64_t>(result_hi, result_lo < low_bits());
*this = BasicDecimal128(result_hi, result_lo);
return *this;
}

BasicDecimal128& BasicDecimal128::operator-=(const BasicDecimal128& right) {
const uint64_t diff = low_bits_ - right.low_bits_;
high_bits_ -= right.high_bits_;
if (diff > low_bits_) {
--high_bits_;
}
low_bits_ = diff;
int64_t result_hi = SafeSignedSubtract(high_bits(), right.high_bits());
uint64_t result_lo = low_bits() - right.low_bits();
result_hi = SafeSignedSubtract<int64_t>(result_hi, result_lo > low_bits());
*this = BasicDecimal128(result_hi, result_lo);
return *this;
}

Expand All @@ -449,47 +421,53 @@ BasicDecimal128& BasicDecimal128::operator/=(const BasicDecimal128& right) {
}

BasicDecimal128& BasicDecimal128::operator|=(const BasicDecimal128& right) {
low_bits_ |= right.low_bits_;
high_bits_ |= right.high_bits_;
array_[0] |= right.array_[0];
array_[1] |= right.array_[1];
return *this;
}

BasicDecimal128& BasicDecimal128::operator&=(const BasicDecimal128& right) {
low_bits_ &= right.low_bits_;
high_bits_ &= right.high_bits_;
array_[0] &= right.array_[0];
array_[1] &= right.array_[1];
return *this;
}

BasicDecimal128& BasicDecimal128::operator<<=(uint32_t bits) {
if (bits != 0) {
uint64_t result_lo;
int64_t result_hi;
if (bits < 64) {
high_bits_ = SafeLeftShift(high_bits_, bits);
high_bits_ |= (low_bits_ >> (64 - bits));
low_bits_ <<= bits;
result_hi = SafeLeftShift(high_bits(), bits);
result_hi |= (low_bits() >> (64 - bits));
result_lo = low_bits() << bits;
} else if (bits < 128) {
high_bits_ = static_cast<int64_t>(low_bits_) << (bits - 64);
low_bits_ = 0;
result_hi = static_cast<int64_t>(low_bits() << (bits - 64));
result_lo = 0;
} else {
high_bits_ = 0;
low_bits_ = 0;
result_hi = 0;
result_lo = 0;
}
*this = BasicDecimal128(result_hi, result_lo);
}
return *this;
}

BasicDecimal128& BasicDecimal128::operator>>=(uint32_t bits) {
if (bits != 0) {
uint64_t result_lo;
int64_t result_hi;
if (bits < 64) {
low_bits_ >>= bits;
low_bits_ |= static_cast<uint64_t>(high_bits_ << (64 - bits));
high_bits_ = static_cast<int64_t>(static_cast<uint64_t>(high_bits_) >> bits);
result_lo = low_bits() >> bits;
result_lo |= static_cast<uint64_t>(high_bits()) << (64 - bits);
result_hi = high_bits() >> bits;
} else if (bits < 128) {
low_bits_ = static_cast<uint64_t>(high_bits_ >> (bits - 64));
high_bits_ = static_cast<int64_t>(high_bits_ >= 0L ? 0L : -1L);
result_lo = static_cast<uint64_t>(high_bits() >> (bits - 64));
result_hi = high_bits() >> 63;
} else {
high_bits_ = static_cast<int64_t>(high_bits_ >= 0L ? 0L : -1L);
low_bits_ = static_cast<uint64_t>(high_bits_);
result_hi = high_bits() >> 63;
result_lo = static_cast<uint64_t>(result_hi);
}
*this = BasicDecimal128(result_hi, result_lo);
}
return *this;
}
Expand Down Expand Up @@ -633,8 +611,7 @@ BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
BasicDecimal128 y = BasicDecimal128::Abs(right);
uint128_t r(x);
r *= uint128_t{y};
high_bits_ = r.hi();
low_bits_ = r.lo();
*this = BasicDecimal128(static_cast<int64_t>(r.hi()), r.lo());
if (negate) {
Negate();
}
Expand Down Expand Up @@ -1158,20 +1135,13 @@ BasicDecimal128 BasicDecimal128::ReduceScaleBy(int32_t reduce_by, bool round) co
int32_t BasicDecimal128::CountLeadingBinaryZeros() const {
DCHECK_GE(*this, BasicDecimal128(0));

if (high_bits_ == 0) {
return bit_util::CountLeadingZeros(low_bits_) + 64;
if (high_bits() == 0) {
return bit_util::CountLeadingZeros(low_bits()) + 64;
} else {
return bit_util::CountLeadingZeros(static_cast<uint64_t>(high_bits_));
return bit_util::CountLeadingZeros(static_cast<uint64_t>(high_bits()));
}
}

BasicDecimal256::BasicDecimal256(const uint8_t* bytes)
: array_({reinterpret_cast<const uint64_t*>(bytes)[0],
reinterpret_cast<const uint64_t*>(bytes)[1],
reinterpret_cast<const uint64_t*>(bytes)[2],
reinterpret_cast<const uint64_t*>(bytes)[3]}) {}

constexpr int BasicDecimal256::kBitWidth;
constexpr int BasicDecimal256::kMaxPrecision;
constexpr int BasicDecimal256::kMaxScale;

Expand Down Expand Up @@ -1243,20 +1213,6 @@ BasicDecimal256& BasicDecimal256::operator<<=(uint32_t bits) {
return *this;
}

std::array<uint8_t, 32> BasicDecimal256::ToBytes() const {
std::array<uint8_t, 32> out{{0}};
ToBytes(out.data());
return out;
}

void BasicDecimal256::ToBytes(uint8_t* out) const {
DCHECK_NE(out, nullptr);
reinterpret_cast<uint64_t*>(out)[0] = array_[0];
reinterpret_cast<uint64_t*>(out)[1] = array_[1];
reinterpret_cast<uint64_t*>(out)[2] = array_[2];
reinterpret_cast<uint64_t*>(out)[3] = array_[3];
}

BasicDecimal256& BasicDecimal256::operator*=(const BasicDecimal256& right) {
// Since the max value of BasicDecimal256 is supposed to be 1e76 - 1 and the
// min the negation taking the absolute values here should always be safe.
Expand Down Expand Up @@ -1391,4 +1347,8 @@ BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& ri
return result;
}

// Explicitly instantiate template base class, for DLL linking on Windows
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this approach avoid putting all template class implementation code inside the header?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should, though that's not the goal here :-)

template class GenericBasicDecimal<BasicDecimal128, 128>;
template class GenericBasicDecimal<BasicDecimal256, 256>;

} // namespace arrow
Loading