Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1029,14 +1029,16 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
cdef cppclass CRecordBatchStreamWriter \
" arrow::ipc::RecordBatchStreamWriter"(CRecordBatchWriter):
@staticmethod
CStatus Open(OutputStream* sink, const shared_ptr[CSchema]& schema,
shared_ptr[CRecordBatchWriter]* out)
CResult[shared_ptr[CRecordBatchWriter]] Open(
OutputStream* sink, const shared_ptr[CSchema]& schema,
CIpcOptions& options)

cdef cppclass CRecordBatchFileWriter \
" arrow::ipc::RecordBatchFileWriter"(CRecordBatchWriter):
@staticmethod
CStatus Open(OutputStream* sink, const shared_ptr[CSchema]& schema,
shared_ptr[CRecordBatchWriter]* out)
CResult[shared_ptr[CRecordBatchWriter]] Open(
OutputStream* sink, const shared_ptr[CSchema]& schema,
CIpcOptions& options)

cdef cppclass CRecordBatchFileReader \
" arrow::ipc::RecordBatchFileReader":
Expand Down
33 changes: 22 additions & 11 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ cdef class _CRecordBatchWriter:
cdef class _RecordBatchStreamWriter(_CRecordBatchWriter):
cdef:
shared_ptr[OutputStream] sink
CIpcOptions options
bint closed

def __cinit__(self):
Expand All @@ -256,14 +257,20 @@ cdef class _RecordBatchStreamWriter(_CRecordBatchWriter):
def __dealloc__(self):
pass

def _open(self, sink, Schema schema):
get_writer(sink, &self.sink)
@property
def _use_legacy_format(self):
return self.options.write_legacy_ipc_format

def _open(self, sink, Schema schema, use_legacy_format=False):
cdef:
CResult[shared_ptr[CRecordBatchWriter]] result

self.options.write_legacy_ipc_format = use_legacy_format
get_writer(sink, &self.sink)
with nogil:
check_status(
CRecordBatchStreamWriter.Open(self.sink.get(),
schema.sp_schema,
&self.writer))
result = CRecordBatchStreamWriter.Open(
self.sink.get(), schema.sp_schema, self.options)
self.writer = GetResultValue(result)


cdef _get_input_stream(object source, shared_ptr[InputStream]* out):
Expand Down Expand Up @@ -341,13 +348,17 @@ cdef class _RecordBatchStreamReader(_CRecordBatchReader):

cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter):

def _open(self, sink, Schema schema):
get_writer(sink, &self.sink)
def _open(self, sink, Schema schema, use_legacy_format=False):
cdef:
CResult[shared_ptr[CRecordBatchWriter]] result

self.options.write_legacy_ipc_format = use_legacy_format
get_writer(sink, &self.sink)
with nogil:
check_status(
CRecordBatchFileWriter.Open(self.sink.get(), schema.sp_schema,
&self.writer))
result = CRecordBatchFileWriter.Open(self.sink.get(),
schema.sp_schema,
self.options)
self.writer = GetResultValue(result)


cdef class _RecordBatchFileReader:
Expand Down
55 changes: 33 additions & 22 deletions python/pyarrow/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,26 @@ def __init__(self, source):
self._open(source)


_ipc_writer_class_doc = """\
Parameters
----------
sink : str, pyarrow.NativeFile, or file-like Python object
Either a file path, or a writable file object
schema : pyarrow.Schema
The Arrow schema for data to be written to the file
use_legacy_format : boolean, default None
If None, use True unless overridden by ARROW_PRE_0_15_IPC_FORMAT=1
environment variable"""


class RecordBatchStreamWriter(lib._RecordBatchStreamWriter):
"""
Writer for the Arrow streaming binary format
__doc__ = """Writer for the Arrow streaming binary format

Parameters
----------
sink : str, pyarrow.NativeFile, or file-like Python object
Either a file path, or a writable file object
schema : pyarrow.Schema
The Arrow schema for data to be written to the file
"""
def __init__(self, sink, schema):
self._open(sink, schema)
{}""".format(_ipc_writer_class_doc)

def __init__(self, sink, schema, use_legacy_format=None):
use_legacy_format = _get_legacy_format_default(use_legacy_format)
self._open(sink, schema, use_legacy_format=use_legacy_format)


class RecordBatchFileReader(lib._RecordBatchFileReader, _ReadPandasOption):
Expand All @@ -92,18 +99,22 @@ def __init__(self, source, footer_offset=None):


class RecordBatchFileWriter(lib._RecordBatchFileWriter):
"""
Writer to create the Arrow binary file format

Parameters
----------
sink : str, pyarrow.NativeFile, or file-like Python object
Either a file path, or a writable file object
schema : pyarrow.Schema
The Arrow schema for data to be written to the file
"""
def __init__(self, sink, schema):
self._open(sink, schema)
__doc__ = """Writer to create the Arrow binary file format

{}""".format(_ipc_writer_class_doc)

def __init__(self, sink, schema, use_legacy_format=None):
use_legacy_format = _get_legacy_format_default(use_legacy_format)
self._open(sink, schema, use_legacy_format=use_legacy_format)


def _get_legacy_format_default(use_legacy_format):
if use_legacy_format is None:
import os
return bool(int(os.environ.get('ARROW_PRE_0_15_IPC_FORMAT', '0')))
else:
return use_legacy_format


def open_stream(source):
Expand Down
31 changes: 29 additions & 2 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,15 @@ def _check_roundtrip(self, as_table=False):

class StreamFormatFixture(IpcFixture):

# ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
use_legacy_ipc_format = False

def _get_writer(self, sink, schema):
return pa.RecordBatchStreamWriter(sink, schema)
return pa.RecordBatchStreamWriter(
sink,
schema,
use_legacy_format=self.use_legacy_ipc_format
)


class MessageFixture(IpcFixture):
Expand Down Expand Up @@ -289,7 +296,9 @@ def test_stream_write_table_batches(stream_fixture):
ignore_index=True))


def test_stream_simple_roundtrip(stream_fixture):
@pytest.mark.parametrize('use_legacy_ipc_format', [False, True])
def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format):
stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format
_, batches = stream_fixture.write_batches()
file_contents = pa.BufferReader(stream_fixture.get_source())
reader = pa.ipc.open_stream(file_contents)
Expand All @@ -307,6 +316,24 @@ def test_stream_simple_roundtrip(stream_fixture):
reader.read_next_batch()


def test_envvar_set_legacy_ipc_format():
schema = pa.schema([pa.field('foo', pa.int32())])

writer = pa.RecordBatchStreamWriter(pa.BufferOutputStream(), schema)
assert not writer._use_legacy_format
writer = pa.RecordBatchFileWriter(pa.BufferOutputStream(), schema)
assert not writer._use_legacy_format

import os
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
writer = pa.RecordBatchStreamWriter(pa.BufferOutputStream(), schema)
assert writer._use_legacy_format
writer = pa.RecordBatchFileWriter(pa.BufferOutputStream(), schema)
assert writer._use_legacy_format

del os.environ['ARROW_PRE_0_15_IPC_FORMAT']


def test_stream_read_all(stream_fixture):
_, batches = stream_fixture.write_batches()
file_contents = pa.BufferReader(stream_fixture.get_source())
Expand Down