diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index 1e2385b73f8..124c21b8fc0 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -175,6 +175,8 @@ class JsonSchemaWriter { void WriteTypeMetadata(const TimeType& type) { writer_->Key("unit"); writer_->String(GetTimeUnitName(type.unit)); + writer_->Key("bitWidth"); + writer_->Int(type.bit_width()); } void WriteTypeMetadata(const DateType& type) { @@ -608,6 +610,9 @@ static Status GetTime(const RjObject& json_type, std::shared_ptr* type const auto& json_unit = json_type.FindMember("unit"); RETURN_NOT_STRING("unit", json_unit, json_type); + const auto& json_bit_width = json_type.FindMember("bitWidth"); + RETURN_NOT_INT("bitWidth", json_bit_width, json_type); + std::string unit_str = json_unit->value.GetString(); if (unit_str == "SECOND") { @@ -623,6 +628,14 @@ static Status GetTime(const RjObject& json_type, std::shared_ptr* type ss << "Invalid time unit: " << unit_str; return Status::Invalid(ss.str()); } + + const auto& fw_type = static_cast(**type); + + int bit_width = json_bit_width->value.GetInt(); + if (bit_width != fw_type.bit_width()) { + return Status::Invalid("Indicated bit width does not match unit"); + } + return Status::OK(); } diff --git a/cpp/src/arrow/ipc/metadata.cc b/cpp/src/arrow/ipc/metadata.cc index 5007f130908..92fa623846b 100644 --- a/cpp/src/arrow/ipc/metadata.cc +++ b/cpp/src/arrow/ipc/metadata.cc @@ -255,12 +255,19 @@ static Status TypeFromFlatbuffer(flatbuf::Type type, const void* type_data, case flatbuf::Type_Time: { auto time_type = static_cast(type_data); TimeUnit unit = FromFlatbufferUnit(time_type->unit()); + int32_t bit_width = time_type->bitWidth(); switch (unit) { case TimeUnit::SECOND: case TimeUnit::MILLI: + if (bit_width != 32) { + return Status::Invalid("Time is 32 bits for second/milli unit"); + } *out = time32(unit); break; default: + if (bit_width != 64) { + return Status::Invalid("Time is 64 bits for micro/nano unit"); + } *out = time64(unit); break; } @@ -386,12 +393,12 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr& type, case Type::TIME32: { const auto& time_type = static_cast(*type); *out_type = flatbuf::Type_Time; - *offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit)).Union(); + *offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit), 32).Union(); } break; case Type::TIME64: { const auto& time_type = static_cast(*type); *out_type = flatbuf::Type_Time; - *offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit)).Union(); + *offset = flatbuf::CreateTime(fbb, ToFlatbufferUnit(time_type.unit), 64).Union(); } break; case Type::TIMESTAMP: { const auto& ts_type = static_cast(*type); diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc index dafadc168c1..7ff9479dea6 100644 --- a/cpp/src/arrow/type-test.cc +++ b/cpp/src/arrow/type-test.cc @@ -207,6 +207,9 @@ TEST(TestTimeType, Equals) { Time64Type t4(TimeUnit::NANO); Time64Type t5(TimeUnit::MICRO); + ASSERT_EQ(32, t0.bit_width()); + ASSERT_EQ(64, t3.bit_width()); + ASSERT_TRUE(t0.Equals(t2)); ASSERT_TRUE(t1.Equals(t1)); ASSERT_FALSE(t1.Equals(t3)); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 6b936f348d4..7b5fb9a5d72 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -509,7 +509,7 @@ struct ARROW_EXPORT Time32Type : public TimeType { static constexpr Type::type type_id = Type::TIME32; using c_type = int32_t; - int bit_width() const override { return static_cast(sizeof(c_type) * 4); } + int bit_width() const override { return static_cast(sizeof(c_type) * 8); } explicit Time32Type(TimeUnit unit = TimeUnit::MILLI); diff --git a/integration/integration_test.py b/integration/integration_test.py index ec2a38d840d..37af4009858 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -175,10 +175,14 @@ def _get_buffers(self): class IntegerType(PrimitiveType): - def __init__(self, name, is_signed, bit_width, nullable=True): + def __init__(self, name, is_signed, bit_width, nullable=True, + min_value=TEST_INT_MIN, + max_value=TEST_INT_MAX): PrimitiveType.__init__(self, name, nullable=nullable) self.is_signed = is_signed self.bit_width = bit_width + self.min_value = min_value + self.max_value = max_value @property def numpy_type(self): @@ -194,14 +198,80 @@ def _get_type(self): def generate_column(self, size): iinfo = np.iinfo(self.numpy_type) values = [int(x) for x in - np.random.randint(max(iinfo.min, TEST_INT_MIN), - min(iinfo.max, TEST_INT_MAX), + np.random.randint(max(iinfo.min, self.min_value), + min(iinfo.max, self.max_value), size=size)] is_valid = self._make_is_valid(size) return PrimitiveColumn(self.name, size, is_valid, values) +class DateType(IntegerType): + + DAY = 0 + MILLISECOND = 1 + + def __init__(self, name, unit, nullable=True): + self.unit = unit + bit_width = 32 if unit == self.DAY else 64 + IntegerType.__init__(self, name, True, bit_width, nullable=nullable) + + def _get_type(self): + return OrderedDict([ + ('name', 'date'), + ('unit', 'DAY' if self.unit == self.DAY else 'MILLISECOND') + ]) + + +TIMEUNIT_NAMES = { + 's': 'SECOND', + 'ms': 'MILLISECOND', + 'us': 'MICROSECOND', + 'ns': 'NANOSECOND' +} + + +class TimeType(IntegerType): + + BIT_WIDTHS = { + 's': 32, + 'ms': 32, + 'us': 64, + 'ns': 64 + } + + def __init__(self, name, unit='s', nullable=True): + self.unit = unit + IntegerType.__init__(self, name, True, self.BIT_WIDTHS[unit], + nullable=nullable) + + def _get_type(self): + return OrderedDict([ + ('name', 'time'), + ('unit', TIMEUNIT_NAMES[self.unit]), + ('bitWidth', self.bit_width) + ]) + + +class TimestampType(IntegerType): + + def __init__(self, name, unit='s', tz=None, nullable=True): + self.unit = unit + self.tz = tz + IntegerType.__init__(self, name, True, 64, nullable=nullable) + + def _get_type(self): + fields = [ + ('name', 'timestamp'), + ('unit', TIMEUNIT_NAMES[self.unit]) + ] + + if self.tz is not None: + fields.append(('timezone', self.tz)) + + return OrderedDict(fields) + + class FloatingPointType(PrimitiveType): def __init__(self, name, bit_width, nullable=True): @@ -509,6 +579,20 @@ def get_field(name, type_, nullable=True): raise TypeError(dtype) +def _generate_file(fields, batch_sizes): + schema = JSONSchema(fields) + batches = [] + for size in batch_sizes: + columns = [] + for field in fields: + col = field.generate_column(size) + columns.append(col) + + batches.append(JSONRecordBatch(size, columns)) + + return JSONFile(schema, batches) + + def generate_primitive_case(): types = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', @@ -520,19 +604,27 @@ def generate_primitive_case(): fields.append(get_field(type_ + "_nullable", type_, True)) fields.append(get_field(type_ + "_nonnullable", type_, False)) - schema = JSONSchema(fields) - batch_sizes = [7, 10] - batches = [] - for size in batch_sizes: - columns = [] - for field in fields: - col = field.generate_column(size) - columns.append(col) + return _generate_file(fields, batch_sizes) - batches.append(JSONRecordBatch(size, columns)) - return JSONFile(schema, batches) +def generate_datetime_case(): + fields = [ + DateType('f0', DateType.DAY), + DateType('f1', DateType.MILLISECOND), + TimeType('f2', 's'), + TimeType('f3', 'ms'), + TimeType('f4', 'us'), + TimeType('f5', 'ns'), + TimestampType('f6', 's'), + TimestampType('f7', 'ms'), + TimestampType('f8', 'us'), + TimestampType('f9', 'ns'), + TimestampType('f10', 'ms', tz='America/New_York') + ] + + batch_sizes = [7, 10] + return _generate_file(fields, batch_sizes) def generate_nested_case(): @@ -545,19 +637,8 @@ def generate_nested_case(): # ListType('list_nonnullable', get_field('item', 'int32'), False), ] - schema = JSONSchema(fields) - batch_sizes = [7, 10] - batches = [] - for size in batch_sizes: - columns = [] - for field in fields: - col = field.generate_column(size) - columns.append(col) - - batches.append(JSONRecordBatch(size, columns)) - - return JSONFile(schema, batches) + return _generate_file(fields, batch_sizes) def get_generated_json_files(): @@ -566,13 +647,13 @@ def get_generated_json_files(): def _temp_path(): return - file_objs = [] - - K = 10 - for i in range(K): - file_objs.append(generate_primitive_case()) - - file_objs.append(generate_nested_case()) + file_objs = [ + generate_primitive_case(), + generate_primitive_case(), + generate_primitive_case(), + # generate_datetime_case(), + generate_nested_case() + ] generated_paths = [] for file_obj in file_objs: