Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7ac9235
Add casts for float16 <-> float32
ClifHouck Feb 13, 2024
f9568f0
Better template definitions for cast to/from halffloat
ClifHouck Feb 16, 2024
b626ee4
Fold casting from half float to float32/64 into GetCastToFloating
ClifHouck Feb 20, 2024
8f1aad0
Comment on CastPrimitive specializations from half float to float32/64
ClifHouck Feb 20, 2024
b1c7edf
string to float16 casts
ClifHouck Feb 26, 2024
d21d8bc
Add casts from int to half float
ClifHouck Feb 28, 2024
8f56a6f
Add HalfFloatType support to ConvertNumber
ClifHouck Feb 28, 2024
0ba90e8
Fix casting half float to int
ClifHouck Mar 5, 2024
465093b
Fix to CastPrimitve for other floating point types to half float
ClifHouck Mar 6, 2024
fb72199
Remove fixme comment
ClifHouck Mar 8, 2024
cad4936
Remove debug std::cout
ClifHouck Mar 8, 2024
9b2e53f
clang format run
ClifHouck Mar 8, 2024
9d1a979
Update notes about data types available in different arrow libraries
ClifHouck Mar 8, 2024
6f3e299
Possible fix for TestBatchToTensorTest failure
ClifHouck Mar 18, 2024
d9d3f76
'fix' TestIntegers failure
ClifHouck Mar 18, 2024
66750d3
clang-format
ClifHouck Mar 18, 2024
a3eb4f8
use Float16 operator for comparing equality
ClifHouck Mar 19, 2024
6a646a4
Revert "use Float16 operator for comparing equality"
ClifHouck Mar 20, 2024
90ab8a0
Change c_glib test-half-float-scalar string test to match current beh…
ClifHouck Mar 20, 2024
c9348e8
Refactor FloatingToString
ClifHouck Mar 22, 2024
87b244c
Make float type the outer loop in FloatingToString test
ClifHouck Mar 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c_glib/test/test-half-float-scalar.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_equal
end

def test_to_s
assert_equal("[\n #{@half_float}\n]", @scalar.to_s)
assert_equal("1.0009765625", @scalar.to_s)
end

def test_value
Expand Down
30 changes: 30 additions & 0 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/float16.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
Expand All @@ -59,6 +60,7 @@ using internal::BitmapReader;
using internal::BitmapUInt64Reader;
using internal::checked_cast;
using internal::OptionalBitmapEquals;
using util::Float16;

// ----------------------------------------------------------------------
// Public method implementations
Expand Down Expand Up @@ -95,6 +97,30 @@ struct FloatingEquality {
const T epsilon;
};

// For half-float equality.
template <typename Flags>
struct FloatingEquality<uint16_t, Flags> {
explicit FloatingEquality(const EqualOptions& options)
: epsilon(static_cast<float>(options.atol())) {}

bool operator()(uint16_t x, uint16_t y) const {
Float16 f_x = Float16::FromBits(x);
Float16 f_y = Float16::FromBits(y);
if (x == y) {
return Flags::signed_zeros_equal || (f_x.signbit() == f_y.signbit());
}
if (Flags::nans_equal && f_x.is_nan() && f_y.is_nan()) {
return true;
}
if (Flags::approximate && (fabs(f_x.ToFloat() - f_y.ToFloat()) <= epsilon)) {
return true;
}
return false;
}

const float epsilon;
};

