Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def print_entry(label, value):

from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table,
concat_arrays, concat_tables, TableGroupBy,
RecordBatchReader, concat_batches)
RecordBatchReader, concat_batches,
normalize_batches)

# Exceptions
from pyarrow.lib import (ArrowCancelled,
Expand Down
117 changes: 117 additions & 0 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1502,3 +1502,120 @@ def read_record_batch(obj, Schema schema,
CIpcReadOptions.Defaults()))

return pyarrow_wrap_batch(result)


def normalize_batches(batches, min_rows=0, max_rows=None, min_bytes=0,
max_bytes=None):
"""
Normalize a stream of record batches to fall within row and byte bounds.

Splits oversized batches and concatenates undersized ones, yielding
batches that fall within configurable row and byte limits.

Parameters
----------
batches : iterable of RecordBatch
Input record batches.
min_rows : int, default 0
Minimum number of rows per output batch. The last batch may have
fewer rows.
max_rows : int or None, default None
Maximum number of rows per output batch.
min_bytes : int, default 0
Minimum byte size per output batch. The last batch may be smaller.
max_bytes : int or None, default None
Maximum byte size per output batch.

Yields
------
RecordBatch
Normalized record batches.

Examples
--------
>>> import pyarrow as pa
>>> batch = pa.record_batch([pa.array(range(100))], names=['x'])
>>> batches = list(pa.normalize_batches([batch], max_rows=30))
>>> [b.num_rows for b in batches]
[30, 30, 30, 10]
"""
if min_rows < 0:
raise ValueError(f"min_rows must be non-negative, got {min_rows}")
if min_bytes < 0:
raise ValueError(f"min_bytes must be non-negative, got {min_bytes}")
if max_rows is not None:
if max_rows <= 0:
raise ValueError(
f"max_rows must be positive, got {max_rows}")
if min_rows > max_rows:
raise ValueError(
f"min_rows ({min_rows}) must be <= max_rows ({max_rows})")
if max_bytes is not None:
if max_bytes <= 0:
raise ValueError(
f"max_bytes must be positive, got {max_bytes}")
if min_bytes > max_bytes:
raise ValueError(
f"min_bytes ({min_bytes}) must be <= max_bytes ({max_bytes})")

cdef:
list acc = []
int64_t acc_rows = 0
int64_t acc_bytes = 0
int64_t _min_rows = min_rows
int64_t _min_bytes = min_bytes
int64_t _max_rows = max_rows if max_rows is not None else -1
int64_t _max_bytes = max_bytes if max_bytes is not None else -1
int64_t chunk_rows, rows_for_bytes, batch_rows, batch_bytes
double bytes_per_row
bint has_max_rows = max_rows is not None
bint has_max_bytes = max_bytes is not None

for batch in batches:
batch_rows = batch.num_rows
if batch_rows == 0:
continue

acc.append(batch)
acc_rows += batch_rows
acc_bytes += batch.nbytes

while acc_rows > 0 and (
(acc_rows >= _min_rows and acc_bytes >= _min_bytes) or
(has_max_rows and acc_rows >= _max_rows) or
(has_max_bytes and acc_bytes >= _max_bytes)
):
if len(acc) == 1:
merged = acc[0]
else:
merged = concat_batches(acc)
acc = []
acc_rows = 0
acc_bytes = 0

while merged.num_rows > 0:
chunk_rows = merged.num_rows

if has_max_rows and chunk_rows > _max_rows:
chunk_rows = _max_rows

if has_max_bytes and merged.num_rows > 0:
bytes_per_row = merged.nbytes / <double>merged.num_rows
if bytes_per_row > 0:
rows_for_bytes = max(
1, <int64_t>(_max_bytes / bytes_per_row))
if chunk_rows > rows_for_bytes:
chunk_rows = rows_for_bytes

if chunk_rows >= merged.num_rows:
yield merged
break

yield merged.slice(0, chunk_rows)
merged = merged.slice(chunk_rows)

if acc:
if len(acc) == 1:
yield acc[0]
else:
yield concat_batches(acc)
129 changes: 129 additions & 0 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,3 +1429,132 @@ def read_options_args(request):
def test_read_options_repr(read_options_args):
# https://github.com/apache/arrow/issues/47358
check_ipc_options_repr(pa.ipc.IpcReadOptions, read_options_args)


def _make_batch(num_rows):
return pa.record_batch([pa.array(range(num_rows))], names=['x'])


def test_normalize_batches_split_large():
batch = _make_batch(100)
result = list(pa.normalize_batches([batch], max_rows=30))
row_counts = [b.num_rows for b in result]
assert row_counts == [30, 30, 30, 10]


