diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 5c56e6409a7..94dfdeb9f8d 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -498,14 +498,21 @@ class GrpcStreamReader : public FlightStreamReader { app_metadata_(nullptr) {} Status EnsureDataStarted() { + std::cout << "Here i am in GrpcStreamReader EnsureDataStarted" << std::endl; + if (!batch_reader_) { + std::cout << "yes batch_reader_" << std::endl; + bool skipped_to_data = false; { auto guard = TakeGuard(); skipped_to_data = peekable_reader_->SkipToData(); + std::cout << "TakeGuard, SkipToData" << std::endl; } // peek() until we find the first data message; discard metadata if (!skipped_to_data) { + std::cout << "!skipped_to_data" << std::endl; + return OverrideWithServerError(MakeFlightError( FlightStatusCode::Internal, "Server never sent a data message")); } @@ -513,10 +520,16 @@ class GrpcStreamReader : public FlightStreamReader { auto message_reader = std::unique_ptr(new GrpcIpcMessageReader( rpc_, read_mutex_, stream_, peekable_reader_, &app_metadata_)); + std::cout << "yes message_reader" << std::endl; + auto result = ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_); + std::cout << "yes result" << std::endl; + RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_))); } + std::cout << "the end" << std::endl; + return Status::OK(); } arrow::Result> GetSchema() override { @@ -1141,6 +1154,8 @@ class FlightClient::FlightClientImpl { *out = std::unique_ptr( new StreamReader(rpc, nullptr, options.read_options, finishable_stream)); // Eagerly read the schema + std::cout << "Here i am in DoGet" << std::endl; + return static_cast(out->get())->EnsureDataStarted(); } @@ -1151,6 +1166,8 @@ class FlightClient::FlightClientImpl { using GrpcStream = grpc::ClientReaderWriter; using StreamWriter = GrpcStreamWriter; + std::cerr << "DoPut called" << std::endl; + auto rpc = std::make_shared(options); RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); std::shared_ptr stream = stub_->DoPut(&rpc->context); diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 87c96ce4910..1957bba383e 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -622,6 +622,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, ServerWriter* writer) { + std::cout << "in base DoGet" << std::endl; GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); diff --git a/cpp/src/arrow/flight/test_integration_client.cc b/cpp/src/arrow/flight/test_integration_client.cc index 8f331f926ef..1f070197d8b 100644 --- a/cpp/src/arrow/flight/test_integration_client.cc +++ b/cpp/src/arrow/flight/test_integration_client.cc @@ -96,11 +96,15 @@ Status ConsumeFlightLocation( std::unique_ptr stream; RETURN_NOT_OK(read_client->DoGet(ticket, &stream)); + std::cout << "Here i am in ConsumeFlightLocation" << std::endl; + int counter = 0; const int expected = static_cast(retrieved_data.size()); for (const auto& original_batch : retrieved_data) { FlightStreamChunk chunk; RETURN_NOT_OK(stream->Next(&chunk)); + std::cout << "The counter is " << counter << std::endl; + if (chunk.data == nullptr) { return Status::Invalid("Got fewer batches than expected, received so far: ", counter, " expected ", expected); diff --git a/cpp/src/arrow/flight/test_integration_server.cc b/cpp/src/arrow/flight/test_integration_server.cc index 4b904b0eba1..bbecdf769ac 100644 --- a/cpp/src/arrow/flight/test_integration_server.cc +++ b/cpp/src/arrow/flight/test_integration_server.cc @@ -111,8 +111,11 @@ class FlightIntegrationTestServer : public FlightServerBase { Status DoGet(const ServerCallContext& context, const Ticket& request, std::unique_ptr* data_stream) override { + std::cout << "In Server DoGet" << std::endl; auto data = uploaded_chunks.find(request.ticket); if (data == uploaded_chunks.end()) { + std::cout << "Could not find flight" << std::endl; + return Status::KeyError("Could not find flight.", request.ticket); } auto flight = data->second; @@ -121,6 +124,8 @@ class FlightIntegrationTestServer : public FlightServerBase { new NumberingStream(std::unique_ptr(new RecordBatchStream( std::shared_ptr(new RecordBatchListReader(flight)))))); + std::cout << "Returning OK" << std::endl; + return Status::OK(); } diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index f5efa395909..5e754309179 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -460,14 +460,29 @@ Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, NumberingStream::NumberingStream(std::unique_ptr stream) : counter_(0), stream_(std::move(stream)) {} -std::shared_ptr NumberingStream::schema() { return stream_->schema(); } +std::shared_ptr NumberingStream::schema() { + std::cout << "In NumberingStream::schema" << std::endl; + + return stream_->schema(); +} Status NumberingStream::GetSchemaPayload(FlightPayload* payload) { + std::cout << "In NumberingStream::GetSchemaPayload" << std::endl; + return stream_->GetSchemaPayload(payload); } Status NumberingStream::Next(FlightPayload* payload) { + std::cout << "In NumberingStream::Next " << counter_ << std::endl; + RETURN_NOT_OK(stream_->Next(payload)); + if (payload) { + std::cout << "yes payload" << std::endl; + if (payload->ipc_message.type != ipc::MessageType::RECORD_BATCH) { + std::cout << "no record batch :(" << std::endl; + + } + } if (payload && payload->ipc_message.type == ipc::MessageType::RECORD_BATCH) { payload->app_metadata = Buffer::FromString(std::to_string(counter_)); counter_++; diff --git a/cpp/src/arrow/ipc/message.cc b/cpp/src/arrow/ipc/message.cc index 6569e71b454..669307b58a1 100644 --- a/cpp/src/arrow/ipc/message.cc +++ b/cpp/src/arrow/ipc/message.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "arrow/buffer.h" #include "arrow/device.h" @@ -469,6 +470,8 @@ class MessageDecoder::MessageDecoderImpl { metadata_(nullptr) {} Status ConsumeData(const uint8_t* data, int64_t size) { + // std::cerr << "ConsumeData / next_required_size_ " << next_required_size_ << std::endl; + if (buffered_size_ == 0) { while (size > 0 && size >= next_required_size_) { auto used_size = next_required_size_; @@ -505,6 +508,7 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeBuffer(std::shared_ptr buffer) { + // std::cerr << "ConsumeBuffer / next_required_size_ " << next_required_size_ << std::endl; if (buffered_size_ == 0) { while (buffer->size() >= next_required_size_) { auto used_size = next_required_size_; @@ -598,15 +602,18 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeInitial(int32_t continuation) { + // std::cerr << "ConsumeInitial / continuation " << continuation << std::endl; if (continuation == internal::kIpcContinuationToken) { state_ = State::METADATA_LENGTH; next_required_size_ = kMessageDecoderNextRequiredSizeMetadataLength; + // std::cerr << "ConsumeInitial / A / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadataLength()); // Valid IPC message, read the message length now return Status::OK(); } else if (continuation == 0) { state_ = State::EOS; next_required_size_ = 0; + // std::cerr << "ConsumeInitial / B / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnEOS()); return Status::OK(); } else if (continuation > 0) { @@ -614,6 +621,7 @@ class MessageDecoder::MessageDecoderImpl { // ARROW-6314: Backwards compatibility for reading old IPC // messages produced prior to version 0.15.0 next_required_size_ = continuation; + // std::cerr << "ConsumeInitial / C / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadata()); return Status::OK(); } else { @@ -641,11 +649,13 @@ class MessageDecoder::MessageDecoderImpl { if (metadata_length == 0) { state_ = State::EOS; next_required_size_ = 0; + // std::cerr << "ConsumeMetadataLength / A /next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnEOS()); return Status::OK(); } else if (metadata_length > 0) { state_ = State::METADATA; next_required_size_ = metadata_length; + // std::cerr << "ConsumeMetadataLength / B / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnMetadata()); return Status::OK(); } else { @@ -664,6 +674,8 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeMetadataChunks() { + // std::cerr << "ConsumeMetadataChunks / next_required_size_ " << next_required_size_ << std::endl; + if (chunks_[0]->size() >= next_required_size_) { if (chunks_[0]->size() == next_required_size_) { if (chunks_[0]->is_cpu()) { @@ -698,6 +710,7 @@ class MessageDecoder::MessageDecoderImpl { state_ = State::BODY; next_required_size_ = body_length; + // std::cerr << "ConsumeMetadata / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnBody()); if (next_required_size_ == 0) { ARROW_ASSIGN_OR_RAISE(auto body, AllocateBuffer(0, pool_)); @@ -713,6 +726,8 @@ class MessageDecoder::MessageDecoderImpl { } Status ConsumeBodyChunks() { + // std::cerr << "ConsumeBodyChunks / next_required_size_ " << next_required_size_ << std::endl; + if (chunks_[0]->size() >= next_required_size_) { auto used_size = next_required_size_; if (chunks_[0]->size() == next_required_size_) { @@ -740,6 +755,7 @@ class MessageDecoder::MessageDecoderImpl { RETURN_NOT_OK(listener_->OnMessageDecoded(std::move(message))); state_ = State::INITIAL; next_required_size_ = kMessageDecoderNextRequiredSizeInitial; + // std::cerr << "ConsumeBody / next_required_size_ = " << next_required_size_ << std::endl; RETURN_NOT_OK(listener_->OnInitial()); return Status::OK(); } diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 92f2b70f294..f0dc3093ecf 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -860,8 +860,7 @@ class RecordBatchStreamReaderImpl : public RecordBatchStreamReader { } if (message->type() != MessageType::DICTIONARY_BATCH) { - return Status::Invalid("IPC stream did not have the expected number (", num_dicts, - ") of dictionaries at the start of the stream"); + return Status::Invalid("IPC stream had (", i, ") dictionaries at the start of the stream, but (", num_dicts, ") were expected"); } RETURN_NOT_OK(ReadDictionary(*message)); } diff --git a/dev/archery/archery/integration/tester_rust.py b/dev/archery/archery/integration/tester_rust.py index 23c2d37386a..bca80ebae3c 100644 --- a/dev/archery/archery/integration/tester_rust.py +++ b/dev/archery/archery/integration/tester_rust.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import contextlib import os +import subprocess from .tester import Tester from .util import run_cmd, ARROW_ROOT_DEFAULT, log @@ -24,8 +26,8 @@ class RustTester(Tester): PRODUCER = True CONSUMER = True - # FLIGHT_SERVER = True - # FLIGHT_CLIENT = True + FLIGHT_SERVER = True + FLIGHT_CLIENT = True EXE_PATH = os.path.join(ARROW_ROOT_DEFAULT, 'rust/target/debug') @@ -34,11 +36,11 @@ class RustTester(Tester): STREAM_TO_FILE = os.path.join(EXE_PATH, 'arrow-stream-to-file') FILE_TO_STREAM = os.path.join(EXE_PATH, 'arrow-file-to-stream') - # FLIGHT_SERVER_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-server')] - # FLIGHT_CLIENT_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-client'), - # "-host", "localhost"] + FLIGHT_SERVER_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-server')] + FLIGHT_CLIENT_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-client'), + "--host", "localhost"] name = 'Rust' @@ -72,34 +74,42 @@ def file_to_stream(self, file_path, stream_path): cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path] self.run_shell_command(cmd) - # @contextlib.contextmanager - # def flight_server(self): - # cmd = self.FLIGHT_SERVER_CMD + ['-port=0'] - # if self.debug: - # log(' '.join(cmd)) - # server = subprocess.Popen(cmd, - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE) - # try: - # output = server.stdout.readline().decode() - # if not output.startswith("Server listening on localhost:"): - # server.kill() - # out, err = server.communicate() - # raise RuntimeError( - # "Flight-C++ server did not start properly, " - # "stdout:\n{}\n\nstderr:\n{}\n" - # .format(output + out.decode(), err.decode())) - # port = int(output.split(":")[1]) - # yield port - # finally: - # server.kill() - # server.wait(5) - - # def flight_request(self, port, json_path): - # cmd = self.FLIGHT_CLIENT_CMD + [ - # '-port=' + str(port), - # '-path=' + json_path, - # ] - # if self.debug: - # log(' '.join(cmd)) - # run_cmd(cmd) + @contextlib.contextmanager + def flight_server(self, scenario_name=None): + cmd = self.FLIGHT_SERVER_CMD + ['--port=0'] + if scenario_name: + cmd = cmd + ["--scenario", scenario_name] + if self.debug: + log(' '.join(cmd)) + server = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + try: + output = server.stdout.readline().decode() + if not output.startswith("Server listening on localhost:"): + server.kill() + out, err = server.communicate() + raise RuntimeError( + "Flight-Rust server did not start properly, " + "stdout:\n{}\n\nstderr:\n{}\n" + .format(output + out.decode(), err.decode())) + port = int(output.split(":")[1]) + yield port + finally: + server.kill() + server.wait(5) + + def flight_request(self, port, json_path=None, scenario_name=None): + cmd = self.FLIGHT_CLIENT_CMD + [ + '--port=' + str(port), + ] + if json_path: + cmd.extend(('--path', json_path)) + elif scenario_name: + cmd.extend(('--scenario', scenario_name)) + else: + raise TypeError("Must provide one of json_path or scenario_name") + + if self.debug: + log(' '.join(cmd)) + run_cmd(cmd) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index ee19f34a7c5..17223d4c3cb 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -21,33 +21,49 @@ use std::convert::TryFrom; use crate::{FlightData, SchemaResult}; +use arrow::array::ArrayRef; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; -use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions}; +use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteOptions}; use arrow::record_batch::RecordBatch; -/// Convert a `RecordBatch` to `FlightData` by converting the header and body to bytes +/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries +/// and values. This can't be a `From` implementation because neither `RecordBatch` nor `Vec` are +/// implemented in this crate. /// /// Note: This implicitly uses the default `IpcWriteOptions`. To configure options, /// use `flight_data_from_arrow_batch()` -impl From<&RecordBatch> for FlightData { - fn from(batch: &RecordBatch) -> Self { - let options = IpcWriteOptions::default(); - flight_data_from_arrow_batch(batch, &options) - } +pub fn convert_to_flight_data(batch: &RecordBatch) -> (Vec, FlightData) { + let options = IpcWriteOptions::default(); + flight_data_from_arrow_batch(batch, &options) } -/// Convert a `RecordBatch` to `FlightData` by converting the header and body to bytes +/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries +/// and values pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, -) -> FlightData { - let data = writer::record_batch_to_bytes(batch, &options); - FlightData { - flight_descriptor: None, - app_metadata: vec![], - data_header: data.ipc_message, - data_body: data.arrow_data, +) -> (Vec, FlightData) { + let data_gen = writer::IpcDataGenerator::default(); + let mut dictionary_tracker = writer::DictionaryTracker::new(false); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, &options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) +} + +impl From for FlightData { + fn from(data: EncodedData) -> Self { + FlightData { + data_header: data.ipc_message, + data_body: data.arrow_data, + ..Default::default() + } } } @@ -67,8 +83,11 @@ pub fn flight_schema_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> SchemaResult { + let data_gen = writer::IpcDataGenerator::default(); + let schema_bytes = data_gen.schema_to_bytes(schema, &options); + SchemaResult { - schema: writer::schema_to_bytes(schema, &options).ipc_message, + schema: schema_bytes.ipc_message, } } @@ -88,7 +107,8 @@ pub fn flight_data_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> FlightData { - let schema = writer::schema_to_bytes(schema, &options); + let data_gen = writer::IpcDataGenerator::default(); + let schema = data_gen.schema_to_bytes(schema, &options); FlightData { flight_descriptor: None, app_metadata: vec![], @@ -129,10 +149,10 @@ impl TryFrom<&SchemaResult> for Schema { pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, + dictionaries_by_field: &[Option], ) -> Option> { // check that the data_header is a record batch message let message = arrow::ipc::get_root_as_message(&data.data_header[..]); - let dictionaries_by_field = Vec::new(); message .header_as_record_batch() diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index 1b4119c9d96..6037fbe2683 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -460,7 +460,7 @@ pub fn read_record_batch( /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries_by_field` with the resulting dictionary -fn read_dictionary( +pub fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, schema: &Schema, diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index d6a52a62c5d..aad9bc39a6f 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -98,6 +98,239 @@ impl Default for IpcWriteOptions { } } +#[derive(Debug, Default)] +pub struct IpcDataGenerator {} + +impl IpcDataGenerator { + pub fn schema_to_bytes( + &self, + schema: &Schema, + write_options: &IpcWriteOptions, + ) -> EncodedData { + let mut fbb = FlatBufferBuilder::new(); + let schema = { + let fb = ipc::convert::schema_to_fb_offset(&mut fbb, schema); + fb.as_union_value() + }; + + let mut message = ipc::MessageBuilder::new(&mut fbb); + message.add_version(write_options.metadata_version); + message.add_header_type(ipc::MessageHeader::Schema); + message.add_bodyLength(0); + message.add_header(schema); + // TODO: custom metadata + let data = message.finish(); + fbb.finish(data, None); + + let data = fbb.finished_data(); + EncodedData { + ipc_message: data.to_vec(), + arrow_data: vec![], + } + } + + pub fn encoded_batch( + &self, + batch: &RecordBatch, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + ) -> Result<(Vec, EncodedData)> { + // TODO: handle nested dictionaries + let schema = batch.schema(); + let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); + + for (i, field) in schema.fields().iter().enumerate() { + let column = batch.column(i); + + if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { + let dict_id = field + .dict_id() + .expect("All Dictionary types have `dict_id`"); + let dict_data = column.data(); + let dict_values = &dict_data.child_data()[0]; + + let emit = dictionary_tracker.insert(dict_id, column)?; + + if emit { + encoded_dictionaries.push(self.dictionary_batch_to_bytes( + dict_id, + dict_values, + write_options, + )); + } + } + } + + let encoded_message = self.record_batch_to_bytes(batch, write_options); + + Ok((encoded_dictionaries, encoded_message)) + } + + /// Write a `RecordBatch` into two sets of bytes, one for the header (ipc::Message) and the + /// other for the batch's data + fn record_batch_to_bytes( + &self, + batch: &RecordBatch, + write_options: &IpcWriteOptions, + ) -> EncodedData { + let mut fbb = FlatBufferBuilder::new(); + + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + let mut offset = 0; + for array in batch.columns() { + let array_data = array.data(); + offset = write_array_data( + &array_data, + &mut buffers, + &mut arrow_data, + &mut nodes, + offset, + array.len(), + array.null_count(), + ); + } + + // write data + let buffers = fbb.create_vector(&buffers); + let nodes = fbb.create_vector(&nodes); + + let root = { + let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); + batch_builder.add_length(batch.num_rows() as i64); + batch_builder.add_nodes(nodes); + batch_builder.add_buffers(buffers); + let b = batch_builder.finish(); + b.as_union_value() + }; + // create an ipc::Message + let mut message = ipc::MessageBuilder::new(&mut fbb); + message.add_version(write_options.metadata_version); + message.add_header_type(ipc::MessageHeader::RecordBatch); + message.add_bodyLength(arrow_data.len() as i64); + message.add_header(root); + let root = message.finish(); + fbb.finish(root, None); + let finished_data = fbb.finished_data(); + + EncodedData { + ipc_message: finished_data.to_vec(), + arrow_data, + } + } + + /// Write dictionary values into two sets of bytes, one for the header (ipc::Message) and the + /// other for the data + fn dictionary_batch_to_bytes( + &self, + dict_id: i64, + array_data: &ArrayDataRef, + write_options: &IpcWriteOptions, + ) -> EncodedData { + let mut fbb = FlatBufferBuilder::new(); + + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + + write_array_data( + &array_data, + &mut buffers, + &mut arrow_data, + &mut nodes, + 0, + array_data.len(), + array_data.null_count(), + ); + + // write data + let buffers = fbb.create_vector(&buffers); + let nodes = fbb.create_vector(&nodes); + + let root = { + let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); + batch_builder.add_length(array_data.len() as i64); + batch_builder.add_nodes(nodes); + batch_builder.add_buffers(buffers); + batch_builder.finish() + }; + + let root = { + let mut batch_builder = ipc::DictionaryBatchBuilder::new(&mut fbb); + batch_builder.add_id(dict_id); + batch_builder.add_data(root); + batch_builder.finish().as_union_value() + }; + + let root = { + let mut message_builder = ipc::MessageBuilder::new(&mut fbb); + message_builder.add_version(write_options.metadata_version); + message_builder.add_header_type(ipc::MessageHeader::DictionaryBatch); + message_builder.add_bodyLength(arrow_data.len() as i64); + message_builder.add_header(root); + message_builder.finish() + }; + + fbb.finish(root, None); + let finished_data = fbb.finished_data(); + + EncodedData { + ipc_message: finished_data.to_vec(), + arrow_data, + } + } +} + +/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary +/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which +/// isn't allowed in the `FileWriter`. +pub struct DictionaryTracker { + written: HashMap, + error_on_replacement: bool, +} + +impl DictionaryTracker { + pub fn new(error_on_replacement: bool) -> Self { + Self { + written: HashMap::new(), + error_on_replacement, + } + } + + /// Keep track of the dictionary with the given ID and values. Behavior: + /// + /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate + /// that the dictionary was not actually inserted (because it's already been seen). + /// * If this ID has been written already but with different data, and this tracker is + /// configured to return an error, return an error. + /// * If the tracker has not been configured to error on replacement or this dictionary + /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just + /// inserted. + pub fn insert(&mut self, dict_id: i64, column: &ArrayRef) -> Result { + let dict_data = column.data(); + let dict_values = &dict_data.child_data()[0]; + + // If a dictionary with this id was already emitted, check if it was the same. + if let Some(last) = self.written.get(&dict_id) { + if last.data().child_data()[0] == *dict_values { + // Same dictionary values => no need to emit it again + return Ok(false); + } else if self.error_on_replacement { + return Err(ArrowError::InvalidArgumentError( + "Dictionary replacement detected when writing IPC file format. \ + Arrow IPC files only support a single dictionary for a given field \ + across all batches." + .to_string(), + )); + } + } + + self.written.insert(dict_id, column.clone()); + Ok(true) + } +} + pub struct FileWriter { /// The object to write to writer: BufWriter, @@ -114,7 +347,9 @@ pub struct FileWriter { /// Whether the writer footer has been written, and the writer is finished finished: bool, /// Keeps track of dictionaries that have been written - last_written_dictionaries: HashMap, + dictionary_tracker: DictionaryTracker, + + data_gen: IpcDataGenerator, } impl FileWriter { @@ -130,14 +365,15 @@ impl FileWriter { schema: &Schema, write_options: IpcWriteOptions, ) -> Result { + let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); // write magic to header writer.write_all(&super::ARROW_MAGIC[..])?; // create an 8-byte boundary after the header writer.write_all(&[0, 0])?; // write the schema, set the written bytes to the schema + header - let message = Message::Schema(schema, &write_options); - let (meta, data) = write_message(&mut writer, &message, &write_options)?; + let encoded_message = data_gen.schema_to_bytes(schema, &write_options); + let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; Ok(Self { writer, write_options, @@ -146,7 +382,8 @@ impl FileWriter { dictionary_blocks: vec![], record_blocks: vec![], finished: false, - last_written_dictionaries: HashMap::new(), + dictionary_tracker: DictionaryTracker::new(true), + data_gen, }) } @@ -157,10 +394,25 @@ impl FileWriter { "Cannot write record batch to file writer as it is closed".to_string(), )); } - self.write_dictionaries(&batch)?; - let message = Message::RecordBatch(batch, &self.write_options); + + let (encoded_dictionaries, encoded_message) = self.data_gen.encoded_batch( + batch, + &mut self.dictionary_tracker, + &self.write_options, + )?; + + for encoded_dictionary in encoded_dictionaries { + let (meta, data) = + write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; + + let block = + ipc::Block::new(self.block_offsets as i64, meta as i32, data as i64); + self.dictionary_blocks.push(block); + self.block_offsets += meta + data; + } + let (meta, data) = - write_message(&mut self.writer, &message, &self.write_options)?; + write_message(&mut self.writer, encoded_message, &self.write_options)?; // add a record block for the footer let block = ipc::Block::new( self.block_offsets as i64, @@ -172,53 +424,6 @@ impl FileWriter { Ok(()) } - fn write_dictionaries(&mut self, batch: &RecordBatch) -> Result<()> { - // TODO: handle nested dictionaries - - let schema = batch.schema(); - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); - - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); - let dict_data = column.data(); - let dict_values = &dict_data.child_data()[0]; - - // If a dictionary with this id was already emitted, check if it was the same. - if let Some(last_dictionary) = - self.last_written_dictionaries.get(&dict_id) - { - if last_dictionary.data().child_data()[0] == *dict_values { - // Same dictionary values => no need to emit it again - continue; - } else { - return Err(ArrowError::InvalidArgumentError( - "Dictionary replacement detected when writing IPC file format. \ - Arrow IPC files only support a single dictionary for a given field \ - across all batches.".to_string())); - } - } - - self.last_written_dictionaries - .insert(dict_id, column.clone()); - - let message = - Message::DictionaryBatch(dict_id, dict_values, &self.write_options); - - let (meta, data) = - write_message(&mut self.writer, &message, &self.write_options)?; - - let block = - ipc::Block::new(self.block_offsets as i64, meta as i32, data as i64); - self.dictionary_blocks.push(block); - self.block_offsets += meta + data; - } - } - Ok(()) - } - /// Write footer and closing tag, then mark the writer as done pub fn finish(&mut self) -> Result<()> { // write EOS @@ -269,7 +474,9 @@ pub struct StreamWriter { /// Whether the writer footer has been written, and the writer is finished finished: bool, /// Keeps track of dictionaries that have been written - last_written_dictionaries: HashMap, + dictionary_tracker: DictionaryTracker, + + data_gen: IpcDataGenerator, } impl StreamWriter { @@ -284,16 +491,18 @@ impl StreamWriter { schema: &Schema, write_options: IpcWriteOptions, ) -> Result { + let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); // write the schema, set the written bytes to the schema - let message = Message::Schema(schema, &write_options); - write_message(&mut writer, &message, &write_options)?; + let encoded_message = data_gen.schema_to_bytes(schema, &write_options); + write_message(&mut writer, encoded_message, &write_options)?; Ok(Self { writer, write_options, schema: schema.clone(), finished: false, - last_written_dictionaries: HashMap::new(), + dictionary_tracker: DictionaryTracker::new(false), + data_gen, }) } @@ -304,46 +513,17 @@ impl StreamWriter { "Cannot write record batch to stream writer as it is closed".to_string(), )); } - self.write_dictionaries(&batch)?; - - let message = Message::RecordBatch(batch, &self.write_options); - write_message(&mut self.writer, &message, &self.write_options)?; - Ok(()) - } - - fn write_dictionaries(&mut self, batch: &RecordBatch) -> Result<()> { - // TODO: handle nested dictionaries - - let schema = batch.schema(); - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); - - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); - let dict_data = column.data(); - let dict_values = &dict_data.child_data()[0]; - // If a dictionary with this id was already emitted, check if it was the same. - if let Some(last_dictionary) = - self.last_written_dictionaries.get(&dict_id) - { - if last_dictionary.data().child_data()[0] == *dict_values { - // Same dictionary values => no need to emit it again - continue; - } - } - - self.last_written_dictionaries - .insert(dict_id, column.clone()); - - let message = - Message::DictionaryBatch(dict_id, dict_values, &self.write_options); + let (encoded_dictionaries, encoded_message) = self + .data_gen + .encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options) + .expect("StreamWriter is configured to not error on dictionary replacement"); - write_message(&mut self.writer, &message, &self.write_options)?; - } + for encoded_dictionary in encoded_dictionaries { + write_message(&mut self.writer, encoded_dictionary, &self.write_options)?; } + + write_message(&mut self.writer, encoded_message, &self.write_options)?; Ok(()) } @@ -373,58 +553,12 @@ pub struct EncodedData { /// Arrow buffers to be written, should be an empty vec for schema messages pub arrow_data: Vec, } - -pub fn schema_to_bytes(schema: &Schema, write_options: &IpcWriteOptions) -> EncodedData { - let mut fbb = FlatBufferBuilder::new(); - let schema = { - let fb = ipc::convert::schema_to_fb_offset(&mut fbb, schema); - fb.as_union_value() - }; - - let mut message = ipc::MessageBuilder::new(&mut fbb); - message.add_version(write_options.metadata_version); - message.add_header_type(ipc::MessageHeader::Schema); - message.add_bodyLength(0); - message.add_header(schema); - // TODO: custom metadata - let data = message.finish(); - fbb.finish(data, None); - - let data = fbb.finished_data(); - EncodedData { - ipc_message: data.to_vec(), - arrow_data: vec![], - } -} - -enum Message<'a> { - Schema(&'a Schema, &'a IpcWriteOptions), - RecordBatch(&'a RecordBatch, &'a IpcWriteOptions), - DictionaryBatch(i64, &'a ArrayDataRef, &'a IpcWriteOptions), -} - -impl<'a> Message<'a> { - /// Encode message to a ipc::Message and return data as bytes - fn encode(&'a self) -> EncodedData { - match self { - Message::Schema(schema, options) => schema_to_bytes(*schema, *options), - Message::RecordBatch(batch, options) => { - record_batch_to_bytes(*batch, *options) - } - Message::DictionaryBatch(dict_id, array_data, options) => { - dictionary_batch_to_bytes(*dict_id, *array_data, *options) - } - } - } -} - /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written -fn write_message( - mut writer: &mut BufWriter, - message: &Message, +pub fn write_message( + mut writer: W, + encoded: EncodedData, write_options: &IpcWriteOptions, ) -> Result<(usize, usize)> { - let encoded = message.encode(); let arrow_data_len = encoded.arrow_data.len(); if arrow_data_len % 8 != 0 { return Err(ArrowError::MemoryError( @@ -466,7 +600,7 @@ fn write_message( Ok((aligned_size, body_len)) } -fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Result { +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { let len = data.len() as u32; let pad_len = pad_to_8(len) as u32; let total_len = len + pad_len; @@ -481,121 +615,10 @@ fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Resul Ok(total_len as usize) } -/// Write a `RecordBatch` into a tuple of bytes, one for the header (ipc::Message) and the other for the batch's data -pub fn record_batch_to_bytes( - batch: &RecordBatch, - write_options: &IpcWriteOptions, -) -> EncodedData { - let mut fbb = FlatBufferBuilder::new(); - - let mut nodes: Vec = vec![]; - let mut buffers: Vec = vec![]; - let mut arrow_data: Vec = vec![]; - let mut offset = 0; - for array in batch.columns() { - let array_data = array.data(); - offset = write_array_data( - &array_data, - &mut buffers, - &mut arrow_data, - &mut nodes, - offset, - array.len(), - array.null_count(), - ); - } - - // write data - let buffers = fbb.create_vector(&buffers); - let nodes = fbb.create_vector(&nodes); - - let root = { - let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(batch.num_rows() as i64); - batch_builder.add_nodes(nodes); - batch_builder.add_buffers(buffers); - let b = batch_builder.finish(); - b.as_union_value() - }; - // create an ipc::Message - let mut message = ipc::MessageBuilder::new(&mut fbb); - message.add_version(write_options.metadata_version); - message.add_header_type(ipc::MessageHeader::RecordBatch); - message.add_bodyLength(arrow_data.len() as i64); - message.add_header(root); - let root = message.finish(); - fbb.finish(root, None); - let finished_data = fbb.finished_data(); - - EncodedData { - ipc_message: finished_data.to_vec(), - arrow_data, - } -} - -/// Write dictionary values into a tuple of bytes, one for the header (ipc::Message) and the other for the data -pub fn dictionary_batch_to_bytes( - dict_id: i64, - array_data: &ArrayDataRef, - write_options: &IpcWriteOptions, -) -> EncodedData { - let mut fbb = FlatBufferBuilder::new(); - - let mut nodes: Vec = vec![]; - let mut buffers: Vec = vec![]; - let mut arrow_data: Vec = vec![]; - - write_array_data( - &array_data, - &mut buffers, - &mut arrow_data, - &mut nodes, - 0, - array_data.len(), - array_data.null_count(), - ); - - // write data - let buffers = fbb.create_vector(&buffers); - let nodes = fbb.create_vector(&nodes); - - let root = { - let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(array_data.len() as i64); - batch_builder.add_nodes(nodes); - batch_builder.add_buffers(buffers); - batch_builder.finish() - }; - - let root = { - let mut batch_builder = ipc::DictionaryBatchBuilder::new(&mut fbb); - batch_builder.add_id(dict_id); - batch_builder.add_data(root); - batch_builder.finish().as_union_value() - }; - - let root = { - let mut message_builder = ipc::MessageBuilder::new(&mut fbb); - message_builder.add_version(write_options.metadata_version); - message_builder.add_header_type(ipc::MessageHeader::DictionaryBatch); - message_builder.add_bodyLength(arrow_data.len() as i64); - message_builder.add_header(root); - message_builder.finish() - }; - - fbb.finish(root, None); - let finished_data = fbb.finished_data(); - - EncodedData { - ipc_message: finished_data.to_vec(), - arrow_data, - } -} - /// Write a record batch to the writer, writing the message size before the message /// if the record batch is being written to a stream fn write_continuation( - writer: &mut BufWriter, + mut writer: W, write_options: &IpcWriteOptions, total_len: i32, ) -> Result { diff --git a/rust/datafusion/examples/flight_server.rs b/rust/datafusion/examples/flight_server.rs index a601b7cafdd..a5e4aee6017 100644 --- a/rust/datafusion/examples/flight_server.rs +++ b/rust/datafusion/examples/flight_server.rs @@ -114,7 +114,11 @@ impl FlightService for FlightServiceImpl { let mut batches: Vec> = results .iter() - .map(|batch| Ok(FlightData::from(batch))) + .flat_map(|batch| { + let flight_data = + arrow_flight::utils::convert_to_flight_data(batch); + flight_data.into_iter().map(Ok) + }) .collect(); // append batch vector to schema vector, so that the first message sent is the schema diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml index 1c2687086fb..003f636cf79 100644 --- a/rust/integration-testing/Cargo.toml +++ b/rust/integration-testing/Cargo.toml @@ -27,11 +27,18 @@ edition = "2018" [dependencies] arrow = { path = "../arrow" } +arrow-flight = { path = "../arrow-flight" } +async-trait = "0.1.41" clap = "2.33" +futures = "0.3" +hex = "0.4" +prost = "0.6" serde = { version = "1.0", features = ["rc"] } serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } -hex = "0.4" +tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] } +tonic = "0.3" +tracing-subscriber = "*" [[bin]] name = "arrow-file-to-stream" @@ -44,3 +51,11 @@ path = "src/bin/arrow-stream-to-file.rs" [[bin]] name = "arrow-json-integration-test" path = "src/bin/arrow-json-integration-test.rs" + +[[bin]] +name = "flight-test-integration-server" +path = "src/bin/flight-test-integration-server.rs" + +[[bin]] +name = "flight-test-integration-client" +path = "src/bin/flight-test-integration-client.rs" diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index b1bec677cf1..cd89a8edf1d 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,27 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::fs::File; -use std::io::BufReader; -use std::sync::Arc; use clap::{App, Arg}; -use hex::decode; -use serde_json::Value; -use arrow::array::*; -use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow::record_batch::RecordBatch; -use arrow::{ - buffer::Buffer, - buffer::MutableBuffer, - datatypes::ToByteSlice, - util::{bit_util, integration_util::*}, -}; +use arrow::util::integration_util::*; +use arrow_integration_testing::read_json_file; fn main() -> Result<()> { let matches = App::new("rust arrow-json-integration-test") @@ -93,520 +81,6 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn record_batch_from_json( - schema: &Schema, - json_batch: ArrowJsonBatch, - json_dictionaries: Option<&HashMap>, -) -> Result { - let mut columns = vec![]; - - for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { - let col = array_from_json(field, json_col, json_dictionaries)?; - columns.push(col); - } - - RecordBatch::try_new(Arc::new(schema.clone()), columns) -} - -/// Construct an Arrow array from a partially typed JSON column -fn array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dictionaries: Option<&HashMap>, -) -> Result { - match field.data_type() { - DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), - DataType::Boolean => { - let mut b = BooleanBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_bool().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int8 => { - let mut b = Int8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) - })? as i8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int16 => { - let mut b = Int16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int32 - | DataType::Date32(DateUnit::Day) - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let mut b = Int32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i32), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::Int64 - | DataType::Date64(DateUnit::Millisecond) - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - let mut b = Int64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(match value { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } - _ => panic!("Unable to parse {:?} as number", value), - }), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::UInt8 => { - let mut b = UInt8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt16 => { - let mut b = UInt16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt32 => { - let mut b = UInt32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt64 => { - let mut b = UInt64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value( - value - .as_str() - .unwrap() - .parse() - .expect("Unable to parse string as u64"), - ), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float32 => { - let mut b = Float32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap() as f32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float64 => { - let mut b = Float64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Binary => { - let mut b = BinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeBinary => { - let mut b = LargeBinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Utf8 => { - let mut b = StringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeUtf8 => { - let mut b = LargeStringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::FixedSizeBinary(len) => { - let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = hex::decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::List(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| v.as_i64().unwrap() as i32) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(ListArray::from(list_data))) - } - DataType::LargeList(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| match v { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => s.parse::().unwrap(), - _ => panic!("64-bit offset must be either string or number"), - }) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(LargeListArray::from(list_data))) - } - DataType::FixedSizeList(child_field, _) => { - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let null_buf = create_null_buf(&json_col); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(FixedSizeListArray::from(list_data))) - } - DataType::Struct(fields) => { - // construct struct with null data - let null_buf = create_null_buf(&json_col); - let mut array_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .null_bit_buffer(null_buf); - - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { - let array = array_from_json(field, col, dictionaries)?; - array_data = array_data.add_child_data(array.data()); - } - - let array = StructArray::from(array_data.build()); - Ok(Arc::new(array)) - } - DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) - })?; - // find dictionary - let dictionary = dictionaries - .ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field - )) - })? - .get(&dict_id); - match dictionary { - Some(dictionary) => dictionary_array_from_json( - field, json_col, key_type, value_type, dictionary, - ), - None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field - ))), - } - } - t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t - ))), - } -} - -fn dictionary_array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dict_key: &DataType, - dict_value: &DataType, - dictionary: &ArrowJsonDictionaryBatch, -) -> Result { - match dict_key { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - let null_buf = create_null_buf(&json_col); - - // build the key data into a buffer, then construct values separately - let key_field = Field::new_dict( - "key", - dict_key.clone(), - field.is_nullable(), - field - .dict_id() - .expect("Dictionary fields must have a dict_id value"), - field - .dict_is_ordered() - .expect("Dictionary fields must have a dict_is_ordered value"), - ); - let keys = array_from_json(&key_field, json_col, None)?; - // note: not enough info on nullability of dictionary - let value_field = Field::new("value", dict_value.clone(), true); - println!("dictionary value type: {:?}", dict_value); - let values = - array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; - - // convert key and value to dictionary data - let dict_data = ArrayData::builder(field.data_type().clone()) - .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(null_buf) - .add_child_data(values.data()) - .build(); - - let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } - DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), - DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), - DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), - DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), - DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), - DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), - DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), - _ => unreachable!(), - }; - Ok(array) - } - _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key - ))), - } -} - -/// A helper to create a null buffer from a Vec -fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { - let num_bytes = bit_util::ceil(json_col.count, 8); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); - json_col - .validity - .clone() - .unwrap() - .iter() - .enumerate() - .for_each(|(i, v)| { - let null_slice = null_buf.data_mut(); - if *v != 0 { - bit_util::set_bit(null_slice, i); - } - }); - null_buf.freeze() -} - fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Converting {} to {}", arrow_name, json_name); @@ -702,43 +176,3 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { Ok(()) } - -struct ArrowFile { - schema: Schema, - // we can evolve this into a concrete Arrow type - // this is temporarily not being read from - _dictionaries: HashMap, - batches: Vec, -} - -fn read_json_file(json_name: &str) -> Result { - let json_file = File::open(json_name)?; - let reader = BufReader::new(json_file); - let arrow_json: Value = serde_json::from_reader(reader).unwrap(); - let schema = Schema::from(&arrow_json["schema"])?; - // read dictionaries - let mut dictionaries = HashMap::new(); - if let Some(dicts) = arrow_json.get("dictionaries") { - for d in dicts - .as_array() - .expect("Unable to get dictionaries as array") - { - let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) - .expect("Unable to get dictionary from JSON"); - // TODO: convert to a concrete Arrow type - dictionaries.insert(json_dict.id, json_dict); - } - } - - let mut batches = vec![]; - for b in arrow_json["batches"].as_array().unwrap() { - let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); - let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; - batches.push(batch); - } - Ok(ArrowFile { - schema, - _dictionaries: dictionaries, - batches, - }) -} diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs new file mode 100644 index 00000000000..7f38e6d8ed9 --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -0,0 +1,452 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_integration_testing::{ + read_json_file, ArrowFile, AUTH_PASSWORD, AUTH_USERNAME, +}; + +use arrow::datatypes::SchemaRef; +use arrow::ipc::{self, reader}; +use arrow::record_batch::RecordBatch; + +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::{ + flight_descriptor::DescriptorType, BasicAuth, FlightData, HandshakeRequest, Location, + Ticket, +}; +use arrow_flight::{utils::flight_data_to_arrow_batch, FlightDescriptor}; + +use clap::{App, Arg}; +use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; +use prost::Message; +use tonic::{metadata::MetadataValue, Request, Status}; + +use std::sync::Arc; + +type Error = Box; +type Result = std::result::Result; + +type Client = FlightServiceClient; + +#[tokio::main] +async fn main() -> Result { + tracing_subscriber::fmt::init(); + let matches = App::new("rust flight-test-integration-client") + .arg(Arg::with_name("host").long("host").takes_value(true)) + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg(Arg::with_name("path").long("path").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let host = matches.value_of("host").expect("Host is required"); + let port = matches.value_of("port").expect("Port is required"); + + match matches.value_of("scenario") { + Some("middleware") => middleware_scenario(host, port).await?, + Some("auth:basic_proto") => auth_basic_proto_scenario(host, port).await?, + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + let path = matches + .value_of("path") + .expect("Path is required if scenario is not specified"); + integration_test_scenario(host, port, path).await?; + } + } + + Ok(()) +} + +async fn middleware_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let conn = tonic::transport::Endpoint::new(url)?.connect().await?; + let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Cmd); + descriptor.cmd = b"".to_vec(); + + // This call is expected to fail. + let resp = client + .get_flight_info(Request::new(descriptor.clone())) + .await; + match resp { + Ok(_) => return Err(Box::new(Status::internal("Expected call to fail"))), + Err(e) => { + let headers = e.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + + eprintln!("Headers received successfully on failing call."); + } + } + + // This call should succeed + descriptor.cmd = b"success".to_vec(); + let resp = client.get_flight_info(Request::new(descriptor)).await?; + + let headers = resp.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + + eprintln!("Headers received successfully on passing call."); + + Ok(()) +} + +fn middleware_interceptor(mut req: Request<()>) -> Result, Status> { + let metadata = req.metadata_mut(); + metadata.insert("x-middleware", "expected value".parse().unwrap()); + Ok(req) +} + +async fn auth_basic_proto_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let mut client = FlightServiceClient::connect(url).await?; + + let action = arrow_flight::Action::default(); + + let resp = client.do_action(Request::new(action.clone())).await; + // This client is unauthenticated and should fail. + match resp { + Err(e) => { + if e.code() != tonic::Code::Unauthenticated { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + e + )))); + } + } + Ok(other) => { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + other + )))); + } + } + + let token = authenticate(&mut client, AUTH_USERNAME, AUTH_PASSWORD) + .await + .expect("must respond successfully from handshake"); + + let mut request = Request::new(action); + let metadata = request.metadata_mut(); + metadata.insert_bin( + "auth-token-bin", + MetadataValue::from_bytes(token.as_bytes()), + ); + + let resp = client.do_action(request).await?; + let mut resp = resp.into_inner(); + + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + + let body = String::from_utf8(r.body).unwrap(); + assert_eq!(body, AUTH_USERNAME); + + Ok(()) +} + +// TODO: should this be extended, abstracted, and moved out of test code and into production code? +async fn authenticate( + client: &mut Client, + username: &str, + password: &str, +) -> Result { + let auth = BasicAuth { + username: username.into(), + password: password.into(), + }; + let mut payload = vec![]; + auth.encode(&mut payload)?; + + let req = stream::once(async { + HandshakeRequest { + payload, + ..HandshakeRequest::default() + } + }); + + let rx = client.handshake(Request::new(req)).await?; + let mut rx = rx.into_inner(); + + let r = rx.next().await.expect("must respond from handshake")?; + assert!(rx.next().await.is_none(), "must not respond a second time"); + + Ok(String::from_utf8(r.payload).unwrap()) +} + +async fn integration_test_scenario(host: &str, port: &str, path: &str) -> Result { + let url = format!("http://{}:{}", host, port); + + let client = FlightServiceClient::connect(url).await?; + + let ArrowFile { + schema, batches, .. + } = read_json_file(path)?; + + let schema = Arc::new(schema); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Path); + descriptor.path = vec![path.to_string()]; + + upload_data( + client.clone(), + schema.clone(), + descriptor.clone(), + batches.clone(), + ) + .await?; + verify_data(client, descriptor, schema, &batches).await?; + + Ok(()) +} + +async fn upload_data( + mut client: Client, + schema: SchemaRef, + descriptor: FlightDescriptor, + original_data: Vec, +) -> Result { + eprintln!("In upload_data"); + let (mut upload_tx, upload_rx) = mpsc::channel(10); + + let mut schema_flight_data = FlightData::from(&*schema); + schema_flight_data.flight_descriptor = Some(descriptor.clone()); + upload_tx.send(schema_flight_data).await?; + + let mut original_data_iter = original_data.iter().enumerate(); + + if let Some((counter, first_batch)) = original_data_iter.next() { + eprintln!("Some batches"); + + let metadata = counter.to_string().into_bytes(); + eprintln!("sending batch {:?}", metadata); + + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::convert_to_flight_data(first_batch); + + upload_tx.send_all(&mut stream::iter(dictionary_flight_data).map(Ok)).await?; + + // Only the record batch's FlightData gets app_metadata + batch_flight_data.app_metadata = metadata.clone(); + upload_tx.send(batch_flight_data).await?; + + let outer = client.do_put(Request::new(upload_rx)).await?; + let mut inner = outer.into_inner(); + + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + eprintln!("received ack for batch {:?}", metadata); + + for (counter, batch) in original_data_iter { + let metadata = counter.to_string().into_bytes(); + eprintln!("sending batch {:?}", metadata); + + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::convert_to_flight_data(batch); + + upload_tx.send_all(&mut stream::iter(dictionary_flight_data).map(Ok)).await?; + + // Only the record batch's FlightData gets app_metadata + batch_flight_data.app_metadata = metadata.clone(); + upload_tx.send(batch_flight_data).await?; + + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + eprintln!("received ack for batch {:?}", metadata); + } + } else { + eprintln!("No batches"); + drop(upload_tx); + let outer = client.do_put(Request::new(upload_rx)).await?; + let inner = outer.into_inner(); + + dbg!(&inner); + } + + Ok(()) +} + +async fn verify_data( + mut client: Client, + descriptor: FlightDescriptor, + expected_schema: SchemaRef, + expected_data: &[RecordBatch], +) -> Result { + let resp = client.get_flight_info(Request::new(descriptor)).await?; + let info = resp.into_inner(); + + assert!( + !info.endpoint.is_empty(), + "No endpoints returned from Flight server", + ); + for endpoint in info.endpoint { + let ticket = endpoint + .ticket + .expect("No ticket returned from Flight server"); + + assert!( + !endpoint.location.is_empty(), + "No locations returned from Flight server", + ); + for location in endpoint.location { + println!("Verifying location {:?}", location); + consume_flight_location( + location, + ticket.clone(), + &expected_data, + expected_schema.clone(), + ) + .await?; + } + } + + Ok(()) +} + +async fn consume_flight_location( + location: Location, + ticket: Ticket, + expected_data: &[RecordBatch], + schema: SchemaRef, +) -> Result { + let mut location = location; + location.uri = location.uri.replace("grpc+tcp://", "grpc://"); + + dbg!(&location.uri); + let mut client = FlightServiceClient::connect(location.uri).await?; + + dbg!(&client); + + let resp = client.do_get(ticket).await; + dbg!(&resp); + + // If i turn on RUST_LOG=h2=debug and run this client against the C++ server, I see this: + // Dec 02 16:46:50.047 DEBUG h2::codec::framed_read: received; frame=Reset { stream_id: StreamId(1), error_code: INTERNAL_ERROR } + // which i think is coming straight from the server, but I don't know why :( + + let mut resp = resp?.into_inner(); + dbg!(&resp); + + let _schema_again = resp.next().await.unwrap(); + let mut dictionaries_by_field = vec![None; schema.fields().len()]; + + for (counter, expected_batch) in expected_data.iter().enumerate() { + let mut actual_batch = resp.next().await.unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + })?; + let mut message = arrow::ipc::get_root_as_message(&actual_batch.data_header[..]); + dbg!(message.header_type()); + while message.header_type() == ipc::MessageHeader::DictionaryBatch { + // TODO: handle None which means parse failure + if let Some(ipc_batch) = message.header_as_dictionary_batch() { + let dictionary_batch_result = reader::read_dictionary( + &actual_batch.data_body, + ipc_batch, + &schema, + &mut dictionaries_by_field, + ); + if let Err(e) = dictionary_batch_result { + panic!("Error reading dictionary: {:?}", e); + } else { + dbg!(&dictionaries_by_field); + } + } + + actual_batch = resp.next().await.unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + })?; + message = arrow::ipc::get_root_as_message(&actual_batch.data_header[..]); + } + + let metadata = counter.to_string().into_bytes(); + assert_eq!(metadata, actual_batch.app_metadata); + + let actual_batch = flight_data_to_arrow_batch( + &actual_batch, + schema.clone(), + &dictionaries_by_field, + ) + .expect("Unable to convert flight data to Arrow batch") + .expect("Unable to convert flight data to Arrow batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + let schema = expected_batch.schema(); + for i in 0..expected_batch.num_columns() { + let field = schema.field(i); + let field_name = field.name(); + + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data, actual_data, "Data for field {}", field_name); + } + } + + assert!( + resp.next().await.is_none(), + "Got more batches than the expected: {}", + expected_data.len(), + ); + + Ok(()) +} diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs new file mode 100644 index 00000000000..b8a40aec4c5 --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -0,0 +1,709 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::convert::TryFrom; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +use clap::{App, Arg}; +use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use prost::Message; +use tokio::net::TcpListener; +use tokio::sync::Mutex; +use tonic::transport::Server; +use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; + +use arrow::ipc::{self, reader}; +use arrow::{datatypes::Schema, record_batch::RecordBatch}; +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, Action, ActionType, BasicAuth, Criteria, + Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, Location, PutResult, SchemaResult, Ticket, +}; + +use arrow_integration_testing::{AUTH_PASSWORD, AUTH_USERNAME}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +#[derive(Debug, Clone)] +struct IntegrationDataset { + schema: Schema, + chunks: Vec, +} + +#[derive(Clone, Default)] +pub struct FlightServiceImpl { + server_location: String, + uploaded_chunks: Arc>>, +} + +#[tonic::async_trait] +impl FlightService for FlightServiceImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + eprintln!("Doing do_get..."); + let ticket = request.into_inner(); + + let key = String::from_utf8(ticket.ticket.to_vec()) + .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?; + + let uploaded_chunks = self.uploaded_chunks.lock().await; + + let flight = uploaded_chunks.get(&key).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", key)) + })?; + + let schema = std::iter::once( + flight_schema(&flight.schema) + .map(|data_header| FlightData { + data_header, + ..Default::default() + }) + .map_err(|e| { + Status::internal(format!("Could not generate ipc schema: {}", e)) + }), + ); + + let batches = flight + .chunks + .iter() + .enumerate() + .flat_map(|(counter, batch)| { + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::convert_to_flight_data(batch); + + // Only the record batch's FlightData gets app_metadata + let metadata = counter.to_string().into_bytes(); + batch_flight_data.app_metadata = metadata; + + dictionary_flight_data + .into_iter() + .chain(std::iter::once(batch_flight_data)) + .map(Ok) + }); + + let output = futures::stream::iter(schema.chain(batches).collect::>()); + + Ok(Response::new(Box::pin(output) as Self::DoGetStream)) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + eprintln!("Doing get_flight_info..."); + let descriptor = request.into_inner(); + + match descriptor.r#type { + t if t == DescriptorType::Path as i32 => { + let path = &descriptor.path; + if path.is_empty() { + return Err(Status::invalid_argument("Invalid path")); + } + + let uploaded_chunks = self.uploaded_chunks.lock().await; + let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", path[0])) + })?; + + let endpoint = FlightEndpoint { + ticket: Some(Ticket { + ticket: path[0].as_bytes().to_vec(), + }), + location: vec![Location { + uri: self.server_location.clone(), + }], + }; + + let total_records: usize = + flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + + let schema = flight_schema(&flight.schema) + .expect("Could not generate schema bytes"); + + let info = FlightInfo { + schema, + flight_descriptor: Some(descriptor.clone()), + endpoint: vec![endpoint], + total_records: total_records as i64, + total_bytes: -1, + }; + + Ok(Response::new(info)) + } + other => Err(Status::unimplemented(format!("Request type: {}", other))), + } + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + eprintln!("Doing put..."); + + let mut input_stream = request.into_inner(); + let flight_data = input_stream + .message() + .await? + .ok_or(Status::invalid_argument("Must send some FlightData"))?; + + let descriptor = flight_data + .flight_descriptor + .clone() + .ok_or(Status::invalid_argument("Must have a descriptor"))?; + + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() + { + return Err(Status::invalid_argument("Must specify a path")); + } + + let key = descriptor.path[0].clone(); + + let schema = Schema::try_from(&flight_data) + .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; + let schema_ref = Arc::new(schema.clone()); + + let (mut response_tx, response_rx) = mpsc::channel(10); + + let uploaded_chunks = self.uploaded_chunks.clone(); + + tokio::spawn(async move { + let mut chunks = vec![]; + let mut uploaded_chunks = uploaded_chunks.lock().await; + let mut dictionaries_by_field = vec![None; schema_ref.fields().len()]; + + while let Some(Ok(data)) = input_stream.next().await { + let message = arrow::ipc::get_root_as_message(&data.data_header[..]); + + match message.header_type() { + ipc::MessageHeader::Schema => { + // TODO: send an error to the stream + eprintln!("Not expecting a schema when messages are read"); + } + ipc::MessageHeader::RecordBatch => { + eprintln!("RecordBatch"); + eprintln!("send #1"); + let stream_result = response_tx + .send(Ok(PutResult { + app_metadata: data.app_metadata.clone(), + })) + .await; + if let Err(e) = stream_result { + eprintln!("send #2"); + response_tx + .send(Err(Status::internal(format!( + "Could not send PutResult: {:?}", + e + )))) + .await + .expect("Error sending error"); + } + + // TODO: handle None which means parse failure + if let Some(ipc_batch) = message.header_as_record_batch() { + let arrow_batch_result = reader::read_record_batch( + &data.data_body, + ipc_batch, + schema_ref.clone(), + &dictionaries_by_field, + ); + match arrow_batch_result { + Ok(batch) => chunks.push(batch), + Err(e) => { + eprintln!("send #3"); + response_tx + .send(Err(Status::invalid_argument(format!( + "Could not convert to RecordBatch: {:?}", + e + )))) + .await + .expect("Error sending error") + } + } + } + } + ipc::MessageHeader::DictionaryBatch => { + eprintln!("DictionaryBatch"); + // TODO: handle None which means parse failure + if let Some(ipc_batch) = message.header_as_dictionary_batch() { + let dictionary_batch_result = reader::read_dictionary( + &data.data_body, + ipc_batch, + &schema_ref, + &mut dictionaries_by_field, + ); + if let Err(e) = dictionary_batch_result { + eprintln!("send #4"); + response_tx + .send(Err(Status::invalid_argument(format!( + "Could not convert to Dictionary: {:?}", + e + )))) + .await + .expect("Error sending error") + } else { + dbg!(&dictionaries_by_field); + } + } + } + t => { + // TODO: send error to stream + eprintln!("Reading types other than record batches not yet supported, unable to read {:?}", t); + } + } + } + + let dataset = IntegrationDataset { schema, chunks }; + uploaded_chunks.insert(key, dataset); + }); + + Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream)) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +fn flight_schema(arrow_schema: &Schema) -> Result> { + use arrow::ipc::{ + writer::{IpcDataGenerator, IpcWriteOptions}, + MetadataVersion, + }; + + let mut schema = vec![]; + + let wo = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); + + let data_gen = IpcDataGenerator::default(); + let encoded_message = data_gen.schema_to_bytes(arrow_schema, &wo); + arrow::ipc::writer::write_message(&mut schema, encoded_message, &wo)?; + + Ok(schema) +} + +#[derive(Clone, Default)] +pub struct MiddlewareScenarioImpl {} + +#[tonic::async_trait] +impl FlightService for MiddlewareScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let middleware_header = request.metadata().get("x-middleware").cloned(); + + let descriptor = request.into_inner(); + + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success" + { + // Return a fake location - the test doesn't read it + let endpoint = FlightEndpoint { + ticket: Some(Ticket { + ticket: b"foo".to_vec(), + }), + location: vec![Location { + uri: "grpc+tcp://localhost:10010".into(), + }], + }; + + let info = FlightInfo { + endpoint: vec![endpoint], + ..Default::default() + }; + + let mut response = Response::new(info); + if let Some(value) = middleware_header { + response.metadata_mut().insert("x-middleware", value); + } + + return Ok(response); + } + + let mut status = Status::unknown("Unknown"); + if let Some(value) = middleware_header { + status.metadata_mut().insert("x-middleware", value); + } + + Err(status) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +struct GrpcServerCallContext { + peer_identity: String, +} + +impl GrpcServerCallContext { + pub fn peer_identity(&self) -> &str { + &self.peer_identity + } +} + +#[derive(Clone)] +pub struct AuthBasicProtoScenarioImpl { + username: Arc, + password: Arc, + peer_identity: Arc>>, +} + +impl AuthBasicProtoScenarioImpl { + async fn check_auth( + &self, + metadata: &MetadataMap, + ) -> Result { + let token = metadata + .get_bin("auth-token-bin") + .and_then(|v| v.to_bytes().ok()) + .and_then(|b| String::from_utf8(b.to_vec()).ok()); + self.is_valid(token).await + } + + async fn is_valid( + &self, + token: Option, + ) -> Result { + match token { + Some(t) if t == &*self.username => Ok(GrpcServerCallContext { + peer_identity: self.username.to_string(), + }), + _ => Err(Status::unauthenticated("Invalid token")), + } + } +} + +#[tonic::async_trait] +impl FlightService for AuthBasicProtoScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + let (tx, rx) = mpsc::channel(10); + + tokio::spawn({ + let username = self.username.clone(); + let password = self.password.clone(); + + async move { + let requests = request.into_inner(); + + requests + .for_each(move |req| { + let mut tx = tx.clone(); + let req = req.expect("Error reading handshake request"); + let HandshakeRequest { payload, .. } = req; + + let auth = BasicAuth::decode(&*payload) + .expect("Error parsing handshake request"); + + let resp = if &*auth.username == &*username + && &*auth.password == &*password + { + Ok(HandshakeResponse { + payload: username.as_bytes().to_vec(), + ..HandshakeResponse::default() + }) + } else { + Err(Status::unauthenticated(format!( + "Don't know user {}", + auth.username + ))) + }; + + async move { + tx.send(resp) + .await + .expect("Error sending handshake response"); + } + }) + .await; + } + }); + + Ok(Response::new(Box::pin(rx))) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + let flight_context = self.check_auth(request.metadata()).await?; + // Respond with the authenticated username. + let buf = flight_context.peer_identity().as_bytes().to_vec(); + let result = arrow_flight::Result { body: buf }; + let output = futures::stream::once(async { Ok(result) }); + Ok(Response::new(Box::pin(output) as Self::DoActionStream)) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +type Error = Box; +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result { + tracing_subscriber::fmt::init(); + + let matches = App::new("rust flight-test-integration-server") + .about("Integration testing server for Flight.") + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let port = matches.value_of("port").unwrap_or("0"); + + match matches.value_of("scenario") { + Some("middleware") => middleware_scenario(port).await?, + Some("auth:basic_proto") => auth_basic_proto_scenario(port).await?, + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + integration_test_scenario(port).await?; + } + } + Ok(()) +} + +async fn integration_test_scenario(port: &str) -> Result { + let (mut listener, addr) = listen_on(port).await?; + + let service = FlightServiceImpl { + server_location: format!("grpc+tcp://{}", addr), + ..Default::default() + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + + Ok(()) +} + +async fn middleware_scenario(port: &str) -> Result { + let (mut listener, _) = listen_on(port).await?; + + let service = MiddlewareScenarioImpl {}; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +async fn auth_basic_proto_scenario(port: &str) -> Result { + let (mut listener, _) = listen_on(port).await?; + + let service = AuthBasicProtoScenarioImpl { + username: AUTH_USERNAME.into(), + password: AUTH_PASSWORD.into(), + peer_identity: Arc::new(Mutex::new(None)), + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +async fn listen_on(port: &str) -> Result<(TcpListener, SocketAddr)> { + let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; + + let listener = TcpListener::bind(addr).await?; + let addr = listener.local_addr()?; + println!("Server listening on localhost:{}", addr.port()); + + Ok((listener, addr)) +} diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index 596017a79bd..afb39454b2b 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -16,3 +16,583 @@ // under the License. //! Common code used in the integration test binaries + +use hex::decode; +use serde_json::Value; + +use arrow::util::integration_util::ArrowJsonBatch; + +use arrow::array::*; +use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::RecordBatch; +use arrow::{ + buffer::Buffer, + buffer::MutableBuffer, + datatypes::ToByteSlice, + util::{bit_util, integration_util::*}, +}; + +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +/// The expected username for the basic auth integration test. +pub const AUTH_USERNAME: &str = "arrow"; +/// The expected password for the basic auth integration test. +pub const AUTH_PASSWORD: &str = "flight"; + +pub struct ArrowFile { + pub schema: Schema, + // we can evolve this into a concrete Arrow type + // this is temporarily not being read from + pub _dictionaries: HashMap, + pub batches: Vec, +} + +pub fn read_json_file(json_name: &str) -> Result { + let json_file = File::open(json_name)?; + let reader = BufReader::new(json_file); + let arrow_json: Value = serde_json::from_reader(reader).unwrap(); + let schema = Schema::from(&arrow_json["schema"])?; + // read dictionaries + let mut dictionaries = HashMap::new(); + if let Some(dicts) = arrow_json.get("dictionaries") { + for d in dicts + .as_array() + .expect("Unable to get dictionaries as array") + { + let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) + .expect("Unable to get dictionary from JSON"); + // TODO: convert to a concrete Arrow type + dictionaries.insert(json_dict.id, json_dict); + } + } + + let mut batches = vec![]; + for b in arrow_json["batches"].as_array().unwrap() { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; + batches.push(batch); + } + Ok(ArrowFile { + schema, + _dictionaries: dictionaries, + batches, + }) +} + +fn record_batch_from_json( + schema: &Schema, + json_batch: ArrowJsonBatch, + json_dictionaries: Option<&HashMap>, +) -> Result { + let mut columns = vec![]; + + for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { + let col = array_from_json(field, json_col, json_dictionaries)?; + columns.push(col); + } + + RecordBatch::try_new(Arc::new(schema.clone()), columns) +} + +/// Construct an Arrow array from a partially typed JSON column +fn array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dictionaries: Option<&HashMap>, +) -> Result { + match field.data_type() { + DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), + DataType::Boolean => { + let mut b = BooleanBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_bool().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int8 => { + let mut b = Int8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to get {:?} as int64", + value + )) + })? as i8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int16 => { + let mut b = Int16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int32 + | DataType::Date32(DateUnit::Day) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + let mut b = Int32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i32), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::Int64 + | DataType::Date64(DateUnit::Millisecond) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + let mut b = Int64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => { + s.parse().expect("Unable to parse string as i64") + } + _ => panic!("Unable to parse {:?} as number", value), + }), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::UInt8 => { + let mut b = UInt8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt16 => { + let mut b = UInt16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt32 => { + let mut b = UInt32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt64 => { + let mut b = UInt64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value( + value + .as_str() + .unwrap() + .parse() + .expect("Unable to parse string as u64"), + ), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float32 => { + let mut b = Float32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap() as f32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float64 => { + let mut b = Float64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Binary => { + let mut b = BinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeBinary => { + let mut b = LargeBinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Utf8 => { + let mut b = StringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeUtf8 => { + let mut b = LargeStringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::FixedSizeBinary(len) => { + let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = hex::decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::List(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(ListArray::from(list_data))) + } + DataType::LargeList(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| match v { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => s.parse::().unwrap(), + _ => panic!("64-bit offset must be either string or number"), + }) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(LargeListArray::from(list_data))) + } + DataType::FixedSizeList(child_field, _) => { + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let null_buf = create_null_buf(&json_col); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(FixedSizeListArray::from(list_data))) + } + DataType::Struct(fields) => { + // construct struct with null data + let null_buf = create_null_buf(&json_col); + let mut array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .null_bit_buffer(null_buf); + + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + array_data = array_data.add_child_data(array.data()); + } + + let array = StructArray::from(array_data.build()); + Ok(Arc::new(array)) + } + DataType::Dictionary(key_type, value_type) => { + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find dict_id for field {:?}", + field + )) + })?; + // find dictionary + let dictionary = dictionaries + .ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find any dictionaries for field {:?}", + field + )) + })? + .get(&dict_id); + match dictionary { + Some(dictionary) => dictionary_array_from_json( + field, json_col, key_type, value_type, dictionary, + ), + None => Err(ArrowError::JsonError(format!( + "Unable to find dictionary for field {:?}", + field + ))), + } + } + t => Err(ArrowError::JsonError(format!( + "data type {:?} not supported", + t + ))), + } +} + +fn dictionary_array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dict_key: &DataType, + dict_value: &DataType, + dictionary: &ArrowJsonDictionaryBatch, +) -> Result { + match dict_key { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + let null_buf = create_null_buf(&json_col); + + // build the key data into a buffer, then construct values separately + let key_field = Field::new_dict( + "key", + dict_key.clone(), + field.is_nullable(), + field + .dict_id() + .expect("Dictionary fields must have a dict_id value"), + field + .dict_is_ordered() + .expect("Dictionary fields must have a dict_is_ordered value"), + ); + let keys = array_from_json(&key_field, json_col, None)?; + // note: not enough info on nullability of dictionary + let value_field = Field::new("value", dict_value.clone(), true); + println!("dictionary value type: {:?}", dict_value); + let values = + array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; + + // convert key and value to dictionary data + let dict_data = ArrayData::builder(field.data_type().clone()) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .null_bit_buffer(null_buf) + .add_child_data(values.data()) + .build(); + + let array = match dict_key { + DataType::Int8 => { + Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef + } + DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), + DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), + DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), + DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), + DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), + DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), + DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), + _ => unreachable!(), + }; + Ok(array) + } + _ => Err(ArrowError::JsonError(format!( + "Dictionary key type {:?} not supported", + dict_key + ))), + } +} + +/// A helper to create a null buffer from a Vec +fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { + let num_bytes = bit_util::ceil(json_col.count, 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + json_col + .validity + .clone() + .unwrap() + .iter() + .enumerate() + .for_each(|(i, v)| { + let null_slice = null_buf.data_mut(); + if *v != 0 { + bit_util::set_bit(null_slice, i); + } + }); + null_buf.freeze() +} diff --git a/rust/parquet/src/arrow/schema.rs b/rust/parquet/src/arrow/schema.rs index c93325b79b1..0c04704ae0f 100644 --- a/rust/parquet/src/arrow/schema.rs +++ b/rust/parquet/src/arrow/schema.rs @@ -205,7 +205,8 @@ fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Option { /// Encodes the Arrow schema into the IPC format, and base64 encodes it fn encode_arrow_schema(schema: &Schema) -> String { let options = writer::IpcWriteOptions::default(); - let mut serialized_schema = arrow::ipc::writer::schema_to_bytes(&schema, &options); + let data_gen = arrow::ipc::writer::IpcDataGenerator::default(); + let mut serialized_schema = data_gen.schema_to_bytes(&schema, &options); // manually prepending the length to the schema as arrow uses the legacy IPC format // TODO: change after addressing ARROW-9777