Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,19 +740,27 @@ bool TensorEquals(const Tensor& left, const Tensor& right) {
are_equal = false;
} else {
const auto& type = static_cast<const FixedWidthType&>(*left.type());
// Type::BOOL strided tensors are currently not supported
DCHECK_GT(type.bit_width() / CHAR_BIT, 0);
are_equal =
StridedTensorContentEquals(0, 0, 0, type.bit_width() / 8, left, right);
StridedTensorContentEquals(0, 0, 0, type.bit_width() / CHAR_BIT, left, right);
}
} else {
const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
const int byte_width = size_meta.bit_width() / CHAR_BIT;
DCHECK_GT(byte_width, 0);

const uint8_t* left_data = left.data()->data();
const uint8_t* right_data = right.data()->data();

are_equal = memcmp(left_data, right_data,
static_cast<size_t>(byte_width * left.size())) == 0;
if (size_meta.bit_width() == 1) {
int64_t bytes = (left.size() + CHAR_BIT - 1) / CHAR_BIT;
are_equal = memcmp(left_data, right_data,
static_cast<size_t>(bytes)) == 0;
} else {
const int byte_width = size_meta.bit_width() / CHAR_BIT;
DCHECK_GT(byte_width, 0);
are_equal = memcmp(left_data, right_data,
static_cast<size_t>(byte_width * left.size())) == 0;
}
}
}
return are_equal;
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/ipc/ipc-read-write-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,14 +728,25 @@ TEST_F(TestTensorRoundTrip, BasicRoundtrip) {

std::vector<int64_t> values;
test::randint<int64_t>(size, 0, 100, &values);
std::vector<bool> bool_values;
test::randbool(size, &bool_values);
std::vector<uint8_t> bool8_values;
test::randint<uint8_t>(size, 0, 1, &bool8_values);

auto data = test::GetBufferFromVector(values);
std::shared_ptr<Buffer> bool_data;
ASSERT_OK(test::GetBitmapFromVector(bool_values, &bool_data));
auto bool8_data = test::GetBufferFromVector(bool8_values);

Tensor t0(int64(), data, shape, strides, dim_names);
Tensor tzero(int64(), data, {}, {}, {});
Tensor tbool(boolean(), bool_data, {}, {}, {});
Tensor tbool8(boolean8(), bool8_data, {}, {}, {});

CheckTensorRoundTrip(t0);
CheckTensorRoundTrip(tzero);
CheckTensorRoundTrip(tbool);
CheckTensorRoundTrip(tbool8);

int64_t serialized_size;
ASSERT_OK(GetTensorSize(t0, &serialized_size));
Expand Down
18 changes: 16 additions & 2 deletions cpp/src/arrow/ipc/metadata-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,15 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
case flatbuf::Type_Utf8:
*out = utf8();
return Status::OK();
case flatbuf::Type_Bool:
*out = boolean();
case flatbuf::Type_Bool: {
auto bool_type = static_cast<const flatbuf::Bool*>(type_data);
if (bool_type->is_byte()) {
*out = boolean8();
} else {
*out = boolean();
}
return Status::OK();
}
case flatbuf::Type_Decimal: {
auto dec_type = static_cast<const flatbuf::Decimal*>(type_data);
*out = decimal(dec_type->precision(), dec_type->scale());
Expand Down Expand Up @@ -458,6 +464,14 @@ static Status TypeToFlatbuffer(FBB& fbb, const DataType& type,
static Status TensorTypeToFlatbuffer(FBB& fbb, const DataType& type,
flatbuf::Type* out_type, Offset* offset) {
switch (type.id()) {
case Type::BOOL:
*out_type = flatbuf::Type_Bool;
*offset = flatbuf::CreateBool(fbb).Union();
break;
case Type::BOOL8:
*out_type = flatbuf::Type_Bool;
*offset = flatbuf::CreateBool(fbb, true).Union();
break;
case Type::UINT8:
INT_TO_FB_CASE(8, false);
case Type::INT8:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace arrow {

static inline bool is_tensor_supported(Type::type type_id) {
switch (type_id) {
case Type::BOOL:
case Type::BOOL8:
case Type::UINT8:
case Type::INT8:
case Type::UINT16:
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/test-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ using ArrayVector = std::vector<std::shared_ptr<Array>>;

namespace test {

void randbool(int64_t N, std::vector<bool>* out) {
Random rng(random_seed());
bool val;
for (int64_t i = 0; i < N; ++i) {
val = rng.OneIn(2);
out->push_back(val);
}
}

template <typename T>
void randint(int64_t N, T lower, T upper, std::vector<T>* out) {
Random rng(random_seed());
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ bool DataType::Equals(const std::shared_ptr<DataType>& other) const {

std::string BooleanType::ToString() const { return name(); }

std::string Boolean8Type::ToString() const { return name(); }

FloatingPoint::Precision HalfFloatType::precision() const { return FloatingPoint::HALF; }

FloatingPoint::Precision FloatType::precision() const { return FloatingPoint::SINGLE; }
Expand Down Expand Up @@ -368,6 +370,7 @@ std::shared_ptr<Schema> schema(std::vector<std::shared_ptr<Field>>&& fields,

ACCEPT_VISITOR(NullType);
ACCEPT_VISITOR(BooleanType);
ACCEPT_VISITOR(Boolean8Type);
ACCEPT_VISITOR(BinaryType);
ACCEPT_VISITOR(FixedSizeBinaryType);
ACCEPT_VISITOR(StringType);
Expand All @@ -391,6 +394,7 @@ ACCEPT_VISITOR(DictionaryType);

TYPE_FACTORY(null, NullType);
TYPE_FACTORY(boolean, BooleanType);
TYPE_FACTORY(boolean8, Boolean8Type);
TYPE_FACTORY(int8, Int8Type);
TYPE_FACTORY(uint8, UInt8Type);
TYPE_FACTORY(int16, Int16Type);
Expand Down Expand Up @@ -464,6 +468,7 @@ static const BufferDescr kValidityBuffer(BufferType::VALIDITY, 1);
static const BufferDescr kOffsetBuffer(BufferType::OFFSET, 32);
static const BufferDescr kTypeBuffer(BufferType::TYPE, 32);
static const BufferDescr kBooleanBuffer(BufferType::DATA, 1);
static const BufferDescr kBoolean8Buffer(BufferType::DATA, 8);
static const BufferDescr kValues64(BufferType::DATA, 64);
static const BufferDescr kValues32(BufferType::DATA, 32);
static const BufferDescr kValues16(BufferType::DATA, 16);
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ struct Type {
/// Boolean as 1 bit, LSB bit-packed ordering
BOOL,

/// Boolean as 1 byte (may only be used in Tensors)
BOOL8,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm -1 on this type as it further complicates the spec. Introducing it will affect more than just the tensors. Any reasoning to also have this on the JAVA side of things, too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t agree that adding metadata expands the “Arrow specification” in a strict sense — we need some way to accommodate this particular logical type (which occurs in NumPy, R, deep learning frameworks, etc), so if adding a new enumeration value won’t do it let’s try to come up with another solution. If this were something more esoteric I would agree, but we have put off dealing with this kind of data for a while (also an issue with Feather format).

Having an additional boolean type metadata does not mean that we have to implement two versions of every algorithm, or it may be that the byte-boolean version would cast to bit-packed (a cast we definitely need to implement — so having it in arrow::compute::Cast would be valuable) and then use the bit packed path.

Copy link
Contributor

@jacques-n jacques-n Oct 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per my comments, I'm mostly -1 on this as well. Just because someone else calls a 1 byte value boolean doesn't mean it has to map to arrow's boolean. This is just a uint8 from my perspective.


/// Unsigned 8-bit little-endian integer
UINT8,

Expand Down Expand Up @@ -332,6 +335,19 @@ class ARROW_EXPORT BooleanType : public FixedWidthType, public NoExtraMeta {
std::string name() const override { return "bool"; }
};

class ARROW_EXPORT Boolean8Type : public FixedWidthType, public NoExtraMeta {
public:
static constexpr Type::type type_id = Type::BOOL8;

Boolean8Type() : FixedWidthType(Type::BOOL8) {}

Status Accept(TypeVisitor* visitor) const override;
std::string ToString() const override;

int bit_width() const override { return 8; }
std::string name() const override { return "bool8"; }
};

class ARROW_EXPORT UInt8Type
: public detail::IntegerTypeImpl<UInt8Type, Type::UINT8, uint8_t> {
public:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/type_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class NullArray;
class NullBuilder;

class BooleanType;
class Boolean8Type;
class BooleanArray;
class BooleanBuilder;

Expand Down Expand Up @@ -132,6 +133,7 @@ using IntervalArray = NumericArray<IntervalType>;

std::shared_ptr<DataType> ARROW_EXPORT null();
std::shared_ptr<DataType> ARROW_EXPORT boolean();
std::shared_ptr<DataType> ARROW_EXPORT boolean8();
std::shared_ptr<DataType> ARROW_EXPORT int8();
std::shared_ptr<DataType> ARROW_EXPORT int16();
std::shared_ptr<DataType> ARROW_EXPORT int32();
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ ARRAY_VISITOR_DEFAULT(DecimalArray);

TYPE_VISITOR_DEFAULT(NullType);
TYPE_VISITOR_DEFAULT(BooleanType);
TYPE_VISITOR_DEFAULT(Boolean8Type);
TYPE_VISITOR_DEFAULT(Int8Type);
TYPE_VISITOR_DEFAULT(Int16Type);
TYPE_VISITOR_DEFAULT(Int32Type);
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ARROW_EXPORT TypeVisitor {

virtual Status Visit(const NullType& type);
virtual Status Visit(const BooleanType& type);
virtual Status Visit(const Boolean8Type& type);
virtual Status Visit(const Int8Type& type);
virtual Status Visit(const Int16Type& type);
virtual Status Visit(const Int32Type& type);
Expand Down
3 changes: 3 additions & 0 deletions format/Metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ table FloatingPoint {

The Boolean logical type is represented as a 1-bit wide primitive physical
type. The bits are numbered using least-significant bit (LSB) ordering.
Inside of tensors, boolean data may also be represented as
a 1-byte wide primitive physical type; in this case the
flag `is_byte` is set.

Like other fixed bit-width primitive types, boolean data appears as 2 buffers
in the data header (one bitmap for the validity vector and one for the values).
Expand Down
3 changes: 3 additions & 0 deletions format/Schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ table FixedSizeBinary {
}

table Bool {
/// If this flag is set, the bool is represented as a byte,
/// by default it is represented as a bit.
is_byte: bool;
}

table Decimal {
Expand Down