template <typename T, typename Visitor>
struct FloatingEqualityDispatcher {
const EqualOptions& options;
Expand Down Expand Up @@ -259,6 +285,8 @@ class RangeDataEqualsImpl {

Status Visit(const DoubleType& type) { return CompareFloating(type); }

Status Visit(const HalfFloatType& type) { return CompareFloating(type); }

// Also matches StringType
Status Visit(const BinaryType& type) { return CompareBinary(type); }

Expand Down Expand Up @@ -863,6 +891,8 @@ class ScalarEqualsVisitor {

Status Visit(const DoubleScalar& left) { return CompareFloating(left); }

Status Visit(const HalfFloatScalar& left) { return CompareFloating(left); }

template <typename T>
enable_if_t<std::is_base_of<BaseBinaryScalar, T>::value, Status> Visit(const T& left) {
const auto& right = checked_cast<const BaseBinaryScalar&>(right_);
Expand Down
70 changes: 70 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
#include "arrow/compute/cast_internal.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/extension_type.h"
#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/float16.h"

namespace arrow {

using arrow::util::Float16;
using internal::checked_cast;
using internal::PrimitiveScalarBase;

Expand All @@ -47,6 +50,42 @@ struct CastPrimitive {
}
};

// Converting floating types to half float.
template <typename InType>
struct CastPrimitive<HalfFloatType, InType, enable_if_physical_floating_point<InType>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
using InT = typename InType::c_type;
const InT* in_values = arr.GetValues<InT>(1);
uint16_t* out_values = out->GetValues<uint16_t>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = Float16(*in_values++).bits();
}
}
};

// Converting from half float to other floating types.
template <>
struct CastPrimitive<FloatType, HalfFloatType, enable_if_t<true>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
const uint16_t* in_values = arr.GetValues<uint16_t>(1);
float* out_values = out->GetValues<float>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = Float16::FromBits(*in_values++).ToFloat();
}
}
};

template <>
struct CastPrimitive<DoubleType, HalfFloatType, enable_if_t<true>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
const uint16_t* in_values = arr.GetValues<uint16_t>(1);
double* out_values = out->GetValues<double>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = Float16::FromBits(*in_values++).ToDouble();
}
}
};

template <typename OutType, typename InType>
struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>::value>> {
// memcpy output
Expand All @@ -56,6 +95,33 @@ struct CastPrimitive<OutType, InType, enable_if_t<std::is_same<OutType, InType>:
}
};

// Cast int to half float
template <typename InType>
struct CastPrimitive<HalfFloatType, InType, enable_if_integer<InType>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
using InT = typename InType::c_type;
const InT* in_values = arr.GetValues<InT>(1);
uint16_t* out_values = out->GetValues<uint16_t>(1);
for (int64_t i = 0; i < arr.length; ++i) {
float temp = static_cast<float>(*in_values++);
*out_values++ = Float16(temp).bits();
}
}
};

// Cast half float to int
template <typename OutType>
struct CastPrimitive<OutType, HalfFloatType, enable_if_integer<OutType>> {
static void Exec(const ArraySpan& arr, ArraySpan* out) {
using OutT = typename OutType::c_type;
const uint16_t* in_values = arr.GetValues<uint16_t>(1);
OutT* out_values = out->GetValues<OutT>(1);
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = static_cast<OutT>(Float16::FromBits(*in_values++).ToFloat());
}
}
};

