diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index c2e01fb6ccc..d253b83ada2 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -144,6 +144,7 @@ pub fn flight_data_to_arrow_batch( batch, schema, &dictionaries_by_field, + None, )) }, ) diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index d3f282961d1..08e81ad8e61 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -24,13 +24,13 @@ use std::collections::HashMap; use std::io::{BufReader, Read, Seek, SeekFrom}; use std::sync::Arc; -use crate::array::*; use crate::buffer::Buffer; use crate::compute::cast; use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef}; use crate::error::{ArrowError, Result}; use crate::ipc; use crate::record_batch::{RecordBatch, RecordBatchReader}; +use crate::{array::*, record_batch::RecordBatchOptions}; use ipc::CONTINUATION_MARKER; use DataType::*; @@ -406,12 +406,43 @@ fn create_dictionary_array( } } +/// Read optional custom metadata from flatbuffers. +pub fn read_custom_metadata<'a>( + fields: Option< + flatbuffers::Vector<'a, flatbuffers::ForwardsUOffset>>, + >, +) -> Option> { + let fields = fields?; + if fields.is_empty() { + return None; + } + + let len = fields.len(); + let mut metadata = HashMap::default(); + + for i in 0..len { + let kv = fields.get(i); + if let Some(k) = kv.key() { + if let Some(v) = kv.value() { + metadata.insert(k.to_string(), v.to_string()); + } + } + } + + if metadata.is_empty() { + return None; + } + + Some(metadata) +} + /// Creates a record batch from binary data using the `ipc::RecordBatch` indexes and the `Schema` pub fn read_record_batch( buf: &[u8], batch: ipc::RecordBatch, schema: SchemaRef, dictionaries: &[Option], + custom_metadata: Option>, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -440,16 +471,22 @@ pub fn read_record_batch( arrays.push(triple.0); } - RecordBatch::try_new(schema, arrays) + if let Some(metadata) = custom_metadata { + let opts = RecordBatchOptions::new(true, metadata); + RecordBatch::try_new_with_options(schema, arrays, &opts) + } else { + RecordBatch::try_new(schema, arrays) + } } -/// Read the dictionary from the buffer and provided metadata, +/// Read the dictionary from the buffer, provided metadata and optional custom metadata, /// updating the `dictionaries_by_field` with the resulting dictionary fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, schema: &Schema, dictionaries_by_field: &mut [Option], + custom_metadata: Option>, ) -> Result<()> { if batch.isDelta() { return Err(ArrowError::IoError( @@ -479,6 +516,7 @@ fn read_dictionary( batch.data().unwrap(), Arc::new(schema), &dictionaries_by_field, + custom_metadata, )?; Some(record_batch.column(0).clone()) } @@ -528,6 +566,9 @@ pub struct FileReader { /// Metadata version metadata_version: ipc::MetadataVersion, + + // Optional custom metadata. + custom_metadata: Option>, } impl FileReader { @@ -567,6 +608,8 @@ impl FileReader { ArrowError::IoError(format!("Unable to get root as footer: {:?}", err)) })?; + let footer_custom_metadata = read_custom_metadata(footer.custom_metadata()); + let blocks = footer.recordBatches().ok_or_else(|| { ArrowError::IoError( "Unable to get record batches from IPC Footer".to_string(), @@ -611,7 +654,15 @@ impl FileReader { ))?; reader.read_exact(&mut buf)?; - read_dictionary(&buf, batch, &schema, &mut dictionaries_by_field)?; + let custom_metadata = read_custom_metadata(message.custom_metadata()); + + read_dictionary( + &buf, + batch, + &schema, + &mut dictionaries_by_field, + custom_metadata, + )?; } t => { return Err(ArrowError::IoError(format!( @@ -630,6 +681,7 @@ impl FileReader { total_blocks, dictionaries_by_field, metadata_version: footer.version(), + custom_metadata: footer_custom_metadata, }) } @@ -705,11 +757,14 @@ impl FileReader { ))?; self.reader.read_exact(&mut buf)?; + let custom_metadata = read_custom_metadata(message.custom_metadata()); + read_record_batch( &buf, batch, self.schema(), &self.dictionaries_by_field, + custom_metadata, ).map(Some) } ipc::MessageHeader::NONE => { @@ -858,6 +913,8 @@ impl StreamReader { ArrowError::IoError(format!("Unable to get root as message: {:?}", err)) })?; + let custom_metadata = read_custom_metadata(message.custom_metadata()); + match message.header_type() { ipc::MessageHeader::Schema => Err(ArrowError::IoError( "Not expecting a schema when messages are read".to_string(), @@ -872,7 +929,7 @@ impl StreamReader { let mut buf = vec![0; message.bodyLength() as usize]; self.reader.read_exact(&mut buf)?; - read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some) + read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field, custom_metadata).map(Some) } ipc::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { @@ -885,7 +942,7 @@ impl StreamReader { self.reader.read_exact(&mut buf)?; read_dictionary( - &buf, batch, &self.schema, &mut self.dictionaries_by_field + &buf, batch, &self.schema, &mut self.dictionaries_by_field, custom_metadata )?; // read the next message until we encounter a RecordBatch diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index fdec26c1b79..e60b33c1883 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -45,6 +45,8 @@ pub struct IpcWriteOptions { write_legacy_ipc_format: bool, /// The metadata version to write. The Rust IPC writer supports V4+ metadata_version: ipc::MetadataVersion, + /// Optional custom metadata. + custom_metadata: HashMap, } impl IpcWriteOptions { @@ -69,6 +71,7 @@ impl IpcWriteOptions { alignment, write_legacy_ipc_format, metadata_version, + custom_metadata: HashMap::default(), }), ipc::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -81,6 +84,7 @@ impl IpcWriteOptions { alignment, write_legacy_ipc_format, metadata_version, + custom_metadata: HashMap::default(), }) } } @@ -95,6 +99,7 @@ impl Default for IpcWriteOptions { alignment: 8, write_legacy_ipc_format: true, metadata_version: ipc::MetadataVersion::V4, + custom_metadata: HashMap::default(), } } } @@ -114,13 +119,34 @@ impl IpcDataGenerator { fb.as_union_value() }; - let mut message = ipc::MessageBuilder::new(&mut fbb); - message.add_version(write_options.metadata_version); - message.add_header_type(ipc::MessageHeader::Schema); - message.add_bodyLength(0); - message.add_header(schema); - // TODO: custom metadata - let data = message.finish(); + // Optional custom metadata. + let mut fb_metadata = None; + if !write_options.custom_metadata.is_empty() { + let mut kv_vec = vec![]; + for (k, v) in &write_options.custom_metadata { + let kv_args = ipc::KeyValueArgs { + key: Some(fbb.create_string(k.as_str())), + value: Some(fbb.create_string(v.as_str())), + }; + let kv_offset = ipc::KeyValue::create(&mut fbb, &kv_args); + kv_vec.push(kv_offset); + } + + fb_metadata = Some(fbb.create_vector(&kv_vec)); + } + + let message_args = ipc::MessageArgs { + version: write_options.metadata_version, + header_type: ipc::MessageHeader::Schema, + header: Some(schema), + bodyLength: 0, + custom_metadata: fb_metadata, + }; + + // NOTE: + // As of crate `flatbuffers` 0.8.0, with `Message::new()`, almost no way to fix + // compilation error caused by "multiple mutable reference to fbb". + let data = ipc::Message::create(&mut fbb, &message_args); fbb.finish(data, None); let data = fbb.finished_data(); diff --git a/rust/arrow/src/record_batch.rs b/rust/arrow/src/record_batch.rs index b8b6098a1c7..0b98bd61188 100644 --- a/rust/arrow/src/record_batch.rs +++ b/rust/arrow/src/record_batch.rs @@ -18,7 +18,7 @@ //! A two-dimensional batch of column-oriented data with a defined //! [schema](crate::datatypes::Schema). -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use crate::array::*; use crate::datatypes::*; @@ -40,6 +40,7 @@ use crate::error::{ArrowError, Result}; pub struct RecordBatch { schema: SchemaRef, columns: Vec>, + custom_metadata: HashMap, } impl RecordBatch { @@ -77,7 +78,11 @@ impl RecordBatch { pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { let options = RecordBatchOptions::default(); Self::validate_new_batch(&schema, columns.as_slice(), &options)?; - Ok(RecordBatch { schema, columns }) + Ok(RecordBatch { + schema, + columns, + custom_metadata: HashMap::default(), + }) } /// Creates a `RecordBatch` from a schema and columns, with additional options, @@ -90,7 +95,11 @@ impl RecordBatch { options: &RecordBatchOptions, ) -> Result { Self::validate_new_batch(&schema, columns.as_slice(), options)?; - Ok(RecordBatch { schema, columns }) + Ok(RecordBatch { + schema, + columns, + custom_metadata: options.custom_metadata.clone(), + }) } /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error @@ -240,12 +249,26 @@ impl RecordBatch { pub struct RecordBatchOptions { /// Match field names of structs and lists. If set to `true`, the names must match. pub match_field_names: bool, + pub custom_metadata: HashMap, +} + +impl RecordBatchOptions { + pub fn new( + match_field_names: bool, + custom_metadata: HashMap, + ) -> Self { + Self { + match_field_names, + custom_metadata, + } + } } impl Default for RecordBatchOptions { fn default() -> Self { Self { match_field_names: true, + custom_metadata: HashMap::default(), } } } @@ -261,6 +284,7 @@ impl From<&StructArray> for RecordBatch { RecordBatch { schema: Arc::new(schema), columns, + custom_metadata: HashMap::default(), } } else { unreachable!("unable to get datatype as struct") @@ -378,6 +402,7 @@ mod tests { // creating the batch without field name validation should pass let options = RecordBatchOptions { match_field_names: false, + custom_metadata: HashMap::default(), }; let batch = RecordBatch::try_new_with_options(schema, vec![a], &options); assert!(batch.is_ok());