Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cpp/src/arrow/ipc/json-internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class JsonSchemaWriter {
void WriteTypeMetadata(const TimeType& type) {
writer_->Key("unit");
writer_->String(GetTimeUnitName(type.unit));
writer_->Key("bitWidth");
writer_->Int(type.bit_width());
}

void WriteTypeMetadata(const DateType& type) {
Expand Down Expand Up @@ -608,6 +610,9 @@ static Status GetTime(const RjObject& json_type, std::shared_ptr<DataType>* type
const auto& json_unit = json_type.FindMember("unit");
RETURN_NOT_STRING("unit", json_unit, json_type);

const auto& json_bit_width = json_type.FindMember("bitWidth");
RETURN_NOT_INT("bitWidth", json_bit_width, json_type);

std::string unit_str = json_unit->value.GetString();

if (unit_str == "SECOND") {
Expand All @@ -623,6 +628,14 @@ static Status GetTime(const RjObject& json_type, std::shared_ptr<DataType>* type
ss << "Invalid time unit: " << unit_str;
return Status::Invalid(ss.str());
}

const auto& fw_type = static_cast<const FixedWidthType&>(**type);

int bit_width = json_bit_width->value.GetInt();
if (bit_width != fw_type.bit_width()) {
return Status::Invalid("Indicated bit width does not match unit");
}

return Status::OK();
}

Expand Down
11 changes: 9 additions & 2 deletions cpp/src/arrow/ipc/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,19 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data,
case flatbuf::Type_Time: {
auto time_type = static_cast<const flatbuf::Time*>(type_data);
TimeUnit unit = FromFlatbufferUnit(time_type->unit());
int32_t bit_width = time_type->bitWidth();
switch (unit) {
case TimeUnit::SECOND:
case TimeUnit::MILLI:
if (bit_width != 32) {
return Status::Invalid("Time is 32 bits for second/milli unit");
}
*out = time32(unit);
break;
default:
if (bit_width != 64) {
return Status::Invalid("Time is 64 bits for micro/nano unit");
}
*out = time64(unit);
break;
}
Expand Down Expand Up @@ -386,12 +393,12 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
case Type::TIME32: {
const auto& time_type = static_cast<const Time32Type&>(*type);
*out_type = flatbuf::Type_Time;
*offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit)).Union();
*offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit), 32).Union();
} break;
case Type::TIME64: {
const auto& time_type = static_cast<const Time64Type&>(*type);
*out_type = flatbuf::Type_Time;
*offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit)).Union();
*offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit), 64).Union();
} break;
case Type::TIMESTAMP: {
const auto& ts_type = static_cast<const TimestampType&>(*type);
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/type-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ TEST(TestTimeType, Equals) {
Time64Type t4(TimeUnit::NANO);
Time64Type t5(TimeUnit::MICRO);

ASSERT_EQ(32, t0.bit_width());
ASSERT_EQ(64, t3.bit_width());

ASSERT_TRUE(t0.Equals(t2));
ASSERT_TRUE(t1.Equals(t1));
ASSERT_FALSE(t1.Equals(t3));
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ struct ARROW_EXPORT Time32Type : public TimeType {
static constexpr Type::type type_id = Type::TIME32;
using c_type = int32_t;

int bit_width() const override { return static_cast<int>(sizeof(c_type) * 4); }
int bit_width() const override { return static_cast<int>(sizeof(c_type) * 8); }

explicit Time32Type(TimeUnit unit = TimeUnit::MILLI);

Expand Down
145 changes: 113 additions & 32 deletions integration/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,14 @@ def _get_buffers(self):

class IntegerType(PrimitiveType):

def __init__(self, name, is_signed, bit_width, nullable=True):
def __init__(self, name, is_signed, bit_width, nullable=True,
min_value=TEST_INT_MIN,
max_value=TEST_INT_MAX):
PrimitiveType.__init__(self, name, nullable=nullable)
self.is_signed = is_signed
self.bit_width = bit_width
self.min_value = min_value
self.max_value = max_value

@property
def numpy_type(self):
Expand All @@ -194,14 +198,80 @@ def _get_type(self):
def generate_column(self, size):
iinfo = np.iinfo(self.numpy_type)
values = [int(x) for x in
np.random.randint(max(iinfo.min, TEST_INT_MIN),
min(iinfo.max, TEST_INT_MAX),
np.random.randint(max(iinfo.min, self.min_value),
min(iinfo.max, self.max_value),
size=size)]

is_valid = self._make_is_valid(size)
return PrimitiveColumn(self.name, size, is_valid, values)


class DateType(IntegerType):

DAY = 0
MILLISECOND = 1

def __init__(self, name, unit, nullable=True):
self.unit = unit
bit_width = 32 if unit == self.DAY else 64
IntegerType.__init__(self, name, True, bit_width, nullable=nullable)

def _get_type(self):
return OrderedDict([
('name', 'date'),
('unit', 'DAY' if self.unit == self.DAY else 'MILLISECOND')
])


TIMEUNIT_NAMES = {
's': 'SECOND',
'ms': 'MILLISECOND',
'us': 'MICROSECOND',
'ns': 'NANOSECOND'
}


class TimeType(IntegerType):

BIT_WIDTHS = {
's': 32,
'ms': 32,
'us': 64,
'ns': 64
}

def __init__(self, name, unit='s', nullable=True):
self.unit = unit
IntegerType.__init__(self, name, True, self.BIT_WIDTHS[unit],
nullable=nullable)

def _get_type(self):
return OrderedDict([
('name', 'time'),
('unit', TIMEUNIT_NAMES[self.unit]),
('bitWidth', self.bit_width)
])


class TimestampType(IntegerType):

def __init__(self, name, unit='s', tz=None, nullable=True):
self.unit = unit
self.tz = tz
IntegerType.__init__(self, name, True, 64, nullable=nullable)

def _get_type(self):
fields = [
('name', 'timestamp'),
('unit', TIMEUNIT_NAMES[self.unit])
]

if self.tz is not None:
fields.append(('timezone', self.tz))

return OrderedDict(fields)


class FloatingPointType(PrimitiveType):

def __init__(self, name, bit_width, nullable=True):
Expand Down Expand Up @@ -509,6 +579,20 @@ def get_field(name, type_, nullable=True):
raise TypeError(dtype)


def _generate_file(fields, batch_sizes):
schema = JSONSchema(fields)
batches = []
for size in batch_sizes:
columns = []
for field in fields:
col = field.generate_column(size)
columns.append(col)

batches.append(JSONRecordBatch(size, columns))

return JSONFile(schema, batches)


def generate_primitive_case():
types = ['bool', 'int8', 'int16', 'int32', 'int64',
'uint8', 'uint16', 'uint32', 'uint64',
Expand All @@ -520,19 +604,27 @@ def generate_primitive_case():
fields.append(get_field(type_ + "_nullable", type_, True))
fields.append(get_field(type_ + "_nonnullable", type_, False))

schema = JSONSchema(fields)

batch_sizes = [7, 10]
batches = []
for size in batch_sizes:
columns = []
for field in fields:
col = field.generate_column(size)
columns.append(col)
return _generate_file(fields, batch_sizes)

batches.append(JSONRecordBatch(size, columns))

return JSONFile(schema, batches)
def generate_datetime_case():
fields = [
DateType('f0', DateType.DAY),
DateType('f1', DateType.MILLISECOND),
TimeType('f2', 's'),
TimeType('f3', 'ms'),
TimeType('f4', 'us'),
TimeType('f5', 'ns'),
TimestampType('f6', 's'),
TimestampType('f7', 'ms'),
TimestampType('f8', 'us'),
TimestampType('f9', 'ns'),
TimestampType('f10', 'ms', tz='America/New_York')
]

batch_sizes = [7, 10]
return _generate_file(fields, batch_sizes)


def generate_nested_case():
Expand All @@ -545,19 +637,8 @@ def generate_nested_case():
# ListType('list_nonnullable', get_field('item', 'int32'), False),
]

schema = JSONSchema(fields)

batch_sizes = [7, 10]
batches = []
for size in batch_sizes:
columns = []
for field in fields:
col = field.generate_column(size)
columns.append(col)

batches.append(JSONRecordBatch(size, columns))

return JSONFile(schema, batches)
return _generate_file(fields, batch_sizes)


def get_generated_json_files():
Expand All @@ -566,13 +647,13 @@ def get_generated_json_files():
def _temp_path():
return

file_objs = []

K = 10
for i in range(K):
file_objs.append(generate_primitive_case())

file_objs.append(generate_nested_case())
file_objs = [
generate_primitive_case(),
generate_primitive_case(),
generate_primitive_case(),
# generate_datetime_case(),
generate_nested_case()
]

generated_paths = []
for file_obj in file_objs:
Expand Down