From 9fd3f5876baf66e2b0af140f7a8774f376e298c2 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 5 Nov 2020 15:44:13 -0500 Subject: [PATCH 01/37] Extract integration test json functions to the lib --- .../src/bin/arrow-json-integration-test.rs | 570 +----------------- rust/integration-testing/src/lib.rs | 566 +++++++++++++++++ 2 files changed, 568 insertions(+), 568 deletions(-) diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index b1bec677cf1..cd89a8edf1d 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,27 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::fs::File; -use std::io::BufReader; -use std::sync::Arc; use clap::{App, Arg}; -use hex::decode; -use serde_json::Value; -use arrow::array::*; -use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow::record_batch::RecordBatch; -use arrow::{ - buffer::Buffer, - buffer::MutableBuffer, - datatypes::ToByteSlice, - util::{bit_util, integration_util::*}, -}; +use arrow::util::integration_util::*; +use arrow_integration_testing::read_json_file; fn main() -> Result<()> { let matches = App::new("rust arrow-json-integration-test") @@ -93,520 +81,6 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn record_batch_from_json( - schema: &Schema, - json_batch: ArrowJsonBatch, - json_dictionaries: Option<&HashMap>, -) -> Result { - let mut columns = vec![]; - - for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { - let col = array_from_json(field, json_col, json_dictionaries)?; - columns.push(col); - } - - RecordBatch::try_new(Arc::new(schema.clone()), columns) -} - -/// Construct an Arrow array from a partially typed JSON column -fn array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dictionaries: Option<&HashMap>, -) -> Result { - match field.data_type() { - DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), - DataType::Boolean => { - let mut b = BooleanBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_bool().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int8 => { - let mut b = Int8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) - })? as i8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int16 => { - let mut b = Int16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int32 - | DataType::Date32(DateUnit::Day) - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let mut b = Int32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i32), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::Int64 - | DataType::Date64(DateUnit::Millisecond) - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - let mut b = Int64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(match value { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } - _ => panic!("Unable to parse {:?} as number", value), - }), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::UInt8 => { - let mut b = UInt8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt16 => { - let mut b = UInt16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt32 => { - let mut b = UInt32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt64 => { - let mut b = UInt64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value( - value - .as_str() - .unwrap() - .parse() - .expect("Unable to parse string as u64"), - ), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float32 => { - let mut b = Float32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap() as f32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float64 => { - let mut b = Float64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Binary => { - let mut b = BinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeBinary => { - let mut b = LargeBinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Utf8 => { - let mut b = StringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeUtf8 => { - let mut b = LargeStringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::FixedSizeBinary(len) => { - let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = hex::decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::List(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| v.as_i64().unwrap() as i32) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(ListArray::from(list_data))) - } - DataType::LargeList(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| match v { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => s.parse::().unwrap(), - _ => panic!("64-bit offset must be either string or number"), - }) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(LargeListArray::from(list_data))) - } - DataType::FixedSizeList(child_field, _) => { - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let null_buf = create_null_buf(&json_col); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(FixedSizeListArray::from(list_data))) - } - DataType::Struct(fields) => { - // construct struct with null data - let null_buf = create_null_buf(&json_col); - let mut array_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .null_bit_buffer(null_buf); - - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { - let array = array_from_json(field, col, dictionaries)?; - array_data = array_data.add_child_data(array.data()); - } - - let array = StructArray::from(array_data.build()); - Ok(Arc::new(array)) - } - DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) - })?; - // find dictionary - let dictionary = dictionaries - .ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field - )) - })? - .get(&dict_id); - match dictionary { - Some(dictionary) => dictionary_array_from_json( - field, json_col, key_type, value_type, dictionary, - ), - None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field - ))), - } - } - t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t - ))), - } -} - -fn dictionary_array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dict_key: &DataType, - dict_value: &DataType, - dictionary: &ArrowJsonDictionaryBatch, -) -> Result { - match dict_key { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - let null_buf = create_null_buf(&json_col); - - // build the key data into a buffer, then construct values separately - let key_field = Field::new_dict( - "key", - dict_key.clone(), - field.is_nullable(), - field - .dict_id() - .expect("Dictionary fields must have a dict_id value"), - field - .dict_is_ordered() - .expect("Dictionary fields must have a dict_is_ordered value"), - ); - let keys = array_from_json(&key_field, json_col, None)?; - // note: not enough info on nullability of dictionary - let value_field = Field::new("value", dict_value.clone(), true); - println!("dictionary value type: {:?}", dict_value); - let values = - array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; - - // convert key and value to dictionary data - let dict_data = ArrayData::builder(field.data_type().clone()) - .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(null_buf) - .add_child_data(values.data()) - .build(); - - let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } - DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), - DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), - DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), - DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), - DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), - DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), - DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), - _ => unreachable!(), - }; - Ok(array) - } - _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key - ))), - } -} - -/// A helper to create a null buffer from a Vec -fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { - let num_bytes = bit_util::ceil(json_col.count, 8); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); - json_col - .validity - .clone() - .unwrap() - .iter() - .enumerate() - .for_each(|(i, v)| { - let null_slice = null_buf.data_mut(); - if *v != 0 { - bit_util::set_bit(null_slice, i); - } - }); - null_buf.freeze() -} - fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Converting {} to {}", arrow_name, json_name); @@ -702,43 +176,3 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { Ok(()) } - -struct ArrowFile { - schema: Schema, - // we can evolve this into a concrete Arrow type - // this is temporarily not being read from - _dictionaries: HashMap, - batches: Vec, -} - -fn read_json_file(json_name: &str) -> Result { - let json_file = File::open(json_name)?; - let reader = BufReader::new(json_file); - let arrow_json: Value = serde_json::from_reader(reader).unwrap(); - let schema = Schema::from(&arrow_json["schema"])?; - // read dictionaries - let mut dictionaries = HashMap::new(); - if let Some(dicts) = arrow_json.get("dictionaries") { - for d in dicts - .as_array() - .expect("Unable to get dictionaries as array") - { - let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) - .expect("Unable to get dictionary from JSON"); - // TODO: convert to a concrete Arrow type - dictionaries.insert(json_dict.id, json_dict); - } - } - - let mut batches = vec![]; - for b in arrow_json["batches"].as_array().unwrap() { - let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); - let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; - batches.push(batch); - } - Ok(ArrowFile { - schema, - _dictionaries: dictionaries, - batches, - }) -} diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index 596017a79bd..eb101f2f474 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -16,3 +16,569 @@ // under the License. //! Common code used in the integration test binaries + +use hex::decode; +use serde_json::Value; + +use arrow::util::integration_util::{ArrowJsonBatch}; + +use arrow::array::*; +use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::RecordBatch; +use arrow::{ + buffer::Buffer, + buffer::MutableBuffer, + datatypes::ToByteSlice, + util::{bit_util, integration_util::*}, +}; + +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +pub struct ArrowFile { + pub schema: Schema, + // we can evolve this into a concrete Arrow type + // this is temporarily not being read from + pub _dictionaries: HashMap, + pub batches: Vec, +} + +pub fn read_json_file(json_name: &str) -> Result { + let json_file = File::open(json_name)?; + let reader = BufReader::new(json_file); + let arrow_json: Value = serde_json::from_reader(reader).unwrap(); + let schema = Schema::from(&arrow_json["schema"])?; + // read dictionaries + let mut dictionaries = HashMap::new(); + if let Some(dicts) = arrow_json.get("dictionaries") { + for d in dicts + .as_array() + .expect("Unable to get dictionaries as array") + { + let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) + .expect("Unable to get dictionary from JSON"); + // TODO: convert to a concrete Arrow type + dictionaries.insert(json_dict.id, json_dict); + } + } + + let mut batches = vec![]; + for b in arrow_json["batches"].as_array().unwrap() { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; + batches.push(batch); + } + Ok(ArrowFile { + schema, + _dictionaries: dictionaries, + batches, + }) +} + +fn record_batch_from_json( + schema: &Schema, + json_batch: ArrowJsonBatch, + json_dictionaries: Option<&HashMap>, +) -> Result { + let mut columns = vec![]; + + for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { + let col = array_from_json(field, json_col, json_dictionaries)?; + columns.push(col); + } + + RecordBatch::try_new(Arc::new(schema.clone()), columns) +} + +/// Construct an Arrow array from a partially typed JSON column +fn array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dictionaries: Option<&HashMap>, +) -> Result { + match field.data_type() { + DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), + DataType::Boolean => { + let mut b = BooleanBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_bool().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int8 => { + let mut b = Int8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to get {:?} as int64", + value + )) + })? as i8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int16 => { + let mut b = Int16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int32 + | DataType::Date32(DateUnit::Day) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + let mut b = Int32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i32), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::Int64 + | DataType::Date64(DateUnit::Millisecond) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + let mut b = Int64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => { + s.parse().expect("Unable to parse string as i64") + } + _ => panic!("Unable to parse {:?} as number", value), + }), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::UInt8 => { + let mut b = UInt8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt16 => { + let mut b = UInt16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt32 => { + let mut b = UInt32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt64 => { + let mut b = UInt64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value( + value + .as_str() + .unwrap() + .parse() + .expect("Unable to parse string as u64"), + ), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float32 => { + let mut b = Float32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap() as f32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float64 => { + let mut b = Float64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Binary => { + let mut b = BinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeBinary => { + let mut b = LargeBinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Utf8 => { + let mut b = StringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeUtf8 => { + let mut b = LargeStringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::FixedSizeBinary(len) => { + let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = hex::decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::List(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(ListArray::from(list_data))) + } + DataType::LargeList(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| match v { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => s.parse::().unwrap(), + _ => panic!("64-bit offset must be either string or number"), + }) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(LargeListArray::from(list_data))) + } + DataType::FixedSizeList(child_field, _) => { + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let null_buf = create_null_buf(&json_col); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(FixedSizeListArray::from(list_data))) + } + DataType::Struct(fields) => { + // construct struct with null data + let null_buf = create_null_buf(&json_col); + let mut array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .null_bit_buffer(null_buf); + + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + array_data = array_data.add_child_data(array.data()); + } + + let array = StructArray::from(array_data.build()); + Ok(Arc::new(array)) + } + DataType::Dictionary(key_type, value_type) => { + let dict_id = field.dict_id(); + // find dictionary + let dictionary = dictionaries + .ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find any dictionaries for field {:?}", + field + )) + })? + .get(&dict_id); + match dictionary { + Some(dictionary) => dictionary_array_from_json( + field, json_col, key_type, value_type, dictionary, + ), + None => Err(ArrowError::JsonError(format!( + "Unable to find dictionary for field {:?}", + field + ))), + } + } + t => Err(ArrowError::JsonError(format!( + "data type {:?} not supported", + t + ))), + } +} + +fn dictionary_array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dict_key: &DataType, + dict_value: &DataType, + dictionary: &ArrowJsonDictionaryBatch, +) -> Result { + match dict_key { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + let null_buf = create_null_buf(&json_col); + + // build the key data into a buffer, then construct values separately + let key_field = Field::new_dict( + "key", + dict_key.clone(), + field.is_nullable(), + field.dict_id(), + field.dict_is_ordered(), + ); + let keys = array_from_json(&key_field, json_col, None)?; + // note: not enough info on nullability of dictionary + let value_field = Field::new("value", dict_value.clone(), true); + println!("dictionary value type: {:?}", dict_value); + let values = + array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; + + // convert key and value to dictionary data + let dict_data = ArrayData::builder(field.data_type().clone()) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .null_bit_buffer(null_buf) + .add_child_data(values.data()) + .build(); + + let array = match dict_key { + DataType::Int8 => { + Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef + } + DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), + DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), + DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), + DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), + DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), + DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), + DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), + _ => unreachable!(), + }; + Ok(array) + } + _ => Err(ArrowError::JsonError(format!( + "Dictionary key type {:?} not supported", + dict_key + ))), + } +} + +/// A helper to create a null buffer from a Vec +fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { + let num_bytes = bit_util::ceil(json_col.count, 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + json_col + .validity + .clone() + .unwrap() + .iter() + .enumerate() + .for_each(|(i, v)| { + let null_slice = null_buf.data_mut(); + if *v != 0 { + bit_util::set_bit(null_slice, i); + } + }); + null_buf.freeze() +} From dd142d10646118f7b896e27b3729337c0e3146de Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 29 Oct 2020 16:28:45 -0400 Subject: [PATCH 02/37] ARROW-8853: [Rust] [Integration Testing] Enable Flight tests This adds Flight client and server implementations modeled after the behavior of the C++ Flight integration tests, and enables the Flight tests in the integration tests. --- .../archery/integration/tester_rust.py | 86 +-- rust/integration-testing/Cargo.toml | 16 +- .../src/bin/flight-test-integration-client.rs | 377 +++++++++++ .../src/bin/flight-test-integration-server.rs | 634 ++++++++++++++++++ rust/integration-testing/src/lib.rs | 7 +- 5 files changed, 1080 insertions(+), 40 deletions(-) create mode 100644 rust/integration-testing/src/bin/flight-test-integration-client.rs create mode 100644 rust/integration-testing/src/bin/flight-test-integration-server.rs diff --git a/dev/archery/archery/integration/tester_rust.py b/dev/archery/archery/integration/tester_rust.py index 23c2d37386a..bca80ebae3c 100644 --- a/dev/archery/archery/integration/tester_rust.py +++ b/dev/archery/archery/integration/tester_rust.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import contextlib import os +import subprocess from .tester import Tester from .util import run_cmd, ARROW_ROOT_DEFAULT, log @@ -24,8 +26,8 @@ class RustTester(Tester): PRODUCER = True CONSUMER = True - # FLIGHT_SERVER = True - # FLIGHT_CLIENT = True + FLIGHT_SERVER = True + FLIGHT_CLIENT = True EXE_PATH = os.path.join(ARROW_ROOT_DEFAULT, 'rust/target/debug') @@ -34,11 +36,11 @@ class RustTester(Tester): STREAM_TO_FILE = os.path.join(EXE_PATH, 'arrow-stream-to-file') FILE_TO_STREAM = os.path.join(EXE_PATH, 'arrow-file-to-stream') - # FLIGHT_SERVER_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-server')] - # FLIGHT_CLIENT_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-client'), - # "-host", "localhost"] + FLIGHT_SERVER_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-server')] + FLIGHT_CLIENT_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-client'), + "--host", "localhost"] name = 'Rust' @@ -72,34 +74,42 @@ def file_to_stream(self, file_path, stream_path): cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path] self.run_shell_command(cmd) - # @contextlib.contextmanager - # def flight_server(self): - # cmd = self.FLIGHT_SERVER_CMD + ['-port=0'] - # if self.debug: - # log(' '.join(cmd)) - # server = subprocess.Popen(cmd, - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE) - # try: - # output = server.stdout.readline().decode() - # if not output.startswith("Server listening on localhost:"): - # server.kill() - # out, err = server.communicate() - # raise RuntimeError( - # "Flight-C++ server did not start properly, " - # "stdout:\n{}\n\nstderr:\n{}\n" - # .format(output + out.decode(), err.decode())) - # port = int(output.split(":")[1]) - # yield port - # finally: - # server.kill() - # server.wait(5) - - # def flight_request(self, port, json_path): - # cmd = self.FLIGHT_CLIENT_CMD + [ - # '-port=' + str(port), - # '-path=' + json_path, - # ] - # if self.debug: - # log(' '.join(cmd)) - # run_cmd(cmd) + @contextlib.contextmanager + def flight_server(self, scenario_name=None): + cmd = self.FLIGHT_SERVER_CMD + ['--port=0'] + if scenario_name: + cmd = cmd + ["--scenario", scenario_name] + if self.debug: + log(' '.join(cmd)) + server = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + try: + output = server.stdout.readline().decode() + if not output.startswith("Server listening on localhost:"): + server.kill() + out, err = server.communicate() + raise RuntimeError( + "Flight-Rust server did not start properly, " + "stdout:\n{}\n\nstderr:\n{}\n" + .format(output + out.decode(), err.decode())) + port = int(output.split(":")[1]) + yield port + finally: + server.kill() + server.wait(5) + + def flight_request(self, port, json_path=None, scenario_name=None): + cmd = self.FLIGHT_CLIENT_CMD + [ + '--port=' + str(port), + ] + if json_path: + cmd.extend(('--path', json_path)) + elif scenario_name: + cmd.extend(('--scenario', scenario_name)) + else: + raise TypeError("Must provide one of json_path or scenario_name") + + if self.debug: + log(' '.join(cmd)) + run_cmd(cmd) diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml index 1c2687086fb..63f9ad6f1ab 100644 --- a/rust/integration-testing/Cargo.toml +++ b/rust/integration-testing/Cargo.toml @@ -27,11 +27,17 @@ edition = "2018" [dependencies] arrow = { path = "../arrow" } +arrow-flight = { path = "../arrow-flight" } +async-trait = "0.1.41" clap = "2.33" +futures = "0.3" +hex = "0.4" +prost = "0.6" serde = { version = "1.0", features = ["rc"] } serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } -hex = "0.4" +tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] } +tonic = "0.3" [[bin]] name = "arrow-file-to-stream" @@ -44,3 +50,11 @@ path = "src/bin/arrow-stream-to-file.rs" [[bin]] name = "arrow-json-integration-test" path = "src/bin/arrow-json-integration-test.rs" + +[[bin]] +name = "flight-test-integration-server" +path = "src/bin/flight-test-integration-server.rs" + +[[bin]] +name = "flight-test-integration-client" +path = "src/bin/flight-test-integration-client.rs" diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs new file mode 100644 index 00000000000..2a299cc401d --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -0,0 +1,377 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_integration_testing::{ + read_json_file, ArrowFile, AUTH_PASSWORD, AUTH_USERNAME, +}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; + +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::{ + flight_descriptor::DescriptorType, BasicAuth, FlightData, HandshakeRequest, Location, + Ticket, +}; +use arrow_flight::{utils::flight_data_to_arrow_batch, FlightDescriptor}; + +use clap::{App, Arg}; +use futures::{channel::mpsc, sink::SinkExt, StreamExt}; +use prost::Message; +use tonic::{metadata::MetadataValue, Request, Status}; + +use std::sync::Arc; + +type Error = Box; +type Result = std::result::Result; + +type Client = FlightServiceClient; + +#[tokio::main] +async fn main() -> Result { + let matches = App::new("rust flight-test-integration-client") + .arg(Arg::with_name("host").long("host").takes_value(true)) + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg(Arg::with_name("path").long("path").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let host = matches.value_of("host").expect("Host is required"); + let port = matches.value_of("port").expect("Port is required"); + + match matches.value_of("scenario") { + Some("middleware") => middleware_scenario(host, port).await?, + Some("auth:basic_proto") => auth_basic_proto_scenario(host, port).await?, + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + let path = matches + .value_of("path") + .expect("Path is required if scenario is not specified"); + integration_test_scenario(host, port, path).await?; + } + } + + Ok(()) +} + +async fn middleware_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let conn = tonic::transport::Endpoint::new(url)?.connect().await?; + let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Cmd); + descriptor.cmd = b"".to_vec(); + + // This call is expected to fail. + let resp = client + .get_flight_info(Request::new(descriptor.clone())) + .await; + match resp { + Ok(_) => return Err(Box::new(Status::internal("Expected call to fail"))), + Err(e) => { + let headers = e.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + + eprintln!("Headers received successfully on failing call."); + } + } + + // This call should succeed + descriptor.cmd = b"success".to_vec(); + let resp = client.get_flight_info(Request::new(descriptor)).await?; + + let headers = resp.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + + eprintln!("Headers received successfully on passing call."); + + Ok(()) +} + +fn middleware_interceptor(mut req: Request<()>) -> Result, Status> { + let metadata = req.metadata_mut(); + metadata.insert("x-middleware", "expected value".parse().unwrap()); + Ok(req) +} + +async fn auth_basic_proto_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let mut client = FlightServiceClient::connect(url).await?; + + let action = arrow_flight::Action::default(); + + let resp = client.do_action(Request::new(action.clone())).await; + // This client is unauthenticated and should fail. + match resp { + Err(e) => { + if e.code() != tonic::Code::Unauthenticated { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + e + )))); + } + } + Ok(other) => { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + other + )))); + } + } + + let token = authenticate(&mut client, AUTH_USERNAME, AUTH_PASSWORD) + .await + .expect("must respond successfully from handshake"); + + let mut request = Request::new(action); + let metadata = request.metadata_mut(); + metadata.insert_bin( + "auth-token-bin", + MetadataValue::from_bytes(token.as_bytes()), + ); + + let resp = client.do_action(request).await?; + let mut resp = resp.into_inner(); + + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + + let body = String::from_utf8(r.body).unwrap(); + assert_eq!(body, AUTH_USERNAME); + + Ok(()) +} + +// TODO: should this be extended, abstracted, and moved out of test code and into production code? +async fn authenticate( + client: &mut Client, + username: &str, + password: &str, +) -> Result { + let (mut tx, rx) = mpsc::channel(10); + let rx = client.handshake(Request::new(rx)).await?; + let mut rx = rx.into_inner(); + + let auth = BasicAuth { + username: username.into(), + password: password.into(), + }; + + let mut payload = vec![]; + auth.encode(&mut payload)?; + + tx.send(HandshakeRequest { + payload, + ..HandshakeRequest::default() + }) + .await?; + drop(tx); + + let r = rx.next().await.expect("must respond from handshake")?; + assert!(rx.next().await.is_none(), "must not respond a second time"); + + Ok(String::from_utf8(r.payload).unwrap()) +} + +async fn integration_test_scenario(host: &str, port: &str, path: &str) -> Result { + let url = format!("http://{}:{}", host, port); + + let client = FlightServiceClient::connect(url).await?; + + let ArrowFile { + schema, batches, .. + } = read_json_file(path)?; + + let schema = Arc::new(schema); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Path); + descriptor.path = vec![path.to_string()]; + + upload_data( + client.clone(), + schema.clone(), + descriptor.clone(), + batches.clone(), + ) + .await?; + verify_data(client, descriptor, schema, &batches).await?; + + Ok(()) +} + +async fn upload_data( + mut client: Client, + schema: SchemaRef, + descriptor: FlightDescriptor, + original_data: Vec, +) -> Result { + let (mut upload_tx, upload_rx) = mpsc::channel(10); + + let mut schema_flight_data = FlightData::from(&*schema); + schema_flight_data.flight_descriptor = Some(descriptor.clone()); + schema_flight_data.app_metadata = "hello".as_bytes().to_vec(); + upload_tx.send(schema_flight_data).await?; + + let resp = client.do_put(Request::new(upload_rx)).await?; + let mut resp = resp.into_inner(); + + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + + assert_eq!(r.app_metadata, "hello".as_bytes()); + + tokio::spawn(async move { + for (counter, batch) in original_data.iter().enumerate() { + let metadata = counter.to_string().into_bytes(); + + let mut batch = FlightData::from(batch); + batch.flight_descriptor = Some(descriptor.clone()); + batch.app_metadata = metadata.clone(); + + upload_tx.send(batch).await?; + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + } + + Ok(()) + }) + .await? +} + +async fn verify_data( + mut client: Client, + descriptor: FlightDescriptor, + expected_schema: SchemaRef, + expected_data: &[RecordBatch], +) -> Result { + let resp = client.get_flight_info(Request::new(descriptor)).await?; + let info = resp.into_inner(); + + assert!( + !info.endpoint.is_empty(), + "No endpoints returned from Flight server", + ); + for endpoint in info.endpoint { + let ticket = endpoint + .ticket + .expect("No ticket returned from Flight server"); + + assert!( + !endpoint.location.is_empty(), + "No locations returned from Flight server", + ); + for location in endpoint.location { + println!("Verifying location {:?}", location); + consume_flight_location( + location, + ticket.clone(), + &expected_data, + expected_schema.clone(), + ) + .await?; + } + } + + Ok(()) +} + +async fn consume_flight_location( + location: Location, + ticket: Ticket, + expected_data: &[RecordBatch], + schema: SchemaRef, +) -> Result { + let mut client = FlightServiceClient::connect(location.uri).await?; + + let resp = client.do_get(ticket).await?; + let mut resp = resp.into_inner(); + + for (counter, expected_batch) in expected_data.iter().enumerate() { + let actual_batch = resp.next().await.unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + })?; + + let metadata = counter.to_string().into_bytes(); + assert_eq!(metadata, actual_batch.app_metadata); + + let actual_batch = flight_data_to_arrow_batch(&actual_batch, schema.clone()) + .expect("Unable to convert flight data to Arrow batch") + .expect("Unable to convert flight data to Arrow batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + let schema = expected_batch.schema(); + for i in 0..expected_batch.num_columns() { + let field = schema.field(i); + let field_name = field.name(); + + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data, actual_data, "Data for field {}", field_name); + } + } + + assert!( + resp.next().await.is_none(), + "Got more batches than the expected: {}", + expected_data.len(), + ); + + Ok(()) +} diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs new file mode 100644 index 00000000000..f26db87b22e --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -0,0 +1,634 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::convert::TryFrom; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +use clap::{App, Arg}; +use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use prost::Message; +use tokio::net::TcpListener; +use tokio::sync::Mutex; +use tonic::transport::Server; +use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; + +use arrow::{datatypes::Schema, record_batch::RecordBatch}; +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, utils::flight_data_to_arrow_batch, + Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, + FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, Location, PutResult, + SchemaResult, Ticket, +}; + +use arrow_integration_testing::{AUTH_PASSWORD, AUTH_USERNAME}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +#[derive(Debug, Clone)] +struct IntegrationDataset { + schema: Schema, + chunks: Vec, +} + +#[derive(Clone, Default)] +pub struct FlightServiceImpl { + server_location: String, + uploaded_chunks: Arc>>, +} + +#[tonic::async_trait] +impl FlightService for FlightServiceImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + let ticket = request.into_inner(); + + let key = String::from_utf8(ticket.ticket.to_vec()) + .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?; + + let uploaded_chunks = self.uploaded_chunks.lock().await; + + let flight = uploaded_chunks.get(&key).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", key)) + })?; + + let batches: Vec> = flight + .chunks + .iter() + .enumerate() + .map(|(counter, batch)| { + let mut flight_data = FlightData::from(batch); + let metadata = counter.to_string().into_bytes(); + flight_data.app_metadata = metadata; + Ok(flight_data) + }) + .collect(); + + let output = futures::stream::iter(batches); + + Ok(Response::new(Box::pin(output) as Self::DoGetStream)) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let descriptor = request.into_inner(); + + match descriptor.r#type { + t if t == DescriptorType::Path as i32 => { + let path = &descriptor.path; + if path.is_empty() { + return Err(Status::invalid_argument("Invalid path")); + } + + let uploaded_chunks = self.uploaded_chunks.lock().await; + let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", path[0])) + })?; + + let schema_result = SchemaResult::from(&flight.schema); + + let endpoint = FlightEndpoint { + ticket: Some(Ticket { + ticket: path[0].as_bytes().to_vec(), + }), + location: vec![Location { + uri: self.server_location.clone(), + }], + }; + + let total_records: usize = + flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + + let info = FlightInfo { + schema: schema_result.schema, + flight_descriptor: Some(descriptor.clone()), + endpoint: vec![endpoint], + total_records: total_records as i64, + total_bytes: -1, + }; + + Ok(Response::new(info)) + } + other => Err(Status::unimplemented(format!("Request type: {}", other))), + } + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + let mut input_stream = request.into_inner(); + let flight_data = input_stream + .message() + .await? + .ok_or(Status::invalid_argument("Must send some FlightData"))?; + + let descriptor = flight_data + .flight_descriptor + .clone() + .ok_or(Status::invalid_argument("Must have a descriptor"))?; + + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() + { + return Err(Status::invalid_argument("Must specify a path")); + } + + let key = descriptor.path[0].clone(); + + let schema = Schema::try_from(&flight_data) + .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; + let schema_ref = Arc::new(schema.clone()); + + let (mut response_tx, response_rx) = mpsc::channel(10); + + let stream_result = response_tx + .send(Ok(PutResult { + app_metadata: flight_data.app_metadata.clone(), + })) + .await; + if let Err(e) = stream_result { + response_tx + .send(Err(Status::internal(format!( + "Could not send PutResult: {:?}", + e + )))) + .await + .expect("Error sending error"); + } + + let uploaded_chunks = self.uploaded_chunks.clone(); + + tokio::spawn(async move { + let mut chunks = vec![]; + let mut uploaded_chunks = uploaded_chunks.lock().await; + + while let Some(Ok(more_flight_data)) = input_stream.next().await { + let stream_result = response_tx + .send(Ok(PutResult { + app_metadata: more_flight_data.app_metadata.clone(), + })) + .await; + if let Err(e) = stream_result { + response_tx + .send(Err(Status::internal(format!( + "Could not send PutResult: {:?}", + e + )))) + .await + .expect("Error sending error"); + } + + // This `unwrap` is fine because `flight_data_to_arrow_batch` always returns `Some` + let arrow_batch_result = + flight_data_to_arrow_batch(&more_flight_data, schema_ref.clone()) + .expect("flight_data_to_arrow_batch didn't actually return Some"); + + match arrow_batch_result { + Ok(batch) => chunks.push(batch), + Err(e) => response_tx + .send(Err(Status::invalid_argument(format!( + "Could not convert to RecordBatch: {:?}", + e + )))) + .await + .expect("Error sending error"), + } + } + + let dataset = IntegrationDataset { schema, chunks }; + uploaded_chunks.insert(key, dataset); + }); + + Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream)) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +#[derive(Clone, Default)] +pub struct MiddlewareScenarioImpl {} + +#[tonic::async_trait] +impl FlightService for MiddlewareScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let middleware_header = request.metadata().get("x-middleware").cloned(); + + let descriptor = request.into_inner(); + + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success" + { + // Return a fake location - the test doesn't read it + let endpoint = FlightEndpoint { + ticket: Some(Ticket { + ticket: b"foo".to_vec(), + }), + location: vec![Location { + uri: "grpc+tcp://localhost:10010".into(), + }], + }; + + let info = FlightInfo { + endpoint: vec![endpoint], + ..Default::default() + }; + + let mut response = Response::new(info); + if let Some(value) = middleware_header { + response.metadata_mut().insert("x-middleware", value); + } + + return Ok(response); + } + + let mut status = Status::unknown("Unknown"); + if let Some(value) = middleware_header { + status.metadata_mut().insert("x-middleware", value); + } + + Err(status) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +struct GrpcServerCallContext { + peer_identity: String, +} + +impl GrpcServerCallContext { + pub fn peer_identity(&self) -> &str { + &self.peer_identity + } +} + +#[derive(Clone)] +pub struct AuthBasicProtoScenarioImpl { + username: Arc, + password: Arc, + peer_identity: Arc>>, +} + +impl AuthBasicProtoScenarioImpl { + async fn check_auth( + &self, + metadata: &MetadataMap, + ) -> Result { + let token = metadata + .get_bin("auth-token-bin") + .and_then(|v| v.to_bytes().ok()) + .and_then(|b| String::from_utf8(b.to_vec()).ok()); + self.is_valid(token).await + } + + async fn is_valid( + &self, + token: Option, + ) -> Result { + match token { + Some(t) if t == &*self.username => Ok(GrpcServerCallContext { + peer_identity: self.username.to_string(), + }), + _ => Err(Status::unauthenticated("Invalid token")), + } + } +} + +#[tonic::async_trait] +impl FlightService for AuthBasicProtoScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + let (tx, rx) = mpsc::channel(10); + + tokio::spawn({ + let username = self.username.clone(); + let password = self.password.clone(); + + async move { + let requests = request.into_inner(); + + requests + .for_each(move |req| { + let mut tx = tx.clone(); + let req = req.expect("Error reading handshake request"); + let HandshakeRequest { payload, .. } = req; + + let auth = BasicAuth::decode(&*payload) + .expect("Error parsing handshake request"); + + let resp = if &*auth.username == &*username + && &*auth.password == &*password + { + Ok(HandshakeResponse { + payload: username.as_bytes().to_vec(), + ..HandshakeResponse::default() + }) + } else { + Err(Status::unauthenticated(format!( + "Don't know user {}", + auth.username + ))) + }; + + async move { + tx.send(resp) + .await + .expect("Error sending handshake response"); + } + }) + .await; + } + }); + + Ok(Response::new(Box::pin(rx))) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + let flight_context = self.check_auth(request.metadata()).await?; + // Respond with the authenticated username. + let buf = flight_context.peer_identity().as_bytes().to_vec(); + let result = arrow_flight::Result { body: buf }; + let output = futures::stream::once(async { Ok(result) }); + Ok(Response::new(Box::pin(output) as Self::DoActionStream)) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +type Error = Box; +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result { + let matches = App::new("rust flight-test-integration-server") + .about("Integration testing server for Flight.") + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let port = matches.value_of("port").unwrap_or("0"); + + match matches.value_of("scenario") { + Some("middleware") => middleware_scenario(port).await?, + Some("auth:basic_proto") => auth_basic_proto_scenario(port).await?, + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + integration_test_scenario(port).await?; + } + } + Ok(()) +} + +async fn integration_test_scenario(port: &str) -> Result { + let (mut listener, addr) = listen_on(port).await?; + + let service = FlightServiceImpl { + server_location: format!("grpc+tcp://{}", addr), + ..Default::default() + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + + Ok(()) +} + +async fn middleware_scenario(port: &str) -> Result { + let (mut listener, _) = listen_on(port).await?; + + let service = MiddlewareScenarioImpl {}; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +async fn auth_basic_proto_scenario(port: &str) -> Result { + let (mut listener, _) = listen_on(port).await?; + + let service = AuthBasicProtoScenarioImpl { + username: AUTH_USERNAME.into(), + password: AUTH_PASSWORD.into(), + peer_identity: Arc::new(Mutex::new(None)), + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +async fn listen_on(port: &str) -> Result<(TcpListener, SocketAddr)> { + let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; + + let listener = TcpListener::bind(addr).await?; + let addr = listener.local_addr()?; + println!("Server listening on localhost:{}", addr.port()); + + Ok((listener, addr)) +} diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index eb101f2f474..0ce7826a16a 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -20,7 +20,7 @@ use hex::decode; use serde_json::Value; -use arrow::util::integration_util::{ArrowJsonBatch}; +use arrow::util::integration_util::ArrowJsonBatch; use arrow::array::*; use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; @@ -38,6 +38,11 @@ use std::fs::File; use std::io::BufReader; use std::sync::Arc; +/// The expected username for the basic auth integration test. +pub const AUTH_USERNAME: &str = "arrow"; +/// The expected password for the basic auth integration test. +pub const AUTH_PASSWORD: &str = "flight"; + pub struct ArrowFile { pub schema: Schema, // we can evolve this into a concrete Arrow type From b078a659c7f60c6c32bad033c35db00b4e56989b Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 18 Nov 2020 16:04:55 -0500 Subject: [PATCH 03/37] Don't send ACK for initial DoPut request from server --- .../src/bin/flight-test-integration-server.rs | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index f26db87b22e..33d7ecc68be 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -191,21 +191,6 @@ impl FlightService for FlightServiceImpl { let (mut response_tx, response_rx) = mpsc::channel(10); - let stream_result = response_tx - .send(Ok(PutResult { - app_metadata: flight_data.app_metadata.clone(), - })) - .await; - if let Err(e) = stream_result { - response_tx - .send(Err(Status::internal(format!( - "Could not send PutResult: {:?}", - e - )))) - .await - .expect("Error sending error"); - } - let uploaded_chunks = self.uploaded_chunks.clone(); tokio::spawn(async move { From cb54a90e9d31c911e5006af4b2e100c48c62e2c2 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 19 Nov 2020 09:01:28 -0500 Subject: [PATCH 04/37] Don't send or check for app_metadata in the DoPut request --- .../src/bin/flight-test-integration-client.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 2a299cc401d..23f7642ebd1 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -252,7 +252,6 @@ async fn upload_data( let mut schema_flight_data = FlightData::from(&*schema); schema_flight_data.flight_descriptor = Some(descriptor.clone()); - schema_flight_data.app_metadata = "hello".as_bytes().to_vec(); upload_tx.send(schema_flight_data).await?; let resp = client.do_put(Request::new(upload_rx)).await?; @@ -264,8 +263,6 @@ async fn upload_data( .expect("No response received") .expect("Invalid response received"); - assert_eq!(r.app_metadata, "hello".as_bytes()); - tokio::spawn(async move { for (counter, batch) in original_data.iter().enumerate() { let metadata = counter.to_string().into_bytes(); From bfdfc78887bddcb21e21a4b15a65aa245d333336 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 19 Nov 2020 09:16:29 -0500 Subject: [PATCH 05/37] Remove tokio spawn; a new thread isn't actually needed here --- .../src/bin/flight-test-integration-client.rs | 35 +++++++++---------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 23f7642ebd1..8e5b03add06 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -263,26 +263,23 @@ async fn upload_data( .expect("No response received") .expect("Invalid response received"); - tokio::spawn(async move { - for (counter, batch) in original_data.iter().enumerate() { - let metadata = counter.to_string().into_bytes(); - - let mut batch = FlightData::from(batch); - batch.flight_descriptor = Some(descriptor.clone()); - batch.app_metadata = metadata.clone(); - - upload_tx.send(batch).await?; - let r = resp - .next() - .await - .expect("No response received") - .expect("Invalid response received"); - assert_eq!(metadata, r.app_metadata); - } + for (counter, batch) in original_data.iter().enumerate() { + let metadata = counter.to_string().into_bytes(); - Ok(()) - }) - .await? + let mut batch = FlightData::from(batch); + batch.flight_descriptor = Some(descriptor.clone()); + batch.app_metadata = metadata.clone(); + + upload_tx.send(batch).await?; + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + } + + Ok(()) } async fn verify_data( From 355008e4b9c57548bc4a96d8b71ce13ba2a34c15 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 19 Nov 2020 09:17:43 -0500 Subject: [PATCH 06/37] Preload first batch in channel before starting DoPut request --- .../src/bin/flight-test-integration-client.rs | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 8e5b03add06..47d17c725fc 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -254,29 +254,31 @@ async fn upload_data( schema_flight_data.flight_descriptor = Some(descriptor.clone()); upload_tx.send(schema_flight_data).await?; - let resp = client.do_put(Request::new(upload_rx)).await?; - let mut resp = resp.into_inner(); - - let r = resp - .next() - .await - .expect("No response received") - .expect("Invalid response received"); + let mut upload_rx_container = Some(upload_rx); + let mut resp = None; for (counter, batch) in original_data.iter().enumerate() { let metadata = counter.to_string().into_bytes(); let mut batch = FlightData::from(batch); - batch.flight_descriptor = Some(descriptor.clone()); batch.app_metadata = metadata.clone(); upload_tx.send(batch).await?; - let r = resp - .next() - .await - .expect("No response received") - .expect("Invalid response received"); - assert_eq!(metadata, r.app_metadata); + + if let Some(upload_rx) = upload_rx_container.take() { + let outer = client.do_put(Request::new(upload_rx)).await?; + let inner = outer.into_inner(); + resp = Some(inner); + } + + if let Some(inner) = resp.as_mut() { + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + } } Ok(()) From 7c182f92e01a5e499f1a130ad80b01340ec0be56 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 20 Nov 2020 16:57:42 -0500 Subject: [PATCH 07/37] rando debugging --- cpp/src/arrow/flight/client.cc | 2 ++ cpp/src/arrow/ipc/message.cc | 16 ++++++++++++++++ rust/arrow/src/ipc/writer.rs | 5 +++-- .../src/bin/flight-test-integration-server.rs | 12 ++++++++++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 5c56e6409a7..3d32d250556 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -1151,6 +1151,8 @@ class FlightClient::FlightClientImpl { using GrpcStream = grpc::ClientReaderWriter; using StreamWriter = GrpcStreamWriter; + std::cerr << "DoPut called" << std::endl; + auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); std::shared_ptr stream = stub_->DoPut(&rpc->context); diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index 6569e71b454..9d0e2577dc8 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "arrow/buffer.h" #include "arrow/device.h" @@ -469,6 +470,8 @@ class MessageDecoder::MessageDecoderImpl { metadata_(nullptr) {} Status ConsumeData(const uint8_t* data, int64_t size) { + std::cerr << "ConsumeData / next_required_size_ " << next_required_size_ << std::endl; + if (buffered_size_ == 0) { while (size > 0 && size >= next_required_size_) { auto used_size = next_required_size_; @@ -505,6 +508,7 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeBuffer(std::shared_ptr buffer) { + std::cerr << "ConsumeBuffer / next_required_size_ " << next_required_size_ << std::endl; if (buffered_size_ == 0) { while (buffer->size() >= next_required_size_) { auto used_size = next_required_size_; @@ -598,15 +602,18 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeInitial(int32_t continuation) { + std::cerr << "ConsumeInitial / continuation " << continuation << std::endl; if (continuation == internal::kIpcContinuationToken) { state_ = State::METADATA_LENGTH; next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength; + std::cerr << "ConsumeInitial / A / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadataLength()); // Valid IPC message, read the message length now return Status::OK(); } else if (continuation == 0) { state_ = State::EOS; next_required_size_ = 0; + std::cerr << "ConsumeInitial / B / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnEOS()); return Status::OK(); } else if (continuation > 0) { @@ -614,6 +621,7 @@ class MessageDecoder::MessageDecoderImpl { // ARROW-6314: Backwards compatibility for reading old IPC // messages produced prior to version 0.15.0 next_required_size_ = continuation; + std::cerr << "ConsumeInitial / C / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadata()); return Status::OK(); } else { @@ -641,11 +649,13 @@ class MessageDecoder::MessageDecoderImpl { if (metadata_length == 0) { state_ = State::EOS; next_required_size_ = 0; + std::cerr << "ConsumeMetadataLength / A /next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnEOS()); return Status::OK(); } else if (metadata_length > 0) { state_ = State::METADATA; next_required_size_ = metadata_length; + std::cerr << "ConsumeMetadataLength / B / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadata()); return Status::OK(); } else { @@ -664,6 +674,8 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeMetadataChunks() { + std::cerr << "ConsumeMetadataChunks / next_required_size_ " << next_required_size_ << std::endl; + if (chunks_[0]->size() >= next_required_size_) { if (chunks_[0]->size() == next_required_size_) { if (chunks_[0]->is_cpu()) { @@ -698,6 +710,7 @@ class MessageDecoder::MessageDecoderImpl { state_ = State::BODY; next_required_size_ = body_length; + std::cerr << "ConsumeMetadata / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnBody()); if (next_required_size_ == 0) { ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_)); @@ -713,6 +726,8 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeBodyChunks() { + std::cerr << "ConsumeBodyChunks / next_required_size_ " << next_required_size_ << std::endl; + if (chunks_[0]->size() >= next_required_size_) { auto used_size = next_required_size_; if (chunks_[0]->size() == next_required_size_) { @@ -740,6 +755,7 @@ class MessageDecoder::MessageDecoderImpl { RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message))); state_ = State::INITIAL; next_required_size_ = kMessageDecoderNextRequiredSizeInitial; + std::cerr << "ConsumeBody / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnInitial()); return Status::OK(); } diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index d6a52a62c5d..c0ccbeb07f2 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -391,6 +391,7 @@ pub fn schema_to_bytes(schema: &Schema, write_options: &IpcWriteOptions) -> Enco fbb.finish(data, None); let data = fbb.finished_data(); + dbg!(data.len()); EncodedData { ipc_message: data.to_vec(), arrow_data: vec![], @@ -600,9 +601,9 @@ fn write_continuation( total_len: i32, ) -> Result { let mut written = 8; - + dbg!("write_continuation", write_options); // the version of the writer determines whether continuation markers should be added - match write_options.metadata_version { + match dbg!(write_options.metadata_version) { ipc::MetadataVersion::V1 | ipc::MetadataVersion::V2 | ipc::MetadataVersion::V3 => { diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 33d7ecc68be..371f8a1f23e 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -121,6 +121,7 @@ impl FlightService for FlightServiceImpl { &self, request: Request, ) -> Result, Status> { + eprintln!("Doing get_flight_info..."); let descriptor = request.into_inner(); match descriptor.r#type { @@ -167,6 +168,8 @@ impl FlightService for FlightServiceImpl { &self, request: Request>, ) -> Result, Status> { + eprintln!("Doing put..."); + let mut input_stream = request.into_inner(); let flight_data = input_stream .message() @@ -198,12 +201,14 @@ impl FlightService for FlightServiceImpl { let mut uploaded_chunks = uploaded_chunks.lock().await; while let Some(Ok(more_flight_data)) = input_stream.next().await { + eprintln!("send #1"); let stream_result = response_tx .send(Ok(PutResult { app_metadata: more_flight_data.app_metadata.clone(), })) .await; if let Err(e) = stream_result { + eprintln!("send #2"); response_tx .send(Err(Status::internal(format!( "Could not send PutResult: {:?}", @@ -220,13 +225,16 @@ impl FlightService for FlightServiceImpl { match arrow_batch_result { Ok(batch) => chunks.push(batch), - Err(e) => response_tx + Err(e) => { + eprintln!("send #3"); +response_tx .send(Err(Status::invalid_argument(format!( "Could not convert to RecordBatch: {:?}", e )))) .await - .expect("Error sending error"), + .expect("Error sending error") + }, } } From 908c0a5596933d6e583a72b5c4c31a824f14b11c Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 20 Nov 2020 16:57:55 -0500 Subject: [PATCH 08/37] tracing output --- rust/integration-testing/Cargo.toml | 1 + .../src/bin/flight-test-integration-server.rs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml index 63f9ad6f1ab..003f636cf79 100644 --- a/rust/integration-testing/Cargo.toml +++ b/rust/integration-testing/Cargo.toml @@ -38,6 +38,7 @@ serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] } tonic = "0.3" +tracing-subscriber = "*" [[bin]] name = "arrow-file-to-stream" diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 371f8a1f23e..d730573cc00 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -546,6 +546,8 @@ type Result = std::result::Result; #[tokio::main] async fn main() -> Result { + tracing_subscriber::fmt::init(); + let matches = App::new("rust flight-test-integration-server") .about("Integration testing server for Flight.") .arg(Arg::with_name("port").long("port").takes_value(true)) From 12d3138305f3424e7c38f16426dd6ad977672d94 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 20 Nov 2020 16:58:22 -0500 Subject: [PATCH 09/37] [upstream anyway] more generic --- rust/arrow/src/ipc/writer.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index c0ccbeb07f2..afb3cd98744 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -420,8 +420,8 @@ impl<'a> Message<'a> { } /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written -fn write_message( - mut writer: &mut BufWriter, +pub fn write_message( + mut writer: W, message: &Message, write_options: &IpcWriteOptions, ) -> Result<(usize, usize)> { @@ -467,7 +467,7 @@ fn write_message( Ok((aligned_size, body_len)) } -fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Result { +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { let len = data.len() as u32; let pad_len = pad_to_8(len) as u32; let total_len = len + pad_len; @@ -596,7 +596,7 @@ pub fn dictionary_batch_to_bytes( /// Write a record batch to the writer, writing the message size before the message /// if the record batch is being written to a stream fn write_continuation( - writer: &mut BufWriter, + mut writer: W, write_options: &IpcWriteOptions, total_len: i32, ) -> Result { From aeb853923eaf12366a2e2b8fcf75145a6d8aafc2 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 20 Nov 2020 16:58:41 -0500 Subject: [PATCH 10/37] [supahax] it gets to the next failure --- rust/arrow/src/ipc/writer.rs | 2 +- .../src/bin/flight-test-integration-server.rs | 20 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index afb3cd98744..3c75c3aaaed 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -398,7 +398,7 @@ pub fn schema_to_bytes(schema: &Schema, write_options: &IpcWriteOptions) -> Enco } } -enum Message<'a> { +pub enum Message<'a> { Schema(&'a Schema, &'a IpcWriteOptions), RecordBatch(&'a RecordBatch, &'a IpcWriteOptions), DictionaryBatch(i64, &'a ArrayDataRef, &'a IpcWriteOptions), diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index d730573cc00..4609d353ca8 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -136,7 +136,13 @@ impl FlightService for FlightServiceImpl { Status::not_found(format!("Could not find flight. {}", path[0])) })?; - let schema_result = SchemaResult::from(&flight.schema); + + +// let schema_result = SchemaResult::from(&flight.schema); + + use arrow::ipc::{writer::IpcWriteOptions, MetadataVersion}; + // use arrow_flight::utils; + // let schema_result = utils::flight_schema_from_arrow_schema(&flight.schema, &IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); let endpoint = FlightEndpoint { ticket: Some(Ticket { @@ -150,8 +156,18 @@ impl FlightService for FlightServiceImpl { let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + //let mut ss = schema_result.schema; + // ss.splice(0..0, vec![u8::MAX, u8::MAX, u8::MAX, u8::MAX]); +// let ss = schema_result.schema; + + let mut ss = vec![]; + + let wo = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); + let msg = arrow::ipc::writer::Message::Schema(&flight.schema, &wo); + arrow::ipc::writer::write_message(&mut ss, &msg, &wo).expect("write_message"); + let info = FlightInfo { - schema: schema_result.schema, + schema: ss, flight_descriptor: Some(descriptor.clone()), endpoint: vec![endpoint], total_records: total_records as i64, From 0fb62e5b31403f97904ceb3dc8f3bb5fd26ef560 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Mon, 30 Nov 2020 14:29:13 -0500 Subject: [PATCH 11/37] Redo sending batches in client, needs cleaned up --- .../src/bin/flight-test-integration-client.rs | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 47d17c725fc..8ce14946030 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -248,37 +248,60 @@ async fn upload_data( descriptor: FlightDescriptor, original_data: Vec, ) -> Result { + eprintln!("In upload_data"); let (mut upload_tx, upload_rx) = mpsc::channel(10); let mut schema_flight_data = FlightData::from(&*schema); schema_flight_data.flight_descriptor = Some(descriptor.clone()); upload_tx.send(schema_flight_data).await?; - let mut upload_rx_container = Some(upload_rx); - let mut resp = None; + let mut original_data_iter = original_data.iter().enumerate(); + + if let Some((counter, first_batch)) = original_data_iter.next() { + eprintln!("Some batches"); - for (counter, batch) in original_data.iter().enumerate() { let metadata = counter.to_string().into_bytes(); + eprintln!("sending batch {:?}", metadata); - let mut batch = FlightData::from(batch); + let mut batch = FlightData::from(first_batch); batch.app_metadata = metadata.clone(); upload_tx.send(batch).await?; + let outer = client.do_put(Request::new(upload_rx)).await?; + let mut inner = outer.into_inner(); - if let Some(upload_rx) = upload_rx_container.take() { - let outer = client.do_put(Request::new(upload_rx)).await?; - let inner = outer.into_inner(); - resp = Some(inner); - } + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + eprintln!("received ack for batch {:?}", metadata); + + for (counter, batch) in original_data_iter { + + let metadata = counter.to_string().into_bytes(); + eprintln!("sending batch {:?}", metadata); - if let Some(inner) = resp.as_mut() { + let mut batch = FlightData::from(batch); + batch.app_metadata = metadata.clone(); + + upload_tx.send(batch).await?; let r = inner .next() .await .expect("No response received") .expect("Invalid response received"); assert_eq!(metadata, r.app_metadata); + eprintln!("received ack for batch {:?}", metadata); } + } else { + eprintln!("No batches"); + + let outer = client.do_put(Request::new(upload_rx)).await?; + let inner = outer.into_inner(); + + dbg!(&inner); } Ok(()) From b5b041484d8ba5697fa86a747d88e33ad428f82b Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Mon, 30 Nov 2020 15:17:07 -0500 Subject: [PATCH 12/37] cargo fmt --- .../src/bin/flight-test-integration-client.rs | 1 - .../src/bin/flight-test-integration-server.rs | 25 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 8ce14946030..6037cc77184 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -279,7 +279,6 @@ async fn upload_data( eprintln!("received ack for batch {:?}", metadata); for (counter, batch) in original_data_iter { - let metadata = counter.to_string().into_bytes(); eprintln!("sending batch {:?}", metadata); diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 4609d353ca8..207f8f5749c 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -136,9 +136,7 @@ impl FlightService for FlightServiceImpl { Status::not_found(format!("Could not find flight. {}", path[0])) })?; - - -// let schema_result = SchemaResult::from(&flight.schema); + // let schema_result = SchemaResult::from(&flight.schema); use arrow::ipc::{writer::IpcWriteOptions, MetadataVersion}; // use arrow_flight::utils; @@ -158,13 +156,14 @@ impl FlightService for FlightServiceImpl { //let mut ss = schema_result.schema; // ss.splice(0..0, vec![u8::MAX, u8::MAX, u8::MAX, u8::MAX]); -// let ss = schema_result.schema; + // let ss = schema_result.schema; let mut ss = vec![]; let wo = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); let msg = arrow::ipc::writer::Message::Schema(&flight.schema, &wo); - arrow::ipc::writer::write_message(&mut ss, &msg, &wo).expect("write_message"); + arrow::ipc::writer::write_message(&mut ss, &msg, &wo) + .expect("write_message"); let info = FlightInfo { schema: ss, @@ -242,15 +241,15 @@ impl FlightService for FlightServiceImpl { match arrow_batch_result { Ok(batch) => chunks.push(batch), Err(e) => { - eprintln!("send #3"); -response_tx - .send(Err(Status::invalid_argument(format!( - "Could not convert to RecordBatch: {:?}", - e - )))) - .await + eprintln!("send #3"); + response_tx + .send(Err(Status::invalid_argument(format!( + "Could not convert to RecordBatch: {:?}", + e + )))) + .await .expect("Error sending error") - }, + } } } From 6112e8a1cae10e83a79dd9e8f3adffd48eca48e0 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Mon, 30 Nov 2020 17:10:12 -0500 Subject: [PATCH 13/37] Progress on understanding flight --- rust/arrow-flight/src/utils.rs | 1 + .../src/bin/flight-test-integration-server.rs | 85 ++++++++++++------- 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index ee19f34a7c5..4a665401dad 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -132,6 +132,7 @@ pub fn flight_data_to_arrow_batch( ) -> Option> { // check that the data_header is a record batch message let message = arrow::ipc::get_root_as_message(&data.data_header[..]); + // This assumes there are no dictionaries let dictionaries_by_field = Vec::new(); message diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 207f8f5749c..6421fd61207 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -214,41 +214,60 @@ impl FlightService for FlightServiceImpl { tokio::spawn(async move { let mut chunks = vec![]; let mut uploaded_chunks = uploaded_chunks.lock().await; + let mut dictionaries_by_field = vec![]; - while let Some(Ok(more_flight_data)) = input_stream.next().await { - eprintln!("send #1"); - let stream_result = response_tx - .send(Ok(PutResult { - app_metadata: more_flight_data.app_metadata.clone(), - })) - .await; - if let Err(e) = stream_result { - eprintln!("send #2"); - response_tx - .send(Err(Status::internal(format!( - "Could not send PutResult: {:?}", - e - )))) - .await - .expect("Error sending error"); - } + while let Some(Ok(data)) = input_stream.next().await { + let message = arrow::ipc::get_root_as_message(&data.data_header[..]); + + match message.header_type() { + // CAROLTODO: Fix compiler errors here + ipc::MessageHeader::Schema => { + // TODO: send an error to the stream + eprintln!("Not expecting a schema when messages are read"); + } + ipc::MessageHeader::RecordBatch => { + eprintln!("send #1"); + let stream_result = response_tx + .send(Ok(PutResult { + app_metadata: more_flight_data.app_metadata.clone(), + })) + .await; + if let Err(e) = stream_result { + eprintln!("send #2"); + response_tx + .send(Err(Status::internal(format!( + "Could not send PutResult: {:?}", + e + )))) + .await + .expect("Error sending error"); + } - // This `unwrap` is fine because `flight_data_to_arrow_batch` always returns `Some` - let arrow_batch_result = - flight_data_to_arrow_batch(&more_flight_data, schema_ref.clone()) - .expect("flight_data_to_arrow_batch didn't actually return Some"); - - match arrow_batch_result { - Ok(batch) => chunks.push(batch), - Err(e) => { - eprintln!("send #3"); - response_tx - .send(Err(Status::invalid_argument(format!( - "Could not convert to RecordBatch: {:?}", - e - )))) - .await - .expect("Error sending error") + // TODO: handle None which means parse failure + if let Some(ipc_batch) = message.header_as_record_batch() { + let arrow_batch_result = reader::read_record_batch( + &data.data_body, + ipc_batch, + schema_ref.clone(), + &dictionaries_by_field, + ); + match arrow_batch_result { + Ok(batch) => chunks.push(batch), + Err(e) => { + eprintln!("send #3"); + response_tx + .send(Err(Status::invalid_argument(format!( + "Could not convert to RecordBatch: {:?}", + e + )))) + .await + .expect("Error sending error") + } + } + } + } + ipc::MessageHeader::DictionaryBatch => { + // CAROLTODO: And fill in with a call to read_dictionary here } } } From 4e797691b1a91d09aba147a5bea0e2ee9c70c1b2 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 10:46:04 -0500 Subject: [PATCH 14/37] more progress --- .../src/bin/flight-test-integration-server.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 6421fd61207..334df8c19cf 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -30,6 +30,7 @@ use tonic::transport::Server; use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; use arrow::{datatypes::Schema, record_batch::RecordBatch}; +use arrow::ipc::{self, reader}; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, flight_service_server::FlightServiceServer, utils::flight_data_to_arrow_batch, @@ -229,7 +230,7 @@ impl FlightService for FlightServiceImpl { eprintln!("send #1"); let stream_result = response_tx .send(Ok(PutResult { - app_metadata: more_flight_data.app_metadata.clone(), + app_metadata: data.app_metadata.clone(), })) .await; if let Err(e) = stream_result { @@ -267,7 +268,16 @@ impl FlightService for FlightServiceImpl { } } ipc::MessageHeader::DictionaryBatch => { - // CAROLTODO: And fill in with a call to read_dictionary here + // TODO: handle None which means parse failure + if let Some(ipc_batch) = message.header_as_dictionary_batch() { + reader::read_dictionary( + &data.data_body, ipc_batch, schema, &self.schema, &mut self.dictionaries_by_field + )?; + } + } + t => { + // TODO: send error to stream + eprintln!("Reading types other than record batches not yet supported, unable to read {:?}", t); } } } From a3894c0c121a227f016962e1baf6919f07fd2cf0 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Fri, 6 Nov 2020 16:29:36 -0500 Subject: [PATCH 15/37] [Rust] Only return dict_id/dict_is_ordered values for Dictionary types I had submitted this before but I think it got lost in a rebase somewhere. I think this is more correct and informative. --- rust/integration-testing/src/lib.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index 0ce7826a16a..afb39454b2b 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -479,7 +479,12 @@ fn array_from_json( Ok(Arc::new(array)) } DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id(); + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find dict_id for field {:?}", + field + )) + })?; // find dictionary let dictionary = dictionaries .ok_or_else(|| { @@ -529,8 +534,12 @@ fn dictionary_array_from_json( "key", dict_key.clone(), field.is_nullable(), - field.dict_id(), - field.dict_is_ordered(), + field + .dict_id() + .expect("Dictionary fields must have a dict_id value"), + field + .dict_is_ordered() + .expect("Dictionary fields must have a dict_is_ordered value"), ); let keys = array_from_json(&key_field, json_col, None)?; // note: not enough info on nullability of dictionary From 83e2dcdefa82cacc53b0c971c797098b7eeba92d Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 14:31:46 -0500 Subject: [PATCH 16/37] Actually really really read dictionaries --- rust/arrow/src/ipc/reader.rs | 2 +- .../src/bin/flight-test-integration-server.rs | 23 +++++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index 1b4119c9d96..6037fbe2683 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -460,7 +460,7 @@ pub fn read_record_batch( /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries_by_field` with the resulting dictionary -fn read_dictionary( +pub fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, schema: &Schema, diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 334df8c19cf..5e3757846a7 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -29,11 +29,11 @@ use tokio::sync::Mutex; use tonic::transport::Server; use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; -use arrow::{datatypes::Schema, record_batch::RecordBatch}; use arrow::ipc::{self, reader}; +use arrow::{datatypes::Schema, record_batch::RecordBatch}; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, utils::flight_data_to_arrow_batch, + flight_service_server::FlightServiceServer, Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, Location, PutResult, SchemaResult, Ticket, @@ -270,9 +270,22 @@ impl FlightService for FlightServiceImpl { ipc::MessageHeader::DictionaryBatch => { // TODO: handle None which means parse failure if let Some(ipc_batch) = message.header_as_dictionary_batch() { - reader::read_dictionary( - &data.data_body, ipc_batch, schema, &self.schema, &mut self.dictionaries_by_field - )?; + let dictionary_batch_result = reader::read_dictionary( + &data.data_body, + ipc_batch, + &schema_ref, + &mut dictionaries_by_field, + ); + if let Err(e) = dictionary_batch_result { + eprintln!("send #4"); + response_tx + .send(Err(Status::invalid_argument(format!( + "Could not convert to Dictionary: {:?}", + e + )))) + .await + .expect("Error sending error") + } } } t => { From e8bcb79959c56deb50d477d5e470b513774d9b9e Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 14:38:46 -0500 Subject: [PATCH 17/37] set initial len of dictionaries --- .../src/bin/flight-test-integration-server.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 5e3757846a7..d4c5354a84f 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -215,18 +215,18 @@ impl FlightService for FlightServiceImpl { tokio::spawn(async move { let mut chunks = vec![]; let mut uploaded_chunks = uploaded_chunks.lock().await; - let mut dictionaries_by_field = vec![]; + let mut dictionaries_by_field = vec![None; schema_ref.fields().len()]; while let Some(Ok(data)) = input_stream.next().await { let message = arrow::ipc::get_root_as_message(&data.data_header[..]); match message.header_type() { - // CAROLTODO: Fix compiler errors here ipc::MessageHeader::Schema => { // TODO: send an error to the stream eprintln!("Not expecting a schema when messages are read"); } ipc::MessageHeader::RecordBatch => { + eprintln!("RecordBatch"); eprintln!("send #1"); let stream_result = response_tx .send(Ok(PutResult { @@ -268,6 +268,7 @@ impl FlightService for FlightServiceImpl { } } ipc::MessageHeader::DictionaryBatch => { + eprintln!("DictionaryBatch"); // TODO: handle None which means parse failure if let Some(ipc_batch) = message.header_as_dictionary_batch() { let dictionary_batch_result = reader::read_dictionary( @@ -285,6 +286,8 @@ impl FlightService for FlightServiceImpl { )))) .await .expect("Error sending error") + } else { + dbg!(&dictionaries_by_field); } } } From 1fde810857589fdbbbec655fe02522c1f136c80d Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 15:30:36 -0500 Subject: [PATCH 18/37] Extract a function for Jake's supahax --- .../src/bin/flight-test-integration-server.rs | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index d4c5354a84f..37e291f8daa 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -76,6 +76,7 @@ impl FlightService for FlightServiceImpl { &self, request: Request, ) -> Result, Status> { + eprintln!("Doing do_get..."); let ticket = request.into_inner(); let key = String::from_utf8(ticket.ticket.to_vec()) @@ -137,12 +138,6 @@ impl FlightService for FlightServiceImpl { Status::not_found(format!("Could not find flight. {}", path[0])) })?; - // let schema_result = SchemaResult::from(&flight.schema); - - use arrow::ipc::{writer::IpcWriteOptions, MetadataVersion}; - // use arrow_flight::utils; - // let schema_result = utils::flight_schema_from_arrow_schema(&flight.schema, &IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); - let endpoint = FlightEndpoint { ticket: Some(Ticket { ticket: path[0].as_bytes().to_vec(), @@ -155,19 +150,11 @@ impl FlightService for FlightServiceImpl { let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); - //let mut ss = schema_result.schema; - // ss.splice(0..0, vec![u8::MAX, u8::MAX, u8::MAX, u8::MAX]); - // let ss = schema_result.schema; - - let mut ss = vec![]; - - let wo = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); - let msg = arrow::ipc::writer::Message::Schema(&flight.schema, &wo); - arrow::ipc::writer::write_message(&mut ss, &msg, &wo) - .expect("write_message"); + let schema = flight_schema(&flight.schema) + .expect("Could not generate schema bytes"); let info = FlightInfo { - schema: ss, + schema, flight_descriptor: Some(descriptor.clone()), endpoint: vec![endpoint], total_records: total_records as i64, @@ -327,6 +314,26 @@ impl FlightService for FlightServiceImpl { } } +fn flight_schema(arrow_schema: &Schema) -> Result> { + // let schema_result = SchemaResult::from(&flight.schema); + + use arrow::ipc::{writer::IpcWriteOptions, MetadataVersion}; + // use arrow_flight::utils; + // let schema_result = utils::flight_schema_from_arrow_schema(&flight.schema, &IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); + + //let mut ss = schema_result.schema; + // ss.splice(0..0, vec![u8::MAX, u8::MAX, u8::MAX, u8::MAX]); + // let ss = schema_result.schema; + + let mut ss = vec![]; + + let wo = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); + let msg = arrow::ipc::writer::Message::Schema(arrow_schema, &wo); + arrow::ipc::writer::write_message(&mut ss, &msg, &wo)?; + + Ok(ss) +} + #[derive(Clone, Default)] pub struct MiddlewareScenarioImpl {} From 2f97515958d3dac9f6f53d4bdde4de1629460a56 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 15:44:06 -0500 Subject: [PATCH 19/37] welp nope --- .../src/bin/flight-test-integration-server.rs | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 37e291f8daa..76ca7487305 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -33,10 +33,9 @@ use arrow::ipc::{self, reader}; use arrow::{datatypes::Schema, record_batch::RecordBatch}; use arrow_flight::{ flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, - Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, - FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, Location, PutResult, - SchemaResult, Ticket, + flight_service_server::FlightServiceServer, Action, ActionType, BasicAuth, Criteria, + Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, Location, PutResult, SchemaResult, Ticket, }; use arrow_integration_testing::{AUTH_PASSWORD, AUTH_USERNAME}; @@ -88,19 +87,23 @@ impl FlightService for FlightServiceImpl { Status::not_found(format!("Could not find flight. {}", key)) })?; - let batches: Vec> = flight - .chunks - .iter() - .enumerate() - .map(|(counter, batch)| { - let mut flight_data = FlightData::from(batch); - let metadata = counter.to_string().into_bytes(); - flight_data.app_metadata = metadata; - Ok(flight_data) - }) - .collect(); - - let output = futures::stream::iter(batches); + let schema = std::iter::once( + flight_schema(&flight.schema) + .map(|data_header| FlightData { + data_header, + ..Default::default() + }) + .map_err(|e| Status::internal(format!("Could not generate ipc schema: {}", e))), + ); + + let batches = flight.chunks.iter().enumerate().map(|(counter, batch)| { + let mut flight_data = FlightData::from(batch); + let metadata = counter.to_string().into_bytes(); + flight_data.app_metadata = metadata; + Ok(flight_data) + }); + + let output = futures::stream::iter(schema.chain(batches).collect::>()); Ok(Response::new(Box::pin(output) as Self::DoGetStream)) } From 3af0c24325278e1264a02a76c0381a21afca3c28 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 2 Dec 2020 17:01:05 -0500 Subject: [PATCH 20/37] i know nothing --- cpp/src/arrow/flight/client.cc | 15 +++++++++++++++ cpp/src/arrow/flight/server.cc | 1 + cpp/src/arrow/flight/test_integration_client.cc | 4 ++++ cpp/src/arrow/flight/test_integration_server.cc | 5 +++++ cpp/src/arrow/flight/test_util.cc | 17 ++++++++++++++++- .../src/bin/flight-test-integration-client.rs | 17 +++++++++++++++-- 6 files changed, 56 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 3d32d250556..94dfdeb9f8d 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -498,14 +498,21 @@ class GrpcStreamReader : public FlightStreamReader { app_metadata_(nullptr) {} Status EnsureDataStarted() { + std::cout << "Here i am in GrpcStreamReader EnsureDataStarted" << std::endl; + if (!batch_reader_) { + std::cout << "yes batch_reader_" << std::endl; + bool skipped_to_data = false; { auto guard = TakeGuard(); skipped_to_data = peekable_reader_->SkipToData(); + std::cout << "TakeGuard, SkipToData" << std::endl; } // peek() until we find the first data message; discard metadata if (!skipped_to_data) { + std::cout << "!skipped_to_data" << std::endl; + return OverrideWithServerError(MakeFlightError( FlightStatusCode::Internal, "Server never sent a data message")); } @@ -513,10 +520,16 @@ class GrpcStreamReader : public FlightStreamReader { auto message_reader = std::unique_ptr(new GrpcIpcMessageReader( rpc_, read_mutex_, stream_, peekable_reader_, &app_metadata_)); + std::cout << "yes message_reader" << std::endl; + auto result = ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_); + std::cout << "yes result" << std::endl; + RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_))); } + std::cout << "the end" << std::endl; + return Status::OK(); } arrow::Result> GetSchema() override { @@ -1141,6 +1154,8 @@ class FlightClient::FlightClientImpl { *out = std::unique_ptr( new StreamReader(rpc, nullptr, options.read_options, finishable_stream)); // Eagerly read the schema + std::cout << "Here i am in DoGet" << std::endl; + return static_cast(out->get())->EnsureDataStarted(); } diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 87c96ce4910..1957bba383e 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -622,6 +622,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, ServerWriter* writer) { + std::cout << "in base DoGet" << std::endl; GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/test_integration_client.cc index 8f331f926ef..1f070197d8b 100644 --- a/cpp/src/arrow/flight/test_integration_client.cc +++ b/cpp/src/arrow/flight/test_integration_client.cc @@ -96,11 +96,15 @@ Status ConsumeFlightLocation( std::unique_ptr stream; RETURN_NOT_OK(read_client->DoGet(ticket, &stream)); + std::cout << "Here i am in ConsumeFlightLocation" << std::endl; + int counter = 0; const int expected = static_cast(retrieved_data.size()); for (const auto& original_batch : retrieved_data) { FlightStreamChunk chunk; RETURN_NOT_OK(stream->Next(&chunk)); + std::cout << "The counter is " << counter << std::endl; + if (chunk.data == nullptr) { return Status::Invalid("Got fewer batches than expected, received so far: ", counter, " expected ", expected); diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/test_integration_server.cc index 4b904b0eba1..bbecdf769ac 100644 --- a/cpp/src/arrow/flight/test_integration_server.cc +++ b/cpp/src/arrow/flight/test_integration_server.cc @@ -111,8 +111,11 @@ class FlightIntegrationTestServer : public FlightServerBase { Status DoGet(const ServerCallContext& context, const Ticket& request, std::unique_ptr* data_stream) override { + std::cout << "In Server DoGet" << std::endl; auto data = uploaded_chunks.find(request.ticket); if (data == uploaded_chunks.end()) { + std::cout << "Could not find flight" << std::endl; + return Status::KeyError("Could not find flight.", request.ticket); } auto flight = data->second; @@ -121,6 +124,8 @@ class FlightIntegrationTestServer : public FlightServerBase { new NumberingStream(std::unique_ptr(new RecordBatchStream( std::shared_ptr(new RecordBatchListReader(flight)))))); + std::cout << "Returning OK" << std::endl; + return Status::OK(); } diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index f5efa395909..5e754309179 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -460,14 +460,29 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, NumberingStream::NumberingStream(std::unique_ptr stream) : counter_(0), stream_(std::move(stream)) {} -std::shared_ptr NumberingStream::schema() { return stream_->schema(); } +std::shared_ptr NumberingStream::schema() { + std::cout << "In NumberingStream::schema" << std::endl; + + return stream_->schema(); +} Status NumberingStream::GetSchemaPayload(FlightPayload* payload) { + std::cout << "In NumberingStream::GetSchemaPayload" << std::endl; + return stream_->GetSchemaPayload(payload); } Status NumberingStream::Next(FlightPayload* payload) { + std::cout << "In NumberingStream::Next " << counter_ << std::endl; + RETURN_NOT_OK(stream_->Next(payload)); + if (payload) { + std::cout << "yes payload" << std::endl; + if (payload->ipc_message.type != ipc::MessageType::RECORD_BATCH) { + std::cout << "no record batch :(" << std::endl; + + } + } if (payload && payload->ipc_message.type == ipc::MessageType::RECORD_BATCH) { payload->app_metadata = Buffer::FromString(std::to_string(counter_)); counter_++; diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 6037cc77184..47b91a37cd9 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -43,6 +43,7 @@ type Client = FlightServiceClient; #[tokio::main] async fn main() -> Result { + tracing_subscriber::fmt::init(); let matches = App::new("rust flight-test-integration-client") .arg(Arg::with_name("host").long("host").takes_value(true)) .arg(Arg::with_name("port").long("port").takes_value(true)) @@ -351,8 +352,20 @@ async fn consume_flight_location( ) -> Result { let mut client = FlightServiceClient::connect(location.uri).await?; - let resp = client.do_get(ticket).await?; - let mut resp = resp.into_inner(); + dbg!(&client); + + let resp = client.do_get(ticket).await; + dbg!(&resp); + + // If i turn on RUST_LOG=h2=debug and run this client against the C++ server, I see this: + // Dec 02 16:46:50.047 DEBUG h2::codec::framed_read: received; frame=Reset { stream_id: StreamId(1), error_code: INTERNAL_ERROR } + // which i think is coming straight from the server, but I don't know why :( + + let mut resp = resp?.into_inner(); + dbg!(&resp); + + let schema_again = resp.next().await.unwrap(); + dbg!(&schema_again); for (counter, expected_batch) in expected_data.iter().enumerate() { let actual_batch = resp.next().await.unwrap_or_else(|| { From 3cff161b69ceeed2149de210ccc7445b19143c5d Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 14:02:26 -0500 Subject: [PATCH 21/37] Extract generating schema EncodedData to a new struct's method --- rust/arrow-flight/src/utils.rs | 8 ++- rust/arrow/src/ipc/writer.rs | 107 +++++++++++++++++++++---------- rust/parquet/src/arrow/schema.rs | 3 +- 3 files changed, 81 insertions(+), 37 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index ee19f34a7c5..bc16692baca 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -67,8 +67,11 @@ pub fn flight_schema_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> SchemaResult { + let data_gen = writer::IpcDataGenerator::default(); + let schema_bytes = data_gen.schema_to_bytes(schema, &options); + SchemaResult { - schema: writer::schema_to_bytes(schema, &options).ipc_message, + schema: schema_bytes.ipc_message, } } @@ -88,7 +91,8 @@ pub fn flight_data_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> FlightData { - let schema = writer::schema_to_bytes(schema, &options); + let data_gen = writer::IpcDataGenerator::default(); + let schema = data_gen.schema_to_bytes(schema, &options); FlightData { flight_descriptor: None, app_metadata: vec![], diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index d6a52a62c5d..b845b3e4aa0 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -98,6 +98,38 @@ impl Default for IpcWriteOptions { } } +#[derive(Debug, Default)] +pub struct IpcDataGenerator {} + +impl IpcDataGenerator { + pub fn schema_to_bytes( + &self, + schema: &Schema, + write_options: &IpcWriteOptions, + ) -> EncodedData { + let mut fbb = FlatBufferBuilder::new(); + let schema = { + let fb = ipc::convert::schema_to_fb_offset(&mut fbb, schema); + fb.as_union_value() + }; + + let mut message = ipc::MessageBuilder::new(&mut fbb); + message.add_version(write_options.metadata_version); + message.add_header_type(ipc::MessageHeader::Schema); + message.add_bodyLength(0); + message.add_header(schema); + // TODO: custom metadata + let data = message.finish(); + fbb.finish(data, None); + + let data = fbb.finished_data(); + EncodedData { + ipc_message: data.to_vec(), + arrow_data: vec![], + } + } +} + pub struct FileWriter { /// The object to write to writer: BufWriter, @@ -115,6 +147,8 @@ pub struct FileWriter { finished: bool, /// Keeps track of dictionaries that have been written last_written_dictionaries: HashMap, + + data_gen: IpcDataGenerator, } impl FileWriter { @@ -130,6 +164,7 @@ impl FileWriter { schema: &Schema, write_options: IpcWriteOptions, ) -> Result { + let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); // write magic to header writer.write_all(&super::ARROW_MAGIC[..])?; @@ -137,7 +172,8 @@ impl FileWriter { writer.write_all(&[0, 0])?; // write the schema, set the written bytes to the schema + header let message = Message::Schema(schema, &write_options); - let (meta, data) = write_message(&mut writer, &message, &write_options)?; + let (meta, data) = + write_message(&mut writer, &message, &write_options, &data_gen)?; Ok(Self { writer, write_options, @@ -147,6 +183,7 @@ impl FileWriter { record_blocks: vec![], finished: false, last_written_dictionaries: HashMap::new(), + data_gen, }) } @@ -159,8 +196,12 @@ impl FileWriter { } self.write_dictionaries(&batch)?; let message = Message::RecordBatch(batch, &self.write_options); - let (meta, data) = - write_message(&mut self.writer, &message, &self.write_options)?; + let (meta, data) = write_message( + &mut self.writer, + &message, + &self.write_options, + &self.data_gen, + )?; // add a record block for the footer let block = ipc::Block::new( self.block_offsets as i64, @@ -207,8 +248,12 @@ impl FileWriter { let message = Message::DictionaryBatch(dict_id, dict_values, &self.write_options); - let (meta, data) = - write_message(&mut self.writer, &message, &self.write_options)?; + let (meta, data) = write_message( + &mut self.writer, + &message, + &self.write_options, + &self.data_gen, + )?; let block = ipc::Block::new(self.block_offsets as i64, meta as i32, data as i64); @@ -270,6 +315,8 @@ pub struct StreamWriter { finished: bool, /// Keeps track of dictionaries that have been written last_written_dictionaries: HashMap, + + data_gen: IpcDataGenerator, } impl StreamWriter { @@ -284,16 +331,18 @@ impl StreamWriter { schema: &Schema, write_options: IpcWriteOptions, ) -> Result { + let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); // write the schema, set the written bytes to the schema let message = Message::Schema(schema, &write_options); - write_message(&mut writer, &message, &write_options)?; + write_message(&mut writer, &message, &write_options, &data_gen)?; Ok(Self { writer, write_options, schema: schema.clone(), finished: false, last_written_dictionaries: HashMap::new(), + data_gen, }) } @@ -307,7 +356,12 @@ impl StreamWriter { self.write_dictionaries(&batch)?; let message = Message::RecordBatch(batch, &self.write_options); - write_message(&mut self.writer, &message, &self.write_options)?; + write_message( + &mut self.writer, + &message, + &self.write_options, + &self.data_gen, + )?; Ok(()) } @@ -341,7 +395,12 @@ impl StreamWriter { let message = Message::DictionaryBatch(dict_id, dict_values, &self.write_options); - write_message(&mut self.writer, &message, &self.write_options)?; + write_message( + &mut self.writer, + &message, + &self.write_options, + &self.data_gen, + )?; } } Ok(()) @@ -374,29 +433,6 @@ pub struct EncodedData { pub arrow_data: Vec, } -pub fn schema_to_bytes(schema: &Schema, write_options: &IpcWriteOptions) -> EncodedData { - let mut fbb = FlatBufferBuilder::new(); - let schema = { - let fb = ipc::convert::schema_to_fb_offset(&mut fbb, schema); - fb.as_union_value() - }; - - let mut message = ipc::MessageBuilder::new(&mut fbb); - message.add_version(write_options.metadata_version); - message.add_header_type(ipc::MessageHeader::Schema); - message.add_bodyLength(0); - message.add_header(schema); - // TODO: custom metadata - let data = message.finish(); - fbb.finish(data, None); - - let data = fbb.finished_data(); - EncodedData { - ipc_message: data.to_vec(), - arrow_data: vec![], - } -} - enum Message<'a> { Schema(&'a Schema, &'a IpcWriteOptions), RecordBatch(&'a RecordBatch, &'a IpcWriteOptions), @@ -405,9 +441,11 @@ enum Message<'a> { impl<'a> Message<'a> { /// Encode message to a ipc::Message and return data as bytes - fn encode(&'a self) -> EncodedData { + fn encode(&'a self, data_gen: &IpcDataGenerator) -> EncodedData { match self { - Message::Schema(schema, options) => schema_to_bytes(*schema, *options), + Message::Schema(schema, options) => { + data_gen.schema_to_bytes(*schema, *options) + } Message::RecordBatch(batch, options) => { record_batch_to_bytes(*batch, *options) } @@ -423,8 +461,9 @@ fn write_message( mut writer: &mut BufWriter, message: &Message, write_options: &IpcWriteOptions, + data_gen: &IpcDataGenerator, ) -> Result<(usize, usize)> { - let encoded = message.encode(); + let encoded = message.encode(data_gen); let arrow_data_len = encoded.arrow_data.len(); if arrow_data_len % 8 != 0 { return Err(ArrowError::MemoryError( diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index c93325b79b1..0c04704ae0f 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -205,7 +205,8 @@ fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Option { /// Encodes the Arrow schema into the IPC format, and base64 encodes it fn encode_arrow_schema(schema: &Schema) -> String { let options = writer::IpcWriteOptions::default(); - let mut serialized_schema = arrow::ipc::writer::schema_to_bytes(&schema, &options); + let data_gen = arrow::ipc::writer::IpcDataGenerator::default(); + let mut serialized_schema = data_gen.schema_to_bytes(&schema, &options); // manually prepending the length to the schema as arrow uses the legacy IPC format // TODO: change after addressing ARROW-9777 From 462d330f1e3309f102d55634bf6befa471adf8af Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 14:07:06 -0500 Subject: [PATCH 22/37] Move record_batch_to_bytes to the new object --- rust/arrow-flight/src/utils.rs | 3 +- rust/arrow/src/ipc/writer.rs | 108 +++++++++++++++++---------------- 2 files changed, 57 insertions(+), 54 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index bc16692baca..77e4092eb64 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -42,7 +42,8 @@ pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, ) -> FlightData { - let data = writer::record_batch_to_bytes(batch, &options); + let data_gen = writer::IpcDataGenerator::default(); + let data = data_gen.record_batch_to_bytes(batch, &options); FlightData { flight_descriptor: None, app_metadata: vec![], diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index b845b3e4aa0..7f0e866122f 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -128,6 +128,60 @@ impl IpcDataGenerator { arrow_data: vec![], } } + + /// Write a `RecordBatch` into two sets of bytes, one for the header (ipc::Message) and the + /// other for the batch's data + pub fn record_batch_to_bytes( + &self, + batch: &RecordBatch, + write_options: &IpcWriteOptions, + ) -> EncodedData { + let mut fbb = FlatBufferBuilder::new(); + + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + let mut offset = 0; + for array in batch.columns() { + let array_data = array.data(); + offset = write_array_data( + &array_data, + &mut buffers, + &mut arrow_data, + &mut nodes, + offset, + array.len(), + array.null_count(), + ); + } + + // write data + let buffers = fbb.create_vector(&buffers); + let nodes = fbb.create_vector(&nodes); + + let root = { + let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); + batch_builder.add_length(batch.num_rows() as i64); + batch_builder.add_nodes(nodes); + batch_builder.add_buffers(buffers); + let b = batch_builder.finish(); + b.as_union_value() + }; + // create an ipc::Message + let mut message = ipc::MessageBuilder::new(&mut fbb); + message.add_version(write_options.metadata_version); + message.add_header_type(ipc::MessageHeader::RecordBatch); + message.add_bodyLength(arrow_data.len() as i64); + message.add_header(root); + let root = message.finish(); + fbb.finish(root, None); + let finished_data = fbb.finished_data(); + + EncodedData { + ipc_message: finished_data.to_vec(), + arrow_data, + } + } } pub struct FileWriter { @@ -447,7 +501,7 @@ impl<'a> Message<'a> { data_gen.schema_to_bytes(*schema, *options) } Message::RecordBatch(batch, options) => { - record_batch_to_bytes(*batch, *options) + data_gen.record_batch_to_bytes(*batch, *options) } Message::DictionaryBatch(dict_id, array_data, options) => { dictionary_batch_to_bytes(*dict_id, *array_data, *options) @@ -520,58 +574,6 @@ fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Resul Ok(total_len as usize) } -/// Write a `RecordBatch` into a tuple of bytes, one for the header (ipc::Message) and the other for the batch's data -pub fn record_batch_to_bytes( - batch: &RecordBatch, - write_options: &IpcWriteOptions, -) -> EncodedData { - let mut fbb = FlatBufferBuilder::new(); - - let mut nodes: Vec = vec![]; - let mut buffers: Vec = vec![]; - let mut arrow_data: Vec = vec![]; - let mut offset = 0; - for array in batch.columns() { - let array_data = array.data(); - offset = write_array_data( - &array_data, - &mut buffers, - &mut arrow_data, - &mut nodes, - offset, - array.len(), - array.null_count(), - ); - } - - // write data - let buffers = fbb.create_vector(&buffers); - let nodes = fbb.create_vector(&nodes); - - let root = { - let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(batch.num_rows() as i64); - batch_builder.add_nodes(nodes); - batch_builder.add_buffers(buffers); - let b = batch_builder.finish(); - b.as_union_value() - }; - // create an ipc::Message - let mut message = ipc::MessageBuilder::new(&mut fbb); - message.add_version(write_options.metadata_version); - message.add_header_type(ipc::MessageHeader::RecordBatch); - message.add_bodyLength(arrow_data.len() as i64); - message.add_header(root); - let root = message.finish(); - fbb.finish(root, None); - let finished_data = fbb.finished_data(); - - EncodedData { - ipc_message: finished_data.to_vec(), - arrow_data, - } -} - /// Write dictionary values into a tuple of bytes, one for the header (ipc::Message) and the other for the data pub fn dictionary_batch_to_bytes( dict_id: i64, From 7366dd1593c2c6764e5671aaa616124e17654751 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 14:11:19 -0500 Subject: [PATCH 23/37] Extract dictionary_batch_to_bytes to the new struct --- rust/arrow/src/ipc/writer.rs | 122 ++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 60 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 7f0e866122f..41e04c5d43b 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -182,6 +182,67 @@ impl IpcDataGenerator { arrow_data, } } + + /// Write dictionary values into two sets of bytes, one for the header (ipc::Message) and the + /// other for the data + pub fn dictionary_batch_to_bytes( + &self, + dict_id: i64, + array_data: &ArrayDataRef, + write_options: &IpcWriteOptions, + ) -> EncodedData { + let mut fbb = FlatBufferBuilder::new(); + + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + + write_array_data( + &array_data, + &mut buffers, + &mut arrow_data, + &mut nodes, + 0, + array_data.len(), + array_data.null_count(), + ); + + // write data + let buffers = fbb.create_vector(&buffers); + let nodes = fbb.create_vector(&nodes); + + let root = { + let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); + batch_builder.add_length(array_data.len() as i64); + batch_builder.add_nodes(nodes); + batch_builder.add_buffers(buffers); + batch_builder.finish() + }; + + let root = { + let mut batch_builder = ipc::DictionaryBatchBuilder::new(&mut fbb); + batch_builder.add_id(dict_id); + batch_builder.add_data(root); + batch_builder.finish().as_union_value() + }; + + let root = { + let mut message_builder = ipc::MessageBuilder::new(&mut fbb); + message_builder.add_version(write_options.metadata_version); + message_builder.add_header_type(ipc::MessageHeader::DictionaryBatch); + message_builder.add_bodyLength(arrow_data.len() as i64); + message_builder.add_header(root); + message_builder.finish() + }; + + fbb.finish(root, None); + let finished_data = fbb.finished_data(); + + EncodedData { + ipc_message: finished_data.to_vec(), + arrow_data, + } + } } pub struct FileWriter { @@ -504,7 +565,7 @@ impl<'a> Message<'a> { data_gen.record_batch_to_bytes(*batch, *options) } Message::DictionaryBatch(dict_id, array_data, options) => { - dictionary_batch_to_bytes(*dict_id, *array_data, *options) + data_gen.dictionary_batch_to_bytes(*dict_id, *array_data, *options) } } } @@ -574,65 +635,6 @@ fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Resul Ok(total_len as usize) } -/// Write dictionary values into a tuple of bytes, one for the header (ipc::Message) and the other for the data -pub fn dictionary_batch_to_bytes( - dict_id: i64, - array_data: &ArrayDataRef, - write_options: &IpcWriteOptions, -) -> EncodedData { - let mut fbb = FlatBufferBuilder::new(); - - let mut nodes: Vec = vec![]; - let mut buffers: Vec = vec![]; - let mut arrow_data: Vec = vec![]; - - write_array_data( - &array_data, - &mut buffers, - &mut arrow_data, - &mut nodes, - 0, - array_data.len(), - array_data.null_count(), - ); - - // write data - let buffers = fbb.create_vector(&buffers); - let nodes = fbb.create_vector(&nodes); - - let root = { - let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(array_data.len() as i64); - batch_builder.add_nodes(nodes); - batch_builder.add_buffers(buffers); - batch_builder.finish() - }; - - let root = { - let mut batch_builder = ipc::DictionaryBatchBuilder::new(&mut fbb); - batch_builder.add_id(dict_id); - batch_builder.add_data(root); - batch_builder.finish().as_union_value() - }; - - let root = { - let mut message_builder = ipc::MessageBuilder::new(&mut fbb); - message_builder.add_version(write_options.metadata_version); - message_builder.add_header_type(ipc::MessageHeader::DictionaryBatch); - message_builder.add_bodyLength(arrow_data.len() as i64); - message_builder.add_header(root); - message_builder.finish() - }; - - fbb.finish(root, None); - let finished_data = fbb.finished_data(); - - EncodedData { - ipc_message: finished_data.to_vec(), - arrow_data, - } -} - /// Write a record batch to the writer, writing the message size before the message /// if the record batch is being written to a stream fn write_continuation( From 38933611618bda9d96ab0227d9bd1ac657f87ca0 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 14:21:36 -0500 Subject: [PATCH 24/37] Move EncodedData generation out of write_message --- rust/arrow/src/ipc/writer.rs | 60 ++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 34 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 41e04c5d43b..eef7d56234b 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -286,9 +286,8 @@ impl FileWriter { // create an 8-byte boundary after the header writer.write_all(&[0, 0])?; // write the schema, set the written bytes to the schema + header - let message = Message::Schema(schema, &write_options); - let (meta, data) = - write_message(&mut writer, &message, &write_options, &data_gen)?; + let encoded_message = data_gen.schema_to_bytes(schema, &write_options); + let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; Ok(Self { writer, write_options, @@ -310,13 +309,11 @@ impl FileWriter { )); } self.write_dictionaries(&batch)?; - let message = Message::RecordBatch(batch, &self.write_options); - let (meta, data) = write_message( - &mut self.writer, - &message, - &self.write_options, - &self.data_gen, - )?; + let encoded_message = self + .data_gen + .record_batch_to_bytes(batch, &self.write_options); + let (meta, data) = + write_message(&mut self.writer, encoded_message, &self.write_options)?; // add a record block for the footer let block = ipc::Block::new( self.block_offsets as i64, @@ -360,14 +357,16 @@ impl FileWriter { self.last_written_dictionaries .insert(dict_id, column.clone()); - let message = - Message::DictionaryBatch(dict_id, dict_values, &self.write_options); + let encoded_message = self.data_gen.dictionary_batch_to_bytes( + dict_id, + dict_values, + &self.write_options, + ); let (meta, data) = write_message( &mut self.writer, - &message, + encoded_message, &self.write_options, - &self.data_gen, )?; let block = @@ -449,8 +448,8 @@ impl StreamWriter { let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); // write the schema, set the written bytes to the schema - let message = Message::Schema(schema, &write_options); - write_message(&mut writer, &message, &write_options, &data_gen)?; + let encoded_message = data_gen.schema_to_bytes(schema, &write_options); + write_message(&mut writer, encoded_message, &write_options)?; Ok(Self { writer, write_options, @@ -470,13 +469,10 @@ impl StreamWriter { } self.write_dictionaries(&batch)?; - let message = Message::RecordBatch(batch, &self.write_options); - write_message( - &mut self.writer, - &message, - &self.write_options, - &self.data_gen, - )?; + let encoded_message = self + .data_gen + .record_batch_to_bytes(batch, &self.write_options); + write_message(&mut self.writer, encoded_message, &self.write_options)?; Ok(()) } @@ -507,15 +503,13 @@ impl StreamWriter { self.last_written_dictionaries .insert(dict_id, column.clone()); - let message = - Message::DictionaryBatch(dict_id, dict_values, &self.write_options); - - write_message( - &mut self.writer, - &message, + let encoded_message = self.data_gen.dictionary_batch_to_bytes( + dict_id, + dict_values, &self.write_options, - &self.data_gen, - )?; + ); + + write_message(&mut self.writer, encoded_message, &self.write_options)?; } } Ok(()) @@ -574,11 +568,9 @@ impl<'a> Message<'a> { /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written fn write_message( mut writer: &mut BufWriter, - message: &Message, + encoded: EncodedData, write_options: &IpcWriteOptions, - data_gen: &IpcDataGenerator, ) -> Result<(usize, usize)> { - let encoded = message.encode(data_gen); let arrow_data_len = encoded.arrow_data.len(); if arrow_data_len % 8 != 0 { return Err(ArrowError::MemoryError( From 51be8463657ba63b31d50992c14b3a780c5137e3 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 14:32:18 -0500 Subject: [PATCH 25/37] Remove the now-unused intermediate Message enum --- rust/arrow/src/ipc/writer.rs | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index eef7d56234b..1408d5a75c8 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -542,29 +542,6 @@ pub struct EncodedData { pub arrow_data: Vec, } -enum Message<'a> { - Schema(&'a Schema, &'a IpcWriteOptions), - RecordBatch(&'a RecordBatch, &'a IpcWriteOptions), - DictionaryBatch(i64, &'a ArrayDataRef, &'a IpcWriteOptions), -} - -impl<'a> Message<'a> { - /// Encode message to a ipc::Message and return data as bytes - fn encode(&'a self, data_gen: &IpcDataGenerator) -> EncodedData { - match self { - Message::Schema(schema, options) => { - data_gen.schema_to_bytes(*schema, *options) - } - Message::RecordBatch(batch, options) => { - data_gen.record_batch_to_bytes(*batch, *options) - } - Message::DictionaryBatch(dict_id, array_data, options) => { - data_gen.dictionary_batch_to_bytes(*dict_id, *array_data, *options) - } - } - } -} - /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written fn write_message( mut writer: &mut BufWriter, From 71b9d8c06073cfe8e358875f1d362786c597260a Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 15:51:10 -0500 Subject: [PATCH 26/37] Extract a DictionaryTracker for logic of whether replacement is an error --- rust/arrow/src/ipc/writer.rs | 148 ++++++++++++++++++++++------------- 1 file changed, 92 insertions(+), 56 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 1408d5a75c8..9187fe3ff62 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -245,6 +245,55 @@ impl IpcDataGenerator { } } +/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary +/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which +/// isn't allowed in the `FileWriter`. +pub struct DictionaryTracker { + written: HashMap, + error_on_replacement: bool, +} + +impl DictionaryTracker { + pub fn new(error_on_replacement: bool) -> Self { + Self { + written: HashMap::new(), + error_on_replacement, + } + } + + /// Keep track of the dictionary with the given ID and values. Behavior: + /// + /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate + /// that the dictionary was not actually inserted (because it's already been seen). + /// * If this ID has been written already but with different data, and this tracker is + /// configured to return an error, return an error. + /// * If the tracker has not been configured to error on replacement or this dictionary + /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just + /// inserted. + pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result { + let dict_data = column.data(); + let dict_values = &dict_data.child_data()[0]; + + // If a dictionary with this id was already emitted, check if it was the same. + if let Some(last) = self.written.get(&dict_id) { + if last.data().child_data()[0] == *dict_values { + // Same dictionary values => no need to emit it again + return Ok(false); + } else if self.error_on_replacement { + return Err(ArrowError::InvalidArgumentError( + "Dictionary replacement detected when writing IPC file format. \ + Arrow IPC files only support a single dictionary for a given field \ + across all batches." + .to_string(), + )); + } + } + + self.written.insert(dict_id, column.clone()); + Ok(true) + } +} + pub struct FileWriter { /// The object to write to writer: BufWriter, @@ -261,7 +310,7 @@ pub struct FileWriter { /// Whether the writer footer has been written, and the writer is finished finished: bool, /// Keeps track of dictionaries that have been written - last_written_dictionaries: HashMap, + dictionary_tracker: DictionaryTracker, data_gen: IpcDataGenerator, } @@ -296,7 +345,7 @@ impl FileWriter { dictionary_blocks: vec![], record_blocks: vec![], finished: false, - last_written_dictionaries: HashMap::new(), + dictionary_tracker: DictionaryTracker::new(true), data_gen, }) } @@ -339,40 +388,29 @@ impl FileWriter { let dict_data = column.data(); let dict_values = &dict_data.child_data()[0]; - // If a dictionary with this id was already emitted, check if it was the same. - if let Some(last_dictionary) = - self.last_written_dictionaries.get(&dict_id) - { - if last_dictionary.data().child_data()[0] == *dict_values { - // Same dictionary values => no need to emit it again - continue; - } else { - return Err(ArrowError::InvalidArgumentError( - "Dictionary replacement detected when writing IPC file format. \ - Arrow IPC files only support a single dictionary for a given field \ - across all batches.".to_string())); - } + let emit = self.dictionary_tracker.insert(dict_id, column)?; + + if emit { + let encoded_message = self.data_gen.dictionary_batch_to_bytes( + dict_id, + dict_values, + &self.write_options, + ); + + let (meta, data) = write_message( + &mut self.writer, + encoded_message, + &self.write_options, + )?; + + let block = ipc::Block::new( + self.block_offsets as i64, + meta as i32, + data as i64, + ); + self.dictionary_blocks.push(block); + self.block_offsets += meta + data; } - - self.last_written_dictionaries - .insert(dict_id, column.clone()); - - let encoded_message = self.data_gen.dictionary_batch_to_bytes( - dict_id, - dict_values, - &self.write_options, - ); - - let (meta, data) = write_message( - &mut self.writer, - encoded_message, - &self.write_options, - )?; - - let block = - ipc::Block::new(self.block_offsets as i64, meta as i32, data as i64); - self.dictionary_blocks.push(block); - self.block_offsets += meta + data; } } Ok(()) @@ -428,7 +466,7 @@ pub struct StreamWriter { /// Whether the writer footer has been written, and the writer is finished finished: bool, /// Keeps track of dictionaries that have been written - last_written_dictionaries: HashMap, + dictionary_tracker: DictionaryTracker, data_gen: IpcDataGenerator, } @@ -455,7 +493,7 @@ impl StreamWriter { write_options, schema: schema.clone(), finished: false, - last_written_dictionaries: HashMap::new(), + dictionary_tracker: DictionaryTracker::new(false), data_gen, }) } @@ -490,26 +528,24 @@ impl StreamWriter { let dict_data = column.data(); let dict_values = &dict_data.child_data()[0]; - // If a dictionary with this id was already emitted, check if it was the same. - if let Some(last_dictionary) = - self.last_written_dictionaries.get(&dict_id) - { - if last_dictionary.data().child_data()[0] == *dict_values { - // Same dictionary values => no need to emit it again - continue; - } + let emit = self + .dictionary_tracker + .insert(dict_id, column) + .expect("StreamWriter is configured to not error on replacement"); + + if emit { + let encoded_message = self.data_gen.dictionary_batch_to_bytes( + dict_id, + dict_values, + &self.write_options, + ); + + write_message( + &mut self.writer, + encoded_message, + &self.write_options, + )?; } - - self.last_written_dictionaries - .insert(dict_id, column.clone()); - - let encoded_message = self.data_gen.dictionary_batch_to_bytes( - dict_id, - dict_values, - &self.write_options, - ); - - write_message(&mut self.writer, encoded_message, &self.write_options)?; } } Ok(()) From ab7ccc3e1c30b30ace6e9765436732df860d1a84 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 14:54:04 -0500 Subject: [PATCH 27/37] Extract shared dictionary code to the new struct --- rust/arrow/src/ipc/writer.rs | 146 +++++++++++++++-------------------- 1 file changed, 61 insertions(+), 85 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 9187fe3ff62..f28b1cf6994 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -129,6 +129,43 @@ impl IpcDataGenerator { } } + pub fn encoded_batch( + &self, + batch: &RecordBatch, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + ) -> Result<(Vec, EncodedData)> { + // TODO: handle nested dictionaries + let schema = batch.schema(); + let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); + + for (i, field) in schema.fields().iter().enumerate() { + let column = batch.column(i); + + if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { + let dict_id = field + .dict_id() + .expect("All Dictionary types have `dict_id`"); + let dict_data = column.data(); + let dict_values = &dict_data.child_data()[0]; + + let emit = dictionary_tracker.insert(dict_id, column)?; + + if emit { + encoded_dictionaries.push(self.dictionary_batch_to_bytes( + dict_id, + dict_values, + write_options, + )); + } + } + } + + let encoded_message = self.record_batch_to_bytes(batch, write_options); + + Ok((encoded_dictionaries, encoded_message)) + } + /// Write a `RecordBatch` into two sets of bytes, one for the header (ipc::Message) and the /// other for the batch's data pub fn record_batch_to_bytes( @@ -357,10 +394,23 @@ impl FileWriter { "Cannot write record batch to file writer as it is closed".to_string(), )); } - self.write_dictionaries(&batch)?; - let encoded_message = self - .data_gen - .record_batch_to_bytes(batch, &self.write_options); + + let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch( + batch, + &mut self.dictionary_tracker, + &self.write_options, + )?; + + for encoded_dictionary in encoded_dictionaries { + let (meta, data) = + write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; + + let block = + ipc::Block::new(self.block_offsets as i64, meta as i32, data as i64); + self.dictionary_blocks.push(block); + self.block_offsets += meta + data; + } + let (meta, data) = write_message(&mut self.writer, encoded_message, &self.write_options)?; // add a record block for the footer @@ -374,48 +424,6 @@ impl FileWriter { Ok(()) } - fn write_dictionaries(&mut self, batch: &RecordBatch) -> Result<()> { - // TODO: handle nested dictionaries - - let schema = batch.schema(); - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); - - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); - let dict_data = column.data(); - let dict_values = &dict_data.child_data()[0]; - - let emit = self.dictionary_tracker.insert(dict_id, column)?; - - if emit { - let encoded_message = self.data_gen.dictionary_batch_to_bytes( - dict_id, - dict_values, - &self.write_options, - ); - - let (meta, data) = write_message( - &mut self.writer, - encoded_message, - &self.write_options, - )?; - - let block = ipc::Block::new( - self.block_offsets as i64, - meta as i32, - data as i64, - ); - self.dictionary_blocks.push(block); - self.block_offsets += meta + data; - } - } - } - Ok(()) - } - /// Write footer and closing tag, then mark the writer as done pub fn finish(&mut self) -> Result<()> { // write EOS @@ -505,49 +513,17 @@ impl StreamWriter { "Cannot write record batch to stream writer as it is closed".to_string(), )); } - self.write_dictionaries(&batch)?; - let encoded_message = self + let (encoded_dictionaries, encoded_message) = self .data_gen - .record_batch_to_bytes(batch, &self.write_options); - write_message(&mut self.writer, encoded_message, &self.write_options)?; - Ok(()) - } - - fn write_dictionaries(&mut self, batch: &RecordBatch) -> Result<()> { - // TODO: handle nested dictionaries - - let schema = batch.schema(); - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); - - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); - let dict_data = column.data(); - let dict_values = &dict_data.child_data()[0]; - - let emit = self - .dictionary_tracker - .insert(dict_id, column) - .expect("StreamWriter is configured to not error on replacement"); + .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options) + .expect("StreamWriter is configured to not error on dictionary replacement"); - if emit { - let encoded_message = self.data_gen.dictionary_batch_to_bytes( - dict_id, - dict_values, - &self.write_options, - ); - - write_message( - &mut self.writer, - encoded_message, - &self.write_options, - )?; - } - } + for encoded_dictionary in encoded_dictionaries { + write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; } + + write_message(&mut self.writer, encoded_message, &self.write_options)?; Ok(()) } From dca9d09a0f554b1a8583b32007ba4c2856e5f3b3 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Thu, 3 Dec 2020 16:32:00 -0500 Subject: [PATCH 28/37] Always create FlightData for dictionaries when given a RecordBatch --- rust/arrow-flight/src/utils.rs | 40 ++++++++++++++--------- rust/arrow/src/ipc/writer.rs | 4 +-- rust/datafusion/examples/flight_server.rs | 6 +++- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index 77e4092eb64..c2fcfa92884 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -26,30 +26,40 @@ use arrow::error::{ArrowError, Result}; use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions}; use arrow::record_batch::RecordBatch; -/// Convert a `RecordBatch` to `FlightData` by converting the header and body to bytes +/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries +/// and values. This can't be a `From` implementation because neither `RecordBatch` nor `Vec` are +/// implemented in this crate. /// /// Note: This implicitly uses the default `IpcWriteOptions`. To configure options, /// use `flight_data_from_arrow_batch()` -impl From<&RecordBatch> for FlightData { - fn from(batch: &RecordBatch) -> Self { - let options = IpcWriteOptions::default(); - flight_data_from_arrow_batch(batch, &options) - } +pub fn convert_to_flight_data(batch: &RecordBatch) -> Vec { + let options = IpcWriteOptions::default(); + flight_data_from_arrow_batch(batch, &options) } -/// Convert a `RecordBatch` to `FlightData` by converting the header and body to bytes +/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries +/// and values pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, -) -> FlightData { +) -> Vec { let data_gen = writer::IpcDataGenerator::default(); - let data = data_gen.record_batch_to_bytes(batch, &options); - FlightData { - flight_descriptor: None, - app_metadata: vec![], - data_header: data.ipc_message, - data_body: data.arrow_data, - } + let mut dictionary_tracker = writer::DictionaryTracker::new(false); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, &options) + .expect("DictionaryTracker configured above to not error on replacement"); + + encoded_dictionaries + .into_iter() + .chain(std::iter::once(encoded_batch)) + .map(|data| FlightData { + flight_descriptor: None, + app_metadata: vec![], + data_header: data.ipc_message, + data_body: data.arrow_data, + }) + .collect() } /// Convert a `Schema` to `SchemaResult` by converting to an IPC message diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index f28b1cf6994..361edb8f510 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -168,7 +168,7 @@ impl IpcDataGenerator { /// Write a `RecordBatch` into two sets of bytes, one for the header (ipc::Message) and the /// other for the batch's data - pub fn record_batch_to_bytes( + fn record_batch_to_bytes( &self, batch: &RecordBatch, write_options: &IpcWriteOptions, @@ -222,7 +222,7 @@ impl IpcDataGenerator { /// Write dictionary values into two sets of bytes, one for the header (ipc::Message) and the /// other for the data - pub fn dictionary_batch_to_bytes( + fn dictionary_batch_to_bytes( &self, dict_id: i64, array_data: &ArrayDataRef, diff --git a/rust/datafusion/examples/flight_server.rs b/rust/datafusion/examples/flight_server.rs index a601b7cafdd..a5e4aee6017 100644 --- a/rust/datafusion/examples/flight_server.rs +++ b/rust/datafusion/examples/flight_server.rs @@ -114,7 +114,11 @@ impl FlightService for FlightServiceImpl { let mut batches: Vec> = results .iter() - .map(|batch| Ok(FlightData::from(batch))) + .flat_map(|batch| { + let flight_data = + arrow_flight::utils::convert_to_flight_data(batch); + flight_data.into_iter().map(Ok) + }) .collect(); // append batch vector to schema vector, so that the first message sent is the schema From 6eb0614b7d6eca1e21c8668fe8aafe85f77c10f2 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 09:13:48 -0500 Subject: [PATCH 29/37] cleaner --- .../src/bin/flight-test-integration-server.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 76ca7487305..4a4e8c1be31 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -318,23 +318,15 @@ impl FlightService for FlightServiceImpl { } fn flight_schema(arrow_schema: &Schema) -> Result> { - // let schema_result = SchemaResult::from(&flight.schema); - use arrow::ipc::{writer::IpcWriteOptions, MetadataVersion}; - // use arrow_flight::utils; - // let schema_result = utils::flight_schema_from_arrow_schema(&flight.schema, &IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); - - //let mut ss = schema_result.schema; - // ss.splice(0..0, vec![u8::MAX, u8::MAX, u8::MAX, u8::MAX]); - // let ss = schema_result.schema; - let mut ss = vec![]; + let mut schema = vec![]; let wo = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); let msg = arrow::ipc::writer::Message::Schema(arrow_schema, &wo); - arrow::ipc::writer::write_message(&mut ss, &msg, &wo)?; + arrow::ipc::writer::write_message(&mut schema, &msg, &wo)?; - Ok(ss) + Ok(schema) } #[derive(Clone, Default)] From 5d3b3335925c815d83adbde6cc160d623f51cb13 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 11:21:20 -0500 Subject: [PATCH 30/37] empty fix --- .../src/bin/flight-test-integration-client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 47b91a37cd9..cb10030dc66 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -297,7 +297,7 @@ async fn upload_data( } } else { eprintln!("No batches"); - + drop(upload_tx); let outer = client.do_put(Request::new(upload_rx)).await?; let inner = outer.into_inner(); From 0b7b241e7d5b1df2479101b34126a81b076fae78 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 11:22:47 -0500 Subject: [PATCH 31/37] kindawork --- .../src/bin/flight-test-integration-client.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index cb10030dc66..89293780ec1 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -350,6 +350,10 @@ async fn consume_flight_location( expected_data: &[RecordBatch], schema: SchemaRef, ) -> Result { + let mut location = location; + location.uri = location.uri.replace("grpc+tcp://", "grpc://"); + + dbg!(&location.uri); let mut client = FlightServiceClient::connect(location.uri).await?; dbg!(&client); From d436206e1c6b463332149d5481cc1d81ec69abd4 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 13:13:48 -0500 Subject: [PATCH 32/37] [upstream] print both values --- cpp/src/arrow/ipc/reader.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 92f2b70f294..f0dc3093ecf 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -860,8 +860,7 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { } if (message->type() != MessageType::DICTIONARY_BATCH) { - return Status::Invalid("IPC stream did not have the expected number (", num_dicts, - ") of dictionaries at the start of the stream"); + return Status::Invalid("IPC stream had (", i, ") dictionaries at the start of the stream, but (", num_dicts, ") were expected"); } RETURN_NOT_OK(ReadDictionary(*message)); } From 87b367831d591c1d2f17baba788eac8c6a53a00c Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 14:21:20 -0500 Subject: [PATCH 33/37] reduce logging --- cpp/src/arrow/ipc/message.cc | 24 ++++++++++++------------ rust/arrow/src/ipc/writer.rs | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index 9d0e2577dc8..669307b58a1 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -470,7 +470,7 @@ class MessageDecoder::MessageDecoderImpl { metadata_(nullptr) {} Status ConsumeData(const uint8_t* data, int64_t size) { - std::cerr << "ConsumeData / next_required_size_ " << next_required_size_ << std::endl; + // std::cerr << "ConsumeData / next_required_size_ " << next_required_size_ << std::endl; if (buffered_size_ == 0) { while (size > 0 && size >= next_required_size_) { @@ -508,7 +508,7 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeBuffer(std::shared_ptr buffer) { - std::cerr << "ConsumeBuffer / next_required_size_ " << next_required_size_ << std::endl; + // std::cerr << "ConsumeBuffer / next_required_size_ " << next_required_size_ << std::endl; if (buffered_size_ == 0) { while (buffer->size() >= next_required_size_) { auto used_size = next_required_size_; @@ -602,18 +602,18 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeInitial(int32_t continuation) { - std::cerr << "ConsumeInitial / continuation " << continuation << std::endl; + // std::cerr << "ConsumeInitial / continuation " << continuation << std::endl; if (continuation == internal::kIpcContinuationToken) { state_ = State::METADATA_LENGTH; next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength; - std::cerr << "ConsumeInitial / A / next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeInitial / A / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadataLength()); // Valid IPC message, read the message length now return Status::OK(); } else if (continuation == 0) { state_ = State::EOS; next_required_size_ = 0; - std::cerr << "ConsumeInitial / B / next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeInitial / B / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnEOS()); return Status::OK(); } else if (continuation > 0) { @@ -621,7 +621,7 @@ class MessageDecoder::MessageDecoderImpl { // ARROW-6314: Backwards compatibility for reading old IPC // messages produced prior to version 0.15.0 next_required_size_ = continuation; - std::cerr << "ConsumeInitial / C / next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeInitial / C / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadata()); return Status::OK(); } else { @@ -649,13 +649,13 @@ class MessageDecoder::MessageDecoderImpl { if (metadata_length == 0) { state_ = State::EOS; next_required_size_ = 0; - std::cerr << "ConsumeMetadataLength / A /next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeMetadataLength / A /next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnEOS()); return Status::OK(); } else if (metadata_length > 0) { state_ = State::METADATA; next_required_size_ = metadata_length; - std::cerr << "ConsumeMetadataLength / B / next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeMetadataLength / B / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadata()); return Status::OK(); } else { @@ -674,7 +674,7 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeMetadataChunks() { - std::cerr << "ConsumeMetadataChunks / next_required_size_ " << next_required_size_ << std::endl; + // std::cerr << "ConsumeMetadataChunks / next_required_size_ " << next_required_size_ << std::endl; if (chunks_[0]->size() >= next_required_size_) { if (chunks_[0]->size() == next_required_size_) { @@ -710,7 +710,7 @@ class MessageDecoder::MessageDecoderImpl { state_ = State::BODY; next_required_size_ = body_length; - std::cerr << "ConsumeMetadata / next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeMetadata / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnBody()); if (next_required_size_ == 0) { ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_)); @@ -726,7 +726,7 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeBodyChunks() { - std::cerr << "ConsumeBodyChunks / next_required_size_ " << next_required_size_ << std::endl; + // std::cerr << "ConsumeBodyChunks / next_required_size_ " << next_required_size_ << std::endl; if (chunks_[0]->size() >= next_required_size_) { auto used_size = next_required_size_; @@ -755,7 +755,7 @@ class MessageDecoder::MessageDecoderImpl { RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message))); state_ = State::INITIAL; next_required_size_ = kMessageDecoderNextRequiredSizeInitial; - std::cerr << "ConsumeBody / next_required_size_ = " << next_required_size_ << std::endl; + // std::cerr << "ConsumeBody / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnInitial()); return Status::OK(); } diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index d1b34cc0d5c..aad9bc39a6f 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -623,9 +623,9 @@ fn write_continuation( total_len: i32, ) -> Result { let mut written = 8; - dbg!("write_continuation", write_options); + // the version of the writer determines whether continuation markers should be added - match dbg!(write_options.metadata_version) { + match write_options.metadata_version { ipc::MetadataVersion::V1 | ipc::MetadataVersion::V2 | ipc::MetadataVersion::V3 => { From dd91f4f59857baaeaf4a3d8054c84c7bbb830103 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 15:25:13 -0500 Subject: [PATCH 34/37] preload the authentication handshake --- .../src/bin/flight-test-integration-client.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 9a727f43dba..11ac4994291 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -30,7 +30,7 @@ use arrow_flight::{ use arrow_flight::{utils::flight_data_to_arrow_batch, FlightDescriptor}; use clap::{App, Arg}; -use futures::{channel::mpsc, sink::SinkExt, StreamExt}; +use futures::{channel::mpsc, stream, sink::SinkExt, StreamExt}; use prost::Message; use tonic::{metadata::MetadataValue, Request, Status}; @@ -191,24 +191,20 @@ async fn authenticate( username: &str, password: &str, ) -> Result { - let (mut tx, rx) = mpsc::channel(10); - let rx = client.handshake(Request::new(rx)).await?; - let mut rx = rx.into_inner(); - let auth = BasicAuth { username: username.into(), password: password.into(), }; - let mut payload = vec![]; auth.encode(&mut payload)?; - tx.send(HandshakeRequest { + let req = stream::once(async { HandshakeRequest { payload, ..HandshakeRequest::default() - }) - .await?; - drop(tx); + }}); + + let rx = client.handshake(Request::new(req)).await?; + let mut rx = rx.into_inner(); let r = rx.next().await.expect("must respond from handshake")?; assert!(rx.next().await.is_none(), "must not respond a second time"); From 987fd609c578e5cb8d8bba0bcd49fad728103400 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Fri, 4 Dec 2020 15:30:58 -0500 Subject: [PATCH 35/37] Actually use the dictionaries instead of throwing them away :P --- rust/arrow-flight/src/utils.rs | 4 +- .../src/bin/flight-test-integration-client.rs | 74 ++++++++++++++++--- .../src/bin/flight-test-integration-server.rs | 35 +++++++-- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index 570529747ec..5b1a272958c 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -21,6 +21,7 @@ use std::convert::TryFrom; use crate::{FlightData, SchemaResult}; +use arrow::array::ArrayRef; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions}; @@ -144,11 +145,10 @@ impl TryFrom<&SchemaResult> for Schema { pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, + dictionaries_by_field: &[Option], ) -> Option> { // check that the data_header is a record batch message let message = arrow::ipc::get_root_as_message(&data.data_header[..]); - // This assumes there are no dictionaries - let dictionaries_by_field = Vec::new(); message .header_as_record_batch() diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 11ac4994291..b5c209d3e49 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -20,6 +20,7 @@ use arrow_integration_testing::{ }; use arrow::datatypes::SchemaRef; +use arrow::ipc::{self, reader}; use arrow::record_batch::RecordBatch; use arrow_flight::flight_service_client::FlightServiceClient; @@ -260,10 +261,21 @@ async fn upload_data( let metadata = counter.to_string().into_bytes(); eprintln!("sending batch {:?}", metadata); - let mut batch = arrow_flight::utils::convert_to_flight_data(first_batch).pop().unwrap(); - batch.app_metadata = metadata.clone(); + let mut record_batch_flight_datas: Vec = + arrow_flight::utils::convert_to_flight_data(first_batch); + + let mut batch = record_batch_flight_datas + .pop() + .expect("At least one FlightData should be created for every RecordBatch"); + for dictionary_flight_data in record_batch_flight_datas { + upload_tx.send(dictionary_flight_data).await?; + } + + // Only the record batch's FlightData gets app_metadata + batch.app_metadata = metadata.clone(); upload_tx.send(batch).await?; + let outer = client.do_put(Request::new(upload_rx)).await?; let mut inner = outer.into_inner(); @@ -279,10 +291,21 @@ async fn upload_data( let metadata = counter.to_string().into_bytes(); eprintln!("sending batch {:?}", metadata); - let mut batch = arrow_flight::utils::convert_to_flight_data(batch).pop().unwrap(); - batch.app_metadata = metadata.clone(); + let mut record_batch_flight_datas: Vec = + arrow_flight::utils::convert_to_flight_data(batch); + let mut batch = record_batch_flight_datas.pop().expect( + "At least one FlightData should be created for every RecordBatch", + ); + + for dictionary_flight_data in record_batch_flight_datas { + upload_tx.send(dictionary_flight_data).await?; + } + + // Only the record batch's FlightData gets app_metadata + batch.app_metadata = metadata.clone(); upload_tx.send(batch).await?; + let r = inner .next() .await @@ -364,24 +387,55 @@ async fn consume_flight_location( let mut resp = resp?.into_inner(); dbg!(&resp); - let schema_again = resp.next().await.unwrap(); - dbg!(&schema_again); + let _schema_again = resp.next().await.unwrap(); + let mut dictionaries_by_field = vec![None; schema.fields().len()]; for (counter, expected_batch) in expected_data.iter().enumerate() { - let actual_batch = resp.next().await.unwrap_or_else(|| { + let mut actual_batch = resp.next().await.unwrap_or_else(|| { panic!( "Got fewer batches than expected, received so far: {} expected: {}", counter, expected_data.len(), ) })?; + let mut message = arrow::ipc::get_root_as_message(&actual_batch.data_header[..]); + dbg!(message.header_type()); + while message.header_type() == ipc::MessageHeader::DictionaryBatch { + // TODO: handle None which means parse failure + if let Some(ipc_batch) = message.header_as_dictionary_batch() { + let dictionary_batch_result = reader::read_dictionary( + &actual_batch.data_body, + ipc_batch, + &schema, + &mut dictionaries_by_field, + ); + if let Err(e) = dictionary_batch_result { + panic!("Error reading dictionary: {:?}", e); + } else { + dbg!(&dictionaries_by_field); + } + } + + actual_batch = resp.next().await.unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + })?; + message = arrow::ipc::get_root_as_message(&actual_batch.data_header[..]); + } let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, actual_batch.app_metadata); - let actual_batch = flight_data_to_arrow_batch(&actual_batch, schema.clone()) - .expect("Unable to convert flight data to Arrow batch") - .expect("Unable to convert flight data to Arrow batch"); + let actual_batch = flight_data_to_arrow_batch( + &actual_batch, + schema.clone(), + &dictionaries_by_field, + ) + .expect("Unable to convert flight data to Arrow batch") + .expect("Unable to convert flight data to Arrow batch"); assert_eq!(expected_batch.schema(), actual_batch.schema()); assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 8f79acba71f..15cb00f9616 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -93,15 +93,31 @@ impl FlightService for FlightServiceImpl { data_header, ..Default::default() }) - .map_err(|e| Status::internal(format!("Could not generate ipc schema: {}", e))), + .map_err(|e| { + Status::internal(format!("Could not generate ipc schema: {}", e)) + }), ); - let batches = flight.chunks.iter().enumerate().map(|(counter, batch)| { - let mut flight_data = arrow_flight::utils::convert_to_flight_data(batch).pop().unwrap(); - let metadata = counter.to_string().into_bytes(); - flight_data.app_metadata = metadata; - Ok(flight_data) - }); + let batches = flight + .chunks + .iter() + .enumerate() + .flat_map(|(counter, batch)| { + let mut record_batch_flight_datas: Vec = + arrow_flight::utils::convert_to_flight_data(batch); + + let mut flight_data = record_batch_flight_datas.pop().expect( + "At least one FlightData should be created for every RecordBatch", + ); + + let metadata = counter.to_string().into_bytes(); + flight_data.app_metadata = metadata; + + record_batch_flight_datas + .into_iter() + .chain(std::iter::once(flight_data)) + .map(Ok) + }); let output = futures::stream::iter(schema.chain(batches).collect::>()); @@ -318,7 +334,10 @@ impl FlightService for FlightServiceImpl { } fn flight_schema(arrow_schema: &Schema) -> Result> { - use arrow::ipc::{writer::{IpcWriteOptions, IpcDataGenerator}, MetadataVersion}; + use arrow::ipc::{ + writer::{IpcDataGenerator, IpcWriteOptions}, + MetadataVersion, + }; let mut schema = vec![]; From 2574ce5731211af5ec456e0d324d7f636bfdd5a6 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Fri, 4 Dec 2020 15:59:51 -0500 Subject: [PATCH 36/37] Clear up what's being returned from converting a record batch to flight data --- rust/arrow-flight/src/utils.rs | 26 +++++++----- .../src/bin/flight-test-integration-client.rs | 40 ++++++++----------- .../src/bin/flight-test-integration-server.rs | 13 +++--- 3 files changed, 37 insertions(+), 42 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index 5b1a272958c..17223d4c3cb 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -24,7 +24,7 @@ use crate::{FlightData, SchemaResult}; use arrow::array::ArrayRef; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; -use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions}; +use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteOptions}; use arrow::record_batch::RecordBatch; /// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries @@ -33,7 +33,7 @@ use arrow::record_batch::RecordBatch; /// /// Note: This implicitly uses the default `IpcWriteOptions`. To configure options, /// use `flight_data_from_arrow_batch()` -pub fn convert_to_flight_data(batch: &RecordBatch) -> Vec { +pub fn convert_to_flight_data(batch: &RecordBatch) -> (Vec, FlightData) { let options = IpcWriteOptions::default(); flight_data_from_arrow_batch(batch, &options) } @@ -43,7 +43,7 @@ pub fn convert_to_flight_data(batch: &RecordBatch) -> Vec { pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, -) -> Vec { +) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); let mut dictionary_tracker = writer::DictionaryTracker::new(false); @@ -51,16 +51,20 @@ pub fn flight_data_from_arrow_batch( .encoded_batch(batch, &mut dictionary_tracker, &options) .expect("DictionaryTracker configured above to not error on replacement"); - encoded_dictionaries - .into_iter() - .chain(std::iter::once(encoded_batch)) - .map(|data| FlightData { - flight_descriptor: None, - app_metadata: vec![], + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) +} + +impl From for FlightData { + fn from(data: EncodedData) -> Self { + FlightData { data_header: data.ipc_message, data_body: data.arrow_data, - }) - .collect() + ..Default::default() + } + } } /// Convert a `Schema` to `SchemaResult` by converting to an IPC message diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index b5c209d3e49..7aa10128fb2 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -31,7 +31,7 @@ use arrow_flight::{ use arrow_flight::{utils::flight_data_to_arrow_batch, FlightDescriptor}; use clap::{App, Arg}; -use futures::{channel::mpsc, stream, sink::SinkExt, StreamExt}; +use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use prost::Message; use tonic::{metadata::MetadataValue, Request, Status}; @@ -199,10 +199,12 @@ async fn authenticate( let mut payload = vec![]; auth.encode(&mut payload)?; - let req = stream::once(async { HandshakeRequest { - payload, - ..HandshakeRequest::default() - }}); + let req = stream::once(async { + HandshakeRequest { + payload, + ..HandshakeRequest::default() + } + }); let rx = client.handshake(Request::new(req)).await?; let mut rx = rx.into_inner(); @@ -261,20 +263,16 @@ async fn upload_data( let metadata = counter.to_string().into_bytes(); eprintln!("sending batch {:?}", metadata); - let mut record_batch_flight_datas: Vec = + let (dictionary_flight_data, mut batch_flight_data) = arrow_flight::utils::convert_to_flight_data(first_batch); - let mut batch = record_batch_flight_datas - .pop() - .expect("At least one FlightData should be created for every RecordBatch"); - - for dictionary_flight_data in record_batch_flight_datas { - upload_tx.send(dictionary_flight_data).await?; + for dictionary in dictionary_flight_data { + upload_tx.send(dictionary).await?; } // Only the record batch's FlightData gets app_metadata - batch.app_metadata = metadata.clone(); - upload_tx.send(batch).await?; + batch_flight_data.app_metadata = metadata.clone(); + upload_tx.send(batch_flight_data).await?; let outer = client.do_put(Request::new(upload_rx)).await?; let mut inner = outer.into_inner(); @@ -291,20 +289,16 @@ async fn upload_data( let metadata = counter.to_string().into_bytes(); eprintln!("sending batch {:?}", metadata); - let mut record_batch_flight_datas: Vec = + let (dictionary_flight_data, mut batch_flight_data) = arrow_flight::utils::convert_to_flight_data(batch); - let mut batch = record_batch_flight_datas.pop().expect( - "At least one FlightData should be created for every RecordBatch", - ); - - for dictionary_flight_data in record_batch_flight_datas { - upload_tx.send(dictionary_flight_data).await?; + for dictionary in dictionary_flight_data { + upload_tx.send(dictionary).await?; } // Only the record batch's FlightData gets app_metadata - batch.app_metadata = metadata.clone(); - upload_tx.send(batch).await?; + batch_flight_data.app_metadata = metadata.clone(); + upload_tx.send(batch_flight_data).await?; let r = inner .next() diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs index 15cb00f9616..b8a40aec4c5 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-server.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -103,19 +103,16 @@ impl FlightService for FlightServiceImpl { .iter() .enumerate() .flat_map(|(counter, batch)| { - let mut record_batch_flight_datas: Vec = + let (dictionary_flight_data, mut batch_flight_data) = arrow_flight::utils::convert_to_flight_data(batch); - let mut flight_data = record_batch_flight_datas.pop().expect( - "At least one FlightData should be created for every RecordBatch", - ); - + // Only the record batch's FlightData gets app_metadata let metadata = counter.to_string().into_bytes(); - flight_data.app_metadata = metadata; + batch_flight_data.app_metadata = metadata; - record_batch_flight_datas + dictionary_flight_data .into_iter() - .chain(std::iter::once(flight_data)) + .chain(std::iter::once(batch_flight_data)) .map(Ok) }); From 46f76d762f1a8ae51ec960b9999e559b6813dd71 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 4 Dec 2020 16:50:04 -0500 Subject: [PATCH 37/37] send_all --- .../src/bin/flight-test-integration-client.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs index 7aa10128fb2..7f38e6d8ed9 100644 --- a/rust/integration-testing/src/bin/flight-test-integration-client.rs +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -266,9 +266,7 @@ async fn upload_data( let (dictionary_flight_data, mut batch_flight_data) = arrow_flight::utils::convert_to_flight_data(first_batch); - for dictionary in dictionary_flight_data { - upload_tx.send(dictionary).await?; - } + upload_tx.send_all(&mut stream::iter(dictionary_flight_data).map(Ok)).await?; // Only the record batch's FlightData gets app_metadata batch_flight_data.app_metadata = metadata.clone(); @@ -292,9 +290,7 @@ async fn upload_data( let (dictionary_flight_data, mut batch_flight_data) = arrow_flight::utils::convert_to_flight_data(batch); - for dictionary in dictionary_flight_data { - upload_tx.send(dictionary).await?; - } + upload_tx.send_all(&mut stream::iter(dictionary_flight_data).map(Ok)).await?; // Only the record batch's FlightData gets app_metadata batch_flight_data.app_metadata = metadata.clone();