diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 33f92ca244f..8cef315ed0b 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1029,14 +1029,16 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: cdef cppclass CRecordBatchStreamWriter \ " arrow::ipc::RecordBatchStreamWriter"(CRecordBatchWriter): @staticmethod - CStatus Open(OutputStream* sink, const shared_ptr[CSchema]& schema, - shared_ptr[CRecordBatchWriter]* out) + CResult[shared_ptr[CRecordBatchWriter]] Open( + OutputStream* sink, const shared_ptr[CSchema]& schema, + CIpcOptions& options) cdef cppclass CRecordBatchFileWriter \ " arrow::ipc::RecordBatchFileWriter"(CRecordBatchWriter): @staticmethod - CStatus Open(OutputStream* sink, const shared_ptr[CSchema]& schema, - shared_ptr[CRecordBatchWriter]* out) + CResult[shared_ptr[CRecordBatchWriter]] Open( + OutputStream* sink, const shared_ptr[CSchema]& schema, + CIpcOptions& options) cdef cppclass CRecordBatchFileReader \ " arrow::ipc::RecordBatchFileReader": diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 6710f639b00..0e52c6f1ae3 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -248,6 +248,7 @@ cdef class _CRecordBatchWriter: cdef class _RecordBatchStreamWriter(_CRecordBatchWriter): cdef: shared_ptr[OutputStream] sink + CIpcOptions options bint closed def __cinit__(self): @@ -256,14 +257,20 @@ cdef class _RecordBatchStreamWriter(_CRecordBatchWriter): def __dealloc__(self): pass - def _open(self, sink, Schema schema): - get_writer(sink, &self.sink) + @property + def _use_legacy_format(self): + return self.options.write_legacy_ipc_format + + def _open(self, sink, Schema schema, use_legacy_format=False): + cdef: + CResult[shared_ptr[CRecordBatchWriter]] result + self.options.write_legacy_ipc_format = use_legacy_format + get_writer(sink, &self.sink) with nogil: - check_status( - CRecordBatchStreamWriter.Open(self.sink.get(), - schema.sp_schema, - &self.writer)) + result = CRecordBatchStreamWriter.Open( + self.sink.get(), schema.sp_schema, self.options) + self.writer = GetResultValue(result) cdef _get_input_stream(object source, shared_ptr[InputStream]* out): @@ -341,13 +348,17 @@ cdef class _RecordBatchStreamReader(_CRecordBatchReader): cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): - def _open(self, sink, Schema schema): - get_writer(sink, &self.sink) + def _open(self, sink, Schema schema, use_legacy_format=False): + cdef: + CResult[shared_ptr[CRecordBatchWriter]] result + self.options.write_legacy_ipc_format = use_legacy_format + get_writer(sink, &self.sink) with nogil: - check_status( - CRecordBatchFileWriter.Open(self.sink.get(), schema.sp_schema, - &self.writer)) + result = CRecordBatchFileWriter.Open(self.sink.get(), + schema.sp_schema, + self.options) + self.writer = GetResultValue(result) cdef class _RecordBatchFileReader: diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py index 378a92b2fb2..664f000977b 100644 --- a/python/pyarrow/ipc.py +++ b/python/pyarrow/ipc.py @@ -60,19 +60,26 @@ def __init__(self, source): self._open(source) +_ipc_writer_class_doc = """\ +Parameters +---------- +sink : str, pyarrow.NativeFile, or file-like Python object + Either a file path, or a writable file object +schema : pyarrow.Schema + The Arrow schema for data to be written to the file +use_legacy_format : boolean, default None + If None, use True unless overridden by ARROW_PRE_0_15_IPC_FORMAT=1 + environment variable""" + + class RecordBatchStreamWriter(lib._RecordBatchStreamWriter): - """ - Writer for the Arrow streaming binary format + __doc__ = """Writer for the Arrow streaming binary format - Parameters - ---------- - sink : str, pyarrow.NativeFile, or file-like Python object - Either a file path, or a writable file object - schema : pyarrow.Schema - The Arrow schema for data to be written to the file - """ - def __init__(self, sink, schema): - self._open(sink, schema) +{}""".format(_ipc_writer_class_doc) + + def __init__(self, sink, schema, use_legacy_format=None): + use_legacy_format = _get_legacy_format_default(use_legacy_format) + self._open(sink, schema, use_legacy_format=use_legacy_format) class RecordBatchFileReader(lib._RecordBatchFileReader, _ReadPandasOption): @@ -92,18 +99,22 @@ def __init__(self, source, footer_offset=None): class RecordBatchFileWriter(lib._RecordBatchFileWriter): - """ - Writer to create the Arrow binary file format - Parameters - ---------- - sink : str, pyarrow.NativeFile, or file-like Python object - Either a file path, or a writable file object - schema : pyarrow.Schema - The Arrow schema for data to be written to the file - """ - def __init__(self, sink, schema): - self._open(sink, schema) + __doc__ = """Writer to create the Arrow binary file format + +{}""".format(_ipc_writer_class_doc) + + def __init__(self, sink, schema, use_legacy_format=None): + use_legacy_format = _get_legacy_format_default(use_legacy_format) + self._open(sink, schema, use_legacy_format=use_legacy_format) + + +def _get_legacy_format_default(use_legacy_format): + if use_legacy_format is None: + import os + return bool(int(os.environ.get('ARROW_PRE_0_15_IPC_FORMAT', '0'))) + else: + return use_legacy_format def open_stream(source): diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index a94151fbeed..5f1a9320e8d 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -102,8 +102,15 @@ def _check_roundtrip(self, as_table=False): class StreamFormatFixture(IpcFixture): + # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix + use_legacy_ipc_format = False + def _get_writer(self, sink, schema): - return pa.RecordBatchStreamWriter(sink, schema) + return pa.RecordBatchStreamWriter( + sink, + schema, + use_legacy_format=self.use_legacy_ipc_format + ) class MessageFixture(IpcFixture): @@ -289,7 +296,9 @@ def test_stream_write_table_batches(stream_fixture): ignore_index=True)) -def test_stream_simple_roundtrip(stream_fixture): +@pytest.mark.parametrize('use_legacy_ipc_format', [False, True]) +def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format): + stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format _, batches = stream_fixture.write_batches() file_contents = pa.BufferReader(stream_fixture.get_source()) reader = pa.ipc.open_stream(file_contents) @@ -307,6 +316,24 @@ def test_stream_simple_roundtrip(stream_fixture): reader.read_next_batch() +def test_envvar_set_legacy_ipc_format(): + schema = pa.schema([pa.field('foo', pa.int32())]) + + writer = pa.RecordBatchStreamWriter(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + writer = pa.RecordBatchFileWriter(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + + import os + os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1' + writer = pa.RecordBatchStreamWriter(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + writer = pa.RecordBatchFileWriter(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + + del os.environ['ARROW_PRE_0_15_IPC_FORMAT'] + + def test_stream_read_all(stream_fixture): _, batches = stream_fixture.write_batches() file_contents = pa.BufferReader(stream_fixture.get_source())