template <typename InType>
void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan* out) {
switch (out_type) {
Expand All @@ -79,6 +145,8 @@ void CastNumberImpl(Type::type out_type, const ArraySpan& input, ArraySpan* out)
return CastPrimitive<FloatType, InType>::Exec(input, out);
case Type::DOUBLE:
return CastPrimitive<DoubleType, InType>::Exec(input, out);
case Type::HALF_FLOAT:
return CastPrimitive<HalfFloatType, InType>::Exec(input, out);
default:
break;
}
Expand Down Expand Up @@ -109,6 +177,8 @@ void CastNumberToNumberUnsafe(Type::type in_type, Type::type out_type,
return CastNumberImpl<FloatType>(out_type, input, out);
case Type::DOUBLE:
return CastNumberImpl<DoubleType>(out_type, input, out);
case Type::HALF_FLOAT:
return CastNumberImpl<HalfFloatType>(out_type, input, out);
default:
DCHECK(false);
break;
Expand Down
103 changes: 87 additions & 16 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/scalar.h"
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/float16.h"
#include "arrow/util/int_util.h"
#include "arrow/util/value_parsing.h"

Expand All @@ -34,6 +35,7 @@ using internal::IntegersCanFit;
using internal::OptionalBitBlockCounter;
using internal::ParseValue;
using internal::PrimitiveScalarBase;
using util::Float16;

namespace compute {
namespace internal {
Expand All @@ -56,18 +58,37 @@ Status CastFloatingToFloating(KernelContext*, const ExecSpan& batch, ExecResult*

// ----------------------------------------------------------------------
// Implement fast safe floating point to integer cast
//
template <typename InType, typename OutType, typename InT = typename InType::c_type,
typename OutT = typename OutType::c_type>
struct WasTruncated {
static bool Check(OutT out_val, InT in_val) {
return static_cast<InT>(out_val) != in_val;
}

static bool CheckMaybeNull(OutT out_val, InT in_val, bool is_valid) {
return is_valid && static_cast<InT>(out_val) != in_val;
}
};

// Half float to int
template <typename OutType>
struct WasTruncated<HalfFloatType, OutType> {
using OutT = typename OutType::c_type;
static bool Check(OutT out_val, uint16_t in_val) {
return static_cast<float>(out_val) != Float16::FromBits(in_val).ToFloat();
}

static bool CheckMaybeNull(OutT out_val, uint16_t in_val, bool is_valid) {
return is_valid && static_cast<float>(out_val) != Float16::FromBits(in_val).ToFloat();
}
};

// InType is a floating point type we are planning to cast to integer
template <typename InType, typename OutType, typename InT = typename InType::c_type,
typename OutT = typename OutType::c_type>
ARROW_DISABLE_UBSAN("float-cast-overflow")
Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) {
auto WasTruncated = [&](OutT out_val, InT in_val) -> bool {
return static_cast<InT>(out_val) != in_val;
};
auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool {
return is_valid && static_cast<InT>(out_val) != in_val;
};
auto GetErrorMessage = [&](InT val) {
return Status::Invalid("Float value ", val, " was truncated converting to ",
*output.type);
Expand All @@ -86,26 +107,28 @@ Status CheckFloatTruncation(const ArraySpan& input, const ArraySpan& output) {
if (block.popcount == block.length) {
// Fast path: branchless
for (int64_t i = 0; i < block.length; ++i) {
block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]);
block_out_of_bounds |=
WasTruncated<InType, OutType>::Check(out_data[i], in_data[i]);
}
} else if (block.popcount > 0) {
// Indices have nulls, must only boundscheck non-null values
for (int64_t i = 0; i < block.length; ++i) {
block_out_of_bounds |= WasTruncatedMaybeNull(
block_out_of_bounds |= WasTruncated<InType, OutType>::CheckMaybeNull(
out_data[i], in_data[i], bit_util::GetBit(bitmap, offset_position + i));
}
}
if (ARROW_PREDICT_FALSE(block_out_of_bounds)) {
if (input.GetNullCount() > 0) {
for (int64_t i = 0; i < block.length; ++i) {
if (WasTruncatedMaybeNull(out_data[i], in_data[i],
bit_util::GetBit(bitmap, offset_position + i))) {
if (WasTruncated<InType, OutType>::CheckMaybeNull(
out_data[i], in_data[i],
bit_util::GetBit(bitmap, offset_position + i))) {
return GetErrorMessage(in_data[i]);
}
}
} else {
for (int64_t i = 0; i < block.length; ++i) {
if (WasTruncated(out_data[i], in_data[i])) {
if (WasTruncated<InType, OutType>::Check(out_data[i], in_data[i])) {
return GetErrorMessage(in_data[i]);
}
}
Expand Down Expand Up @@ -151,6 +174,9 @@ Status CheckFloatToIntTruncation(const ExecValue& input, const ExecResult& outpu
return CheckFloatToIntTruncationImpl<FloatType>(input.array, *output.array_span());
case Type::DOUBLE:
return CheckFloatToIntTruncationImpl<DoubleType>(input.array, *output.array_span());
case Type::HALF_FLOAT:
return CheckFloatToIntTruncationImpl<HalfFloatType>(input.array,
*output.array_span());
default:
break;
}
Expand Down Expand Up @@ -293,6 +319,15 @@ struct CastFunctor<
}
};

template <>
struct CastFunctor<HalfFloatType, StringType, enable_if_t<true>> {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return applicator::ScalarUnaryNotNull<HalfFloatType, StringType,
ParseString<HalfFloatType>>::Exec(ctx, batch,
out);
}
};

// ----------------------------------------------------------------------
// Decimal to integer

Expand Down Expand Up @@ -689,6 +724,10 @@ std::shared_ptr<CastFunction> GetCastToInteger(std::string name) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger));
}

// Cast from half-float
DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)}, out_ty,
CastFloatingToInteger));

// From other numbers to integer
AddCommonNumberCasts<OutType>(out_ty, func.get());

Expand All @@ -715,6 +754,10 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating));
}

// From half-float to float/double
DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {InputType(Type::HALF_FLOAT)}, out_ty,
CastFloatingToFloating));

// From other numbers to floating point
AddCommonNumberCasts<OutType>(out_ty, func.get());

Expand All @@ -723,6 +766,7 @@ std::shared_ptr<CastFunction> GetCastToFloating(std::string name) {
CastFunctor<OutType, Decimal128Type>::Exec));
DCHECK_OK(func->AddKernel(Type::DECIMAL256, {InputType(Type::DECIMAL256)}, out_ty,
CastFunctor<OutType, Decimal256Type>::Exec));

return func;
}

Expand Down Expand Up @@ -795,6 +839,32 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() {
return func;
}

std::shared_ptr<CastFunction> GetCastToHalfFloat() {
// HalfFloat is a bit brain-damaged for now
auto func = std::make_shared<CastFunction>("func", Type::HALF_FLOAT);
AddCommonCasts(Type::HALF_FLOAT, float16(), func.get());

// Casts from integer to floating point
for (const std::shared_ptr<DataType>& in_ty : IntTypes()) {
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty},
TypeTraits<HalfFloatType>::type_singleton(),
CastIntegerToFloating));
}

// Cast from other strings to half float.
for (const std::shared_ptr<DataType>& in_ty : BaseBinaryTypes()) {
auto exec = GenerateVarBinaryBase<CastFunctor, HalfFloatType>(*in_ty);
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty},
TypeTraits<HalfFloatType>::type_singleton(), exec));
}

DCHECK_OK(func.get()->AddKernel(Type::FLOAT, {InputType(Type::FLOAT)}, float16(),
CastFloatingToFloating));
DCHECK_OK(func.get()->AddKernel(Type::DOUBLE, {InputType(Type::DOUBLE)}, float16(),
CastFloatingToFloating));
return func;
}

} // namespace

std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
Expand Down Expand Up @@ -830,13 +900,14 @@ std::vector<std::shared_ptr<CastFunction>> GetNumericCasts() {
functions.push_back(GetCastToInteger<UInt64Type>("cast_uint64"));

// HalfFloat is a bit brain-damaged for now
auto cast_half_float =
std::make_shared<CastFunction>("cast_half_float", Type::HALF_FLOAT);
AddCommonCasts(Type::HALF_FLOAT, float16(), cast_half_float.get());
auto cast_half_float = GetCastToHalfFloat();
functions.push_back(cast_half_float);

functions.push_back(GetCastToFloating<FloatType>("cast_float"));
functions.push_back(GetCastToFloating<DoubleType>("cast_double"));
auto cast_float = GetCastToFloating<FloatType>("cast_float");
functions.push_back(cast_float);

auto cast_double = GetCastToFloating<DoubleType>("cast_double");
functions.push_back(cast_double);

functions.push_back(GetCastToDecimal128());
functions.push_back(GetCastToDecimal256());
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_cast_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ void AddNumberToStringCasts(CastFunction* func) {
GenerateNumeric<NumericToStringCastFunctor, OutType>(*in_ty),
NullHandling::COMPUTED_NO_PREALLOCATE));
}

DCHECK_OK(func->AddKernel(Type::HALF_FLOAT, {float16()}, out_ty,
NumericToStringCastFunctor<OutType, HalfFloatType>::Exec,
NullHandling::COMPUTED_NO_PREALLOCATE));
}

template <typename OutType>
Expand Down
Loading