diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 2b7b90f2f2e..7fef847c9f2 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -162,7 +162,7 @@ static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { return FileBlock(block->offset(), block->metaDataLength(), block->bodyLength()); } -static inline std::string message_type_name(Message::Type type) { +static inline std::string FormatMessageType(Message::Type type) { switch (type) { case Message::SCHEMA: return "schema"; @@ -188,14 +188,22 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { return ReadSchema(); } - Status ReadNextMessage(Message::Type expected_type, std::shared_ptr* message) { + Status ReadNextMessage(Message::Type expected_type, bool allow_null, + std::shared_ptr* message) { RETURN_NOT_OK(ReadMessage(stream_.get(), message)); + if (!(*message) && !allow_null) { + std::stringstream ss; + ss << "Expected " << FormatMessageType(expected_type) + << " message in stream, was null or length 0"; + return Status::Invalid(ss.str()); + } + if ((*message) == nullptr) { return Status::OK(); } if ((*message)->type() != expected_type) { std::stringstream ss; - ss << "Message not expected type: " << message_type_name(expected_type) + ss << "Message not expected type: " << FormatMessageType(expected_type) << ", was: " << (*message)->type(); return Status::IOError(ss.str()); } @@ -213,7 +221,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status ReadNextDictionary() { std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, false, &message)); std::shared_ptr batch_body; RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body)) @@ -227,7 +235,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status ReadSchema() { std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, false, &message)); RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_)); @@ -243,7 +251,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status GetNextRecordBatch(std::shared_ptr* batch) { std::shared_ptr message; - RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, true, &message)); if (message == nullptr) { // End of stream diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index eeea39ab194..47ef75602bc 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -72,6 +72,11 @@ class TestFile(MessagingTest, unittest.TestCase): def _get_writer(self, sink, schema): return pa.RecordBatchFileWriter(sink, schema) + def test_empty_file(self): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.open_file(buf) + def test_simple_roundtrip(self): batches = self.write_batches() file_contents = self._get_source() @@ -101,6 +106,11 @@ class TestStream(MessagingTest, unittest.TestCase): def _get_writer(self, sink, schema): return pa.RecordBatchStreamWriter(sink, schema) + def test_empty_stream(self): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.open_stream(buf) + def test_simple_roundtrip(self): batches = self.write_batches() file_contents = self._get_source()