def test_normalize_batches_concat_small():
batches = [_make_batch(5) for _ in range(8)]
result = list(pa.normalize_batches(batches, min_rows=20))
total_rows = sum(b.num_rows for b in result)
assert total_rows == 40
assert all(b.num_rows >= 20 for b in result[:-1])


def test_normalize_batches_passthrough():
batches = [_make_batch(10), _make_batch(10)]
result = list(pa.normalize_batches(batches, min_rows=5, max_rows=15))
row_counts = [b.num_rows for b in result]
assert row_counts == [10, 10]


def test_normalize_batches_empty_input():
result = list(pa.normalize_batches([]))
assert result == []


def test_normalize_batches_empty_batches_skipped():
batches = [_make_batch(0), _make_batch(5), _make_batch(0)]
result = list(pa.normalize_batches(batches))
assert len(result) == 1
assert result[0].num_rows == 5


def test_normalize_batches_max_bytes():
batch = _make_batch(100)
bytes_per_row = batch.nbytes / batch.num_rows
max_bytes = int(bytes_per_row * 25)
result = list(pa.normalize_batches([batch], max_bytes=max_bytes))
assert all(b.num_rows <= 26 for b in result)
assert sum(b.num_rows for b in result) == 100


def test_normalize_batches_min_bytes():
batches = [_make_batch(5) for _ in range(10)]
single_bytes = batches[0].nbytes
min_bytes = single_bytes * 4
result = list(pa.normalize_batches(batches, min_bytes=min_bytes))
total_rows = sum(b.num_rows for b in result)
assert total_rows == 50
if len(result) > 1:
assert all(b.nbytes >= min_bytes for b in result[:-1])


def test_normalize_batches_combined_constraints():
batch = _make_batch(100)
bytes_per_row = batch.nbytes / batch.num_rows
result = list(pa.normalize_batches(
[batch], max_rows=40, max_bytes=int(bytes_per_row * 30)))
assert all(b.num_rows <= 40 for b in result)
assert sum(b.num_rows for b in result) == 100


def test_normalize_batches_validation():
with pytest.raises(ValueError, match="min_rows must be non-negative"):
list(pa.normalize_batches([], min_rows=-1))
with pytest.raises(ValueError, match="min_bytes must be non-negative"):
list(pa.normalize_batches([], min_bytes=-1))
with pytest.raises(ValueError, match="max_rows must be positive"):
list(pa.normalize_batches([], max_rows=0))
with pytest.raises(ValueError, match="max_bytes must be positive"):
list(pa.normalize_batches([], max_bytes=0))
with pytest.raises(ValueError, match="min_rows.*must be <= max_rows"):
list(pa.normalize_batches([], min_rows=10, max_rows=5))
with pytest.raises(ValueError, match="min_bytes.*must be <= max_bytes"):
list(pa.normalize_batches([], min_bytes=100, max_bytes=50))


def test_normalize_batches_last_batch_below_minimum():
batches = [_make_batch(5), _make_batch(3)]
result = list(pa.normalize_batches(batches, min_rows=100))
assert len(result) == 1
assert result[0].num_rows == 8


def test_normalize_batches_preserves_schema():
schema = pa.schema([('a', pa.int32()), ('b', pa.utf8())])
batch = pa.record_batch(
[pa.array([1, 2, 3], type=pa.int32()),
pa.array(['x', 'y', 'z'], type=pa.utf8())],
schema=schema)
result = list(pa.normalize_batches([batch, batch], max_rows=4))
for b in result:
assert b.schema.equals(schema)


def test_normalize_batches_from_table():
table = pa.table({'x': range(100), 'y': [str(i) for i in range(100)]})
batches = table.to_batches(max_chunksize=7)
result = list(pa.normalize_batches(batches, min_rows=20, max_rows=30))
total_rows = sum(b.num_rows for b in result)
assert total_rows == 100
for b in result[:-1]:
assert 20 <= b.num_rows <= 30
assert result[-1].num_rows <= 30
for b in result:
assert b.schema.equals(table.schema)


def test_normalize_batches_from_dataset(tmp_path):
ds = pytest.importorskip("pyarrow.dataset")
table = pa.table({'a': range(200), 'b': [float(i) for i in range(200)]})
ds.write_dataset(table, tmp_path, format="parquet",
max_rows_per_file=50,
max_rows_per_group=50)
dataset = ds.dataset(tmp_path, format="parquet")
batches = dataset.to_batches()
result = list(pa.normalize_batches(batches, max_rows=30))
total_rows = sum(b.num_rows for b in result)
assert total_rows == 200
assert all(b.num_rows <= 30 for b in result)
for b in result:
assert b.schema.equals(dataset.schema)
Loading