From bc3ec207575a4a9df9ef546c3fc9828217766ef1 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 22 Jun 2017 18:32:09 -0400 Subject: [PATCH 1/2] Add null checks for invalid streams Change-Id: Ib3c578529974bcd6a80bbcb543d345c38bea0cbd --- cpp/src/arrow/ipc/reader.cc | 9 +++++++++ python/pyarrow/tests/test_ipc.py | 10 ++++++++++ 2 files changed, 19 insertions(+) diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 2b7b90f2f2e..d4d12dda0bf 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -215,6 +215,11 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { std::shared_ptr message; RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, &message)); + if (!message) { + return Status::Invalid( + "Expected DICTIONARY message in stream, was null or length 0"); + } + std::shared_ptr batch_body; RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body)) io::BufferReader reader(batch_body); @@ -229,6 +234,10 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { std::shared_ptr message; RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message)); + if (!message) { + return Status::Invalid("Expected SCHEMA message in stream, was null or length 0"); + } + RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_)); // TODO(wesm): In future, we may want to reconcile the ids in the stream with diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index eeea39ab194..47ef75602bc 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -72,6 +72,11 @@ class TestFile(MessagingTest, unittest.TestCase): def _get_writer(self, sink, schema): return pa.RecordBatchFileWriter(sink, schema) + def test_empty_file(self): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.open_file(buf) + def test_simple_roundtrip(self): batches = self.write_batches() file_contents = self._get_source() @@ -101,6 +106,11 @@ class TestStream(MessagingTest, unittest.TestCase): def _get_writer(self, sink, schema): return pa.RecordBatchStreamWriter(sink, schema) + def test_empty_stream(self): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.open_stream(buf) + def test_simple_roundtrip(self): batches = self.write_batches() file_contents = self._get_source() From 6ae5cd82907a56036b97d76c6dc3ae6782819fcd Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 22 Jun 2017 18:40:07 -0400 Subject: [PATCH 2/2] Centralize null checking Change-Id: Ida88028ad5707a7c8c2b4c1bd3ce4bfb1dbf062a --- cpp/src/arrow/ipc/reader.cc | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index d4d12dda0bf..7fef847c9f2 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -162,7 +162,7 @@ static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { return FileBlock(block->offset(), block->metaDataLength(), block->bodyLength()); } -static inline std::string message_type_name(Message::Type type) { +static inline std::string FormatMessageType(Message::Type type) { switch (type) { case Message::SCHEMA: return "schema"; @@ -188,14 +188,22 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { return ReadSchema(); } - Status ReadNextMessage(Message::Type expected_type, std::shared_ptr* message) { + Status ReadNextMessage(Message::Type expected_type, bool allow_null, + std::shared_ptr* message) { RETURN_NOT_OK(ReadMessage(stream_.get(), message)); + if (!(*message) && !allow_null) { + std::stringstream ss; + ss << "Expected " << FormatMessageType(expected_type) + << " message in stream, was null or length 0"; + return Status::Invalid(ss.str()); + } + if ((*message) == nullptr) { return Status::OK(); } if ((*message)->type() != expected_type) { std::stringstream ss; - ss << "Message not expected type: " << message_type_name(expected_type) + ss << "Message not expected type: " << FormatMessageType(expected_type) << ", was: " << (*message)->type(); return Status::IOError(ss.str()); } @@ -213,12 +221,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status ReadNextDictionary() { std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, &message)); - - if (!message) { - return Status::Invalid( - "Expected DICTIONARY message in stream, was null or length 0"); - } + RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, false, &message)); std::shared_ptr batch_body; RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body)) @@ -232,11 +235,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status ReadSchema() { std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message)); - - if (!message) { - return Status::Invalid("Expected SCHEMA message in stream, was null or length 0"); - } + RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, false, &message)); RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_)); @@ -252,7 +251,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status GetNextRecordBatch(std::shared_ptr* batch) { std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, true, &message)); if (message == nullptr) { // End of stream