Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 188 additions & 30 deletions rust/arrow/src/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ use flatbuffers::{
use std::collections::HashMap;
use std::sync::Arc;

use DataType::*;

/// Serialize a schema in IPC format
fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder {
pub(crate) fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder {
let mut fbb = FlatBufferBuilder::new();

let mut fields = vec![];
Expand Down Expand Up @@ -73,6 +75,47 @@ fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder {
fbb
}

pub(crate) fn schema_to_fb_offset<'a: 'b, 'b>(
mut fbb: &'a mut FlatBufferBuilder,
schema: &Schema,
) -> WIPOffset<ipc::Schema<'b>> {
let mut fields = vec![];
for field in schema.fields() {
let fb_field_name = fbb.create_string(field.name().as_str());
let (ipc_type_type, ipc_type, ipc_children) =
get_fb_field_type(field.data_type(), &mut fbb);
let mut field_builder = ipc::FieldBuilder::new(&mut fbb);
field_builder.add_name(fb_field_name);
field_builder.add_type_type(ipc_type_type);
field_builder.add_nullable(field.is_nullable());
match ipc_children {
None => {}
Some(children) => field_builder.add_children(children),
};
field_builder.add_type_(ipc_type);
fields.push(field_builder.finish());
}

let mut custom_metadata = vec![];
for (k, v) in schema.metadata() {
let fb_key_name = fbb.create_string(k.as_str());
let fb_val_name = fbb.create_string(v.as_str());

let mut kv_builder = ipc::KeyValueBuilder::new(&mut fbb);
kv_builder.add_key(fb_key_name);
kv_builder.add_value(fb_val_name);
custom_metadata.push(kv_builder.finish());
}

let fb_field_list = fbb.create_vector(&fields);
let fb_metadata_list = fbb.create_vector(&custom_metadata);

let mut builder = ipc::SchemaBuilder::new(&mut fbb);
builder.add_fields(fb_field_list);
builder.add_custom_metadata(fb_metadata_list);
builder.finish()
}

/// Convert an IPC Field to Arrow Field
impl<'a> From<ipc::Field<'a>> for Field {
fn from(field: ipc::Field) -> Field {
Expand All @@ -85,7 +128,7 @@ impl<'a> From<ipc::Field<'a>> for Field {
}

/// Deserialize a Schema table from IPC format to Schema data type
pub fn fb_to_schema(fb: ipc::Schema) -> Schema {
pub(crate) fn fb_to_schema(fb: ipc::Schema) -> Schema {
let mut fields: Vec<Field> = vec![];
let c_fields = fb.fields().unwrap();
let len = c_fields.len();
Expand All @@ -110,7 +153,7 @@ pub fn fb_to_schema(fb: ipc::Schema) -> Schema {
}

/// Get the Arrow data type from the flatbuffer Field table
fn get_data_type(field: ipc::Field) -> DataType {
pub(crate) fn get_data_type(field: ipc::Field) -> DataType {
match field.type_type() {
ipc::Type::Bool => DataType::Boolean,
ipc::Type::Int => {
Expand Down Expand Up @@ -233,22 +276,28 @@ fn get_data_type(field: ipc::Field) -> DataType {
}

/// Get the IPC type of a data type
fn get_fb_field_type<'a: 'b, 'b>(
pub(crate) fn get_fb_field_type<'a: 'b, 'b>(
data_type: &DataType,
mut fbb: &mut FlatBufferBuilder<'a>,
) -> (
ipc::Type,
WIPOffset<UnionWIPOffset>,
Option<WIPOffset<Vector<'b, ForwardsUOffset<ipc::Field<'b>>>>>,
) {
use DataType::*;
// some IPC implementations expect an empty list for child data, instead of a null value.
// An empty field list is thus returned for primitive types
let empty_fields: Vec<WIPOffset<ipc::Field>> = vec![];
match data_type {
Boolean => (
ipc::Type::Bool,
ipc::BoolBuilder::new(&mut fbb).finish().as_union_value(),
None,
),
Boolean => {
let children = fbb.create_vector(&empty_fields[..]);
(
ipc::Type::Bool,
ipc::BoolBuilder::new(&mut fbb).finish().as_union_value(),
Some(children),
)
}
UInt8 | UInt16 | UInt32 | UInt64 => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::IntBuilder::new(&mut fbb);
builder.add_is_signed(false);
match data_type {
Expand All @@ -258,9 +307,14 @@ fn get_fb_field_type<'a: 'b, 'b>(
UInt64 => builder.add_bitWidth(64),
_ => {}
};
(ipc::Type::Int, builder.finish().as_union_value(), None)
(
ipc::Type::Int,
builder.finish().as_union_value(),
Some(children),
)
}
Int8 | Int16 | Int32 | Int64 => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::IntBuilder::new(&mut fbb);
builder.add_is_signed(true);
match data_type {
Expand All @@ -270,9 +324,14 @@ fn get_fb_field_type<'a: 'b, 'b>(
Int64 => builder.add_bitWidth(64),
_ => {}
};
(ipc::Type::Int, builder.finish().as_union_value(), None)
(
ipc::Type::Int,
builder.finish().as_union_value(),
Some(children),
)
}
Float16 | Float32 | Float64 => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::FloatingPointBuilder::new(&mut fbb);
match data_type {
Float16 => builder.add_precision(ipc::Precision::HALF),
Expand All @@ -283,30 +342,57 @@ fn get_fb_field_type<'a: 'b, 'b>(
(
ipc::Type::FloatingPoint,
builder.finish().as_union_value(),
None,
Some(children),
)
}
Binary => {
let children = fbb.create_vector(&empty_fields[..]);
(
ipc::Type::Binary,
ipc::BinaryBuilder::new(&mut fbb).finish().as_union_value(),
Some(children),
)
}
Utf8 => {
let children = fbb.create_vector(&empty_fields[..]);
(
ipc::Type::Utf8,
ipc::Utf8Builder::new(&mut fbb).finish().as_union_value(),
Some(children),
)
}
FixedSizeBinary(len) => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::FixedSizeBinaryBuilder::new(&mut fbb);
builder.add_byteWidth(*len as i32);
(
ipc::Type::FixedSizeBinary,
builder.finish().as_union_value(),
Some(children),
)
}
Binary => (
ipc::Type::Binary,
ipc::BinaryBuilder::new(&mut fbb).finish().as_union_value(),
None,
),
Utf8 => (
ipc::Type::Utf8,
ipc::Utf8Builder::new(&mut fbb).finish().as_union_value(),
None,
),
Date32(_) => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::DateBuilder::new(&mut fbb);
builder.add_unit(ipc::DateUnit::DAY);
(ipc::Type::Date, builder.finish().as_union_value(), None)
(
ipc::Type::Date,
builder.finish().as_union_value(),
Some(children),
)
}
Date64(_) => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::DateBuilder::new(&mut fbb);
builder.add_unit(ipc::DateUnit::MILLISECOND);
(ipc::Type::Date, builder.finish().as_union_value(), None)
(
ipc::Type::Date,
builder.finish().as_union_value(),
Some(children),
)
}
Time32(unit) | Time64(unit) => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::TimeBuilder::new(&mut fbb);
match unit {
TimeUnit::Second => {
Expand All @@ -326,9 +412,14 @@ fn get_fb_field_type<'a: 'b, 'b>(
builder.add_unit(ipc::TimeUnit::NANOSECOND);
}
}
(ipc::Type::Time, builder.finish().as_union_value(), None)
(
ipc::Type::Time,
builder.finish().as_union_value(),
Some(children),
)
}
Timestamp(unit, tz) => {
let children = fbb.create_vector(&empty_fields[..]);
let tz = tz.clone().unwrap_or(Arc::new(String::new()));
let tz_str = fbb.create_string(tz.as_str());
let mut builder = ipc::TimestampBuilder::new(&mut fbb);
Expand All @@ -345,19 +436,25 @@ fn get_fb_field_type<'a: 'b, 'b>(
(
ipc::Type::Timestamp,
builder.finish().as_union_value(),
None,
Some(children),
)
}
Interval(unit) => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::IntervalBuilder::new(&mut fbb);
let interval_unit = match unit {
IntervalUnit::YearMonth => ipc::IntervalUnit::YEAR_MONTH,
IntervalUnit::DayTime => ipc::IntervalUnit::DAY_TIME,
};
builder.add_unit(interval_unit);
(ipc::Type::Interval, builder.finish().as_union_value(), None)
(
ipc::Type::Interval,
builder.finish().as_union_value(),
Some(children),
)
}
Duration(unit) => {
let children = fbb.create_vector(&empty_fields[..]);
let mut builder = ipc::DurationBuilder::new(&mut fbb);
let time_unit = match unit {
TimeUnit::Second => ipc::TimeUnit::SECOND,
Expand All @@ -366,7 +463,11 @@ fn get_fb_field_type<'a: 'b, 'b>(
TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND,
};
builder.add_unit(time_unit);
(ipc::Type::Duration, builder.finish().as_union_value(), None)
(
ipc::Type::Duration,
builder.finish().as_union_value(),
Some(children),
)
}
List(ref list_type) => {
let inner_types = get_fb_field_type(list_type, &mut fbb);
Expand All @@ -389,6 +490,29 @@ fn get_fb_field_type<'a: 'b, 'b>(
Some(children),
)
}
FixedSizeList(ref list_type, len) => {
let inner_types = get_fb_field_type(list_type, &mut fbb);
let child = ipc::Field::create(
&mut fbb,
&ipc::FieldArgs {
name: None,
nullable: false,
type_type: inner_types.0,
type_: Some(inner_types.1),
dictionary: None,
children: inner_types.2,
custom_metadata: None,
},
);
let children = fbb.create_vector(&[child]);
let mut builder = ipc::FixedSizeListBuilder::new(&mut fbb);
builder.add_listSize(*len as i32);
(
ipc::Type::FixedSizeList,
builder.finish().as_union_value(),
Some(children),
)
}
Struct(fields) => {
// struct's fields are children
let mut children = vec![];
Expand All @@ -415,7 +539,6 @@ fn get_fb_field_type<'a: 'b, 'b>(
Some(children),
)
}
t @ _ => panic!("Unsupported Arrow Data Type {:?}", t),
}
}

Expand Down Expand Up @@ -530,4 +653,39 @@ mod tests {
let schema2 = fb_to_schema(ipc);
assert_eq!(schema, schema2);
}

#[test]
fn schema_from_bytes() {
// bytes of a schema generated from python (0.14.0), saved as an `ipc::Message`.
// the schema is: Field("field1", DataType::UInt32, false)
let bytes: Vec<u8> = vec![
16, 0, 0, 0, 0, 0, 10, 0, 12, 0, 6, 0, 5, 0, 8, 0, 10, 0, 0, 0, 0, 1, 3, 0,
12, 0, 0, 0, 8, 0, 8, 0, 0, 0, 4, 0, 8, 0, 0, 0, 4, 0, 0, 0, 1, 0, 0, 0, 20,
0, 0, 0, 16, 0, 20, 0, 8, 0, 0, 0, 7, 0, 12, 0, 0, 0, 16, 0, 16, 0, 0, 0, 0,
0, 0, 2, 32, 0, 0, 0, 20, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 8, 0,
4, 0, 6, 0, 0, 0, 32, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49, 0, 0,
0, 0, 0, 0,
];
let ipc = ipc::get_root_as_message(&bytes[..]);
let schema = ipc.header_as_schema().unwrap();

// a message generated from Rust, same as the Python one
let bytes: Vec<u8> = vec![
16, 0, 0, 0, 0, 0, 10, 0, 14, 0, 12, 0, 11, 0, 4, 0, 10, 0, 0, 0, 20, 0, 0,
0, 0, 0, 0, 1, 3, 0, 10, 0, 12, 0, 0, 0, 8, 0, 4, 0, 10, 0, 0, 0, 8, 0, 0, 0,
8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 16, 0, 0, 0, 12, 0, 18, 0, 12, 0, 0, 0,
11, 0, 4, 0, 12, 0, 0, 0, 20, 0, 0, 0, 0, 0, 0, 2, 20, 0, 0, 0, 0, 0, 6, 0,
8, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, 6, 0, 0, 0, 102, 105, 101, 108, 100, 49,
0, 0,
];
let ipc2 = ipc::get_root_as_message(&bytes[..]);
let schema2 = ipc.header_as_schema().unwrap();

assert_eq!(schema, schema2);
assert_eq!(ipc.version(), ipc2.version());
assert_eq!(ipc.header_type(), ipc2.header_type());
assert_eq!(ipc.bodyLength(), ipc2.bodyLength());
assert!(ipc.custom_metadata().is_none());
assert!(ipc2.custom_metadata().is_none());
}
}
3 changes: 3 additions & 0 deletions rust/arrow/src/ipc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod convert;
pub mod reader;
pub mod writer;

pub mod gen;

Expand All @@ -25,3 +26,5 @@ pub use self::gen::Message::*;
pub use self::gen::Schema::*;
pub use self::gen::SparseTensor::*;
pub use self::gen::Tensor::*;

static ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];
6 changes: 2 additions & 4 deletions rust/arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ use crate::ipc;
use crate::record_batch::{RecordBatch, RecordBatchReader};
use DataType::*;

static ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];

/// Read a buffer based on offset and length
fn read_buffer(buf: &ipc::Buffer, a_data: &Vec<u8>) -> Buffer {
let start_offset = buf.offset() as usize;
Expand Down Expand Up @@ -410,14 +408,14 @@ impl<R: Read + Seek> FileReader<R> {
// check if header and footer contain correct magic bytes
let mut magic_buffer: [u8; 6] = [0; 6];
reader.read_exact(&mut magic_buffer)?;
if magic_buffer != ARROW_MAGIC {
if magic_buffer != super::ARROW_MAGIC {
return Err(ArrowError::IoError(
"Arrow file does not contain correct header".to_string(),
));
}
reader.seek(SeekFrom::End(-6))?;
reader.read_exact(&mut magic_buffer)?;
if magic_buffer != ARROW_MAGIC {
if magic_buffer != super::ARROW_MAGIC {
return Err(ArrowError::IoError(
"Arrow file does not contain correct footer".to_string(),
));
Expand Down
Loading