diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 3c1db06159e..fb766a9a759 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -651,9 +651,13 @@ Status GetTensorSize(const Tensor& tensor, int64_t* size) { RecordBatchWriter::~RecordBatchWriter() {} -Status RecordBatchWriter::WriteTable(const Table& table) { +Status RecordBatchWriter::WriteTable(const Table& table, int64_t max_chunksize) { TableBatchReader reader(table); + if (max_chunksize > 0) { + reader.set_chunksize(max_chunksize); + } + std::shared_ptr batch; while (true) { RETURN_NOT_OK(reader.ReadNext(&batch)); @@ -666,6 +670,8 @@ Status RecordBatchWriter::WriteTable(const Table& table) { return Status::OK(); } +Status RecordBatchWriter::WriteTable(const Table& table) { return WriteTable(table, -1); } + // ---------------------------------------------------------------------- // Stream writer implementation diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index cedac45e712..457dcb4ec6a 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -65,6 +65,12 @@ class ARROW_EXPORT RecordBatchWriter { /// \return Status Status WriteTable(const Table& table); + /// \brief Write Table with a particular chunksize + /// \param[in] table table to write + /// \param[in] max_chunksize maximum chunk size for table chunks + /// \return Status + Status WriteTable(const Table& table, int64_t max_chunksize); + /// \brief Perform any logic necessary to finish the stream /// /// \return Status diff --git a/cpp/src/arrow/table-test.cc b/cpp/src/arrow/table-test.cc index e77d3aa8bbc..8a2288710c8 100644 --- a/cpp/src/arrow/table-test.cc +++ b/cpp/src/arrow/table-test.cc @@ -586,4 +586,37 @@ TEST_F(TestTableBatchReader, ReadNext) { ASSERT_EQ(nullptr, batch); } +TEST_F(TestTableBatchReader, Chunksize) { + auto a1 = MakeRandomArray(10); + auto a2 = MakeRandomArray(20); + auto a3 = MakeRandomArray(10); + + auto sch1 = arrow::schema({field("f1", int32())}); + auto t1 = Table::Make(sch1, {column(sch1->field(0), {a1, a2, a3})}); + + TableBatchReader i1(*t1); + + i1.set_chunksize(15); + + std::shared_ptr batch; + ASSERT_OK(i1.ReadNext(&batch)); + ASSERT_OK(batch->Validate()); + ASSERT_EQ(10, batch->num_rows()); + + ASSERT_OK(i1.ReadNext(&batch)); + ASSERT_OK(batch->Validate()); + ASSERT_EQ(15, batch->num_rows()); + + ASSERT_OK(i1.ReadNext(&batch)); + ASSERT_OK(batch->Validate()); + ASSERT_EQ(5, batch->num_rows()); + + ASSERT_OK(i1.ReadNext(&batch)); + ASSERT_OK(batch->Validate()); + ASSERT_EQ(10, batch->num_rows()); + + ASSERT_OK(i1.ReadNext(&batch)); + ASSERT_EQ(nullptr, batch); +} + } // namespace arrow diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 8f3f195765a..129524b7e43 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -403,7 +404,8 @@ class TableBatchReader::TableBatchReaderImpl { column_data_(table.num_columns()), chunk_numbers_(table.num_columns(), 0), chunk_offsets_(table.num_columns(), 0), - absolute_row_position_(0) { + absolute_row_position_(0), + max_chunksize_(std::numeric_limits::max()) { for (int i = 0; i < table.num_columns(); ++i) { column_data_[i] = table.column(i)->data().get(); } @@ -416,7 +418,7 @@ class TableBatchReader::TableBatchReaderImpl { } // Determine the minimum contiguous slice across all columns - int64_t chunksize = table_.num_rows(); + int64_t chunksize = std::min(table_.num_rows(), max_chunksize_); std::vector chunks(table_.num_columns()); for (int i = 0; i < table_.num_columns(); ++i) { auto chunk = column_data_[i]->chunk(chunk_numbers_[i]).get(); @@ -430,8 +432,7 @@ class TableBatchReader::TableBatchReaderImpl { } // Slice chunks and advance chunk index as appropriate - std::vector> batch_data; - batch_data.reserve(table_.num_columns()); + std::vector> batch_data(table_.num_columns()); for (int i = 0; i < table_.num_columns(); ++i) { // Exhausted chunk @@ -441,7 +442,7 @@ class TableBatchReader::TableBatchReaderImpl { if ((chunk->length() - offset) == chunksize) { ++chunk_numbers_[i]; chunk_offsets_[i] = 0; - if (chunk_offsets_[i] > 0) { + if (offset > 0) { // Need to slice slice_data = chunk->Slice(offset, chunksize)->data(); } else { @@ -449,9 +450,10 @@ class TableBatchReader::TableBatchReaderImpl { slice_data = chunk->data(); } } else { + chunk_offsets_[i] += chunksize; slice_data = chunk->Slice(offset, chunksize)->data(); } - batch_data.emplace_back(std::move(slice_data)); + batch_data[i] = std::move(slice_data); } absolute_row_position_ += chunksize; @@ -462,12 +464,15 @@ class TableBatchReader::TableBatchReaderImpl { std::shared_ptr schema() const { return table_.schema(); } + void set_chunksize(int64_t chunksize) { max_chunksize_ = chunksize; } + private: const Table& table_; std::vector column_data_; std::vector chunk_numbers_; std::vector chunk_offsets_; int64_t absolute_row_position_; + int64_t max_chunksize_; }; TableBatchReader::TableBatchReader(const Table& table) { @@ -478,6 +483,10 @@ TableBatchReader::~TableBatchReader() {} std::shared_ptr TableBatchReader::schema() const { return impl_->schema(); } +void TableBatchReader::set_chunksize(int64_t chunksize) { + impl_->set_chunksize(chunksize); +} + Status TableBatchReader::ReadNext(std::shared_ptr* out) { return impl_->ReadNext(out); } diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index d0312d93cb9..c813b32ad36 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -197,6 +197,8 @@ class ARROW_EXPORT TableBatchReader : public RecordBatchReader { Status ReadNext(std::shared_ptr* out) override; + void set_chunksize(int64_t chunksize); + private: class TableBatchReaderImpl; std::unique_ptr impl_; diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 5d68607efa3..14211787c89 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -456,6 +456,13 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CTable] ReplaceSchemaMetadata( const shared_ptr[CKeyValueMetadata]& metadata) + cdef cppclass RecordBatchReader: + CStatus ReadNext(shared_ptr[CRecordBatch]* out) + + cdef cppclass TableBatchReader(RecordBatchReader): + TableBatchReader(const CTable& table) + void set_chunksize(int64_t chunksize) + cdef cppclass CTensor" arrow::Tensor": shared_ptr[CDataType] type() shared_ptr[CBuffer] data() @@ -692,7 +699,7 @@ cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: CStatus Close() CStatus WriteRecordBatch(const CRecordBatch& batch, c_bool allow_64bit) - CStatus WriteTable(const CTable& table) + CStatus WriteTable(const CTable& table, int64_t max_chunksize) cdef cppclass CRecordBatchReader" arrow::ipc::RecordBatchReader": shared_ptr[CSchema] schema() diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 27e91677509..e5639137dd3 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -202,7 +202,7 @@ cdef class _RecordBatchWriter: check_status(self.writer.get() .WriteRecordBatch(deref(batch.batch), 1)) - def write_table(self, Table table): + def write_table(self, Table table, chunksize=None): """ Write RecordBatch to stream @@ -210,8 +210,16 @@ cdef class _RecordBatchWriter: ---------- batch : RecordBatch """ + cdef: + # Chunksize must be > 0 to have any impact + int64_t c_chunksize = -1 + + if chunksize is not None: + c_chunksize = chunksize + with nogil: - check_status(self.writer.get().WriteTable(table.table[0])) + check_status(self.writer.get().WriteTable(table.table[0], + c_chunksize)) def close(self): """ diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 8c5b8bbc343..b03ee267022 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -971,6 +971,44 @@ cdef class Table: return pyarrow_wrap_table(c_table) + def to_batches(self, chunksize=None): + """ + Convert Table to list of (contiguous) RecordBatch objects, with optimal + maximum chunk size + + Parameters + ---------- + chunksize : int, default None + Maximum size for RecordBatch chunks. Individual chunks may be + smaller depending on the chunk layout of individual columns + + Returns + ------- + batches : list of RecordBatch + """ + cdef: + unique_ptr[TableBatchReader] reader + int64_t c_chunksize + list result = [] + shared_ptr[CRecordBatch] batch + + reader.reset(new TableBatchReader(deref(self.table))) + + if chunksize is not None: + c_chunksize = chunksize + reader.get().set_chunksize(c_chunksize) + + while True: + with nogil: + check_status(reader.get().ReadNext(&batch)) + + if batch.get() == NULL: + break + + result.append(pyarrow_wrap_batch(batch)) + + return result + def to_pandas(self, nthreads=None, strings_to_categorical=False, memory_pool=None, zero_copy_only=False): """ diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 5033ea95783..9cd5f807662 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -168,6 +168,29 @@ def test_stream_write_dispatch(self): assert_frame_equal(table.to_pandas(), pd.concat([df, df], ignore_index=True)) + def test_stream_write_table_batches(self): + # ARROW-504 + df = pd.DataFrame({ + 'one': np.random.randn(20), + }) + + b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False) + b2 = pa.RecordBatch.from_pandas(df, preserve_index=False) + + table = pa.Table.from_batches([b1, b2, b1]) + + writer = self._get_writer(self.sink, table.schema) + writer.write_table(table, chunksize=15) + writer.close() + + batches = list(pa.open_stream(pa.BufferReader(self._get_source()))) + + assert list(map(len, batches)) == [10, 15, 5, 10] + result_table = pa.Table.from_batches(batches) + assert_frame_equal(result_table.to_pandas(), + pd.concat([df[:10], df, df[:10]], + ignore_index=True)) + def test_simple_roundtrip(self): _, batches = self.write_batches() file_contents = pa.BufferReader(self._get_source()) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index cd05fb8e1fc..ab012340c0a 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -213,6 +213,31 @@ def test_recordbatchlist_schema_equals(): pa.Table.from_batches([batch1, batch2]) +def test_table_to_batches(): + df1 = pd.DataFrame({'a': list(range(10))}) + df2 = pd.DataFrame({'a': list(range(10, 30))}) + + batch1 = pa.RecordBatch.from_pandas(df1, preserve_index=False) + batch2 = pa.RecordBatch.from_pandas(df2, preserve_index=False) + + table = pa.Table.from_batches([batch1, batch2, batch1]) + + expected_df = pd.concat([df1, df2, df1], ignore_index=True) + + batches = table.to_batches() + assert len(batches) == 3 + + assert_frame_equal(pa.Table.from_batches(batches).to_pandas(), + expected_df) + + batches = table.to_batches(chunksize=15) + assert list(map(len, batches)) == [10, 15, 5, 10] + + assert_frame_equal(table.to_pandas(), expected_df) + assert_frame_equal(pa.Table.from_batches(batches).to_pandas(), + expected_df) + + def test_table_basics(): data = [ pa.array(range(5)),