From ee815c986e7d57f7ca8bc29ae361ee98e9d12aa2 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Mon, 9 Feb 2026 13:08:50 +0100 Subject: [PATCH 1/3] example change --- python/pyarrow/__init__.py | 3 +- python/pyarrow/ipc.pxi | 114 +++++++++++++++++++++++++++++++ python/pyarrow/tests/test_ipc.py | 100 +++++++++++++++++++++++++++ 3 files changed, 216 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 18a40d877c3..0511d90bd03 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -262,7 +262,8 @@ def print_entry(label, value): from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table, concat_arrays, concat_tables, TableGroupBy, - RecordBatchReader, concat_batches) + RecordBatchReader, concat_batches, + normalize_batches) # Exceptions from pyarrow.lib import (ArrowCancelled, diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 6477579af21..0baf6fb7791 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -1502,3 +1502,117 @@ def read_record_batch(obj, Schema schema, CIpcReadOptions.Defaults())) return pyarrow_wrap_batch(result) + + +def normalize_batches(batches, min_rows=0, max_rows=None, min_bytes=0, + max_bytes=None): + """ + Normalize a stream of record batches to fall within row and byte bounds. + + Splits oversized batches and concatenates undersized ones, yielding + batches that fall within configurable row and byte limits. + + Parameters + ---------- + batches : iterable of RecordBatch + Input record batches. + min_rows : int, default 0 + Minimum number of rows per output batch. The last batch may have + fewer rows. + max_rows : int or None, default None + Maximum number of rows per output batch. + min_bytes : int, default 0 + Minimum byte size per output batch. The last batch may be smaller. + max_bytes : int or None, default None + Maximum byte size per output batch. + + Yields + ------ + RecordBatch + Normalized record batches. + + Examples + -------- + >>> import pyarrow as pa + >>> batch = pa.record_batch([pa.array(range(100))], names=['x']) + >>> batches = list(pa.normalize_batches([batch], max_rows=30)) + >>> [b.num_rows for b in batches] + [30, 30, 30, 10] + """ + if min_rows < 0: + raise ValueError(f"min_rows must be non-negative, got {min_rows}") + if min_bytes < 0: + raise ValueError(f"min_bytes must be non-negative, got {min_bytes}") + if max_rows is not None: + if max_rows <= 0: + raise ValueError( + f"max_rows must be positive, got {max_rows}") + if min_rows > max_rows: + raise ValueError( + f"min_rows ({min_rows}) must be <= max_rows ({max_rows})") + if max_bytes is not None: + if max_bytes <= 0: + raise ValueError( + f"max_bytes must be positive, got {max_bytes}") + if min_bytes > max_bytes: + raise ValueError( + f"min_bytes ({min_bytes}) must be <= max_bytes ({max_bytes})") + + acc = [] + acc_rows = 0 + acc_bytes = 0 + + def _should_flush(): + mins_met = acc_rows >= min_rows and acc_bytes >= min_bytes + max_rows_exceeded = max_rows is not None and acc_rows >= max_rows + max_bytes_exceeded = max_bytes is not None and acc_bytes >= max_bytes + return mins_met or max_rows_exceeded or max_bytes_exceeded + + def _flush_batches(): + nonlocal acc, acc_rows, acc_bytes + if len(acc) == 1: + merged = acc[0] + else: + merged = concat_batches(acc) + acc = [] + acc_rows = 0 + acc_bytes = 0 + + while merged.num_rows > 0: + chunk_rows = merged.num_rows + + if max_rows is not None: + chunk_rows = min(chunk_rows, max_rows) + + if max_bytes is not None and merged.num_rows > 0: + bytes_per_row = merged.nbytes / merged.num_rows + if bytes_per_row > 0: + rows_for_bytes = max(1, int(max_bytes / bytes_per_row)) + chunk_rows = min(chunk_rows, rows_for_bytes) + + chunk = merged.slice(0, chunk_rows) + remainder = merged.slice(chunk_rows) + + if remainder.num_rows > 0: + yield chunk + merged = remainder + else: + yield chunk + break + + for batch in batches: + if batch.num_rows == 0: + continue + + acc.append(batch) + acc_rows += batch.num_rows + acc_bytes += batch.nbytes + + while _should_flush(): + yield from _flush_batches() + + if acc: + if len(acc) == 1: + yield acc[0] + else: + yield concat_batches(acc) diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 6813ed77723..c462a5da6f9 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1429,3 +1429,103 @@ def read_options_args(request): def test_read_options_repr(read_options_args): # https://github.com/apache/arrow/issues/47358 check_ipc_options_repr(pa.ipc.IpcReadOptions, read_options_args) + + +def _make_batch(num_rows): + return pa.record_batch([pa.array(range(num_rows))], names=['x']) + + +def test_normalize_batches_split_large(): + batch = _make_batch(100) + result = list(pa.normalize_batches([batch], max_rows=30)) + row_counts = [b.num_rows for b in result] + assert row_counts == [30, 30, 30, 10] + + +def test_normalize_batches_concat_small(): + batches = [_make_batch(5) for _ in range(8)] + result = list(pa.normalize_batches(batches, min_rows=20)) + total_rows = sum(b.num_rows for b in result) + assert total_rows == 40 + assert all(b.num_rows >= 20 for b in result[:-1]) + + +def test_normalize_batches_passthrough(): + batches = [_make_batch(10), _make_batch(10)] + result = list(pa.normalize_batches(batches, min_rows=5, max_rows=15)) + row_counts = [b.num_rows for b in result] + assert row_counts == [10, 10] + + +def test_normalize_batches_empty_input(): + result = list(pa.normalize_batches([])) + assert result == [] + + +def test_normalize_batches_empty_batches_skipped(): + batches = [_make_batch(0), _make_batch(5), _make_batch(0)] + result = list(pa.normalize_batches(batches)) + assert len(result) == 1 + assert result[0].num_rows == 5 + + +def test_normalize_batches_max_bytes(): + batch = _make_batch(100) + bytes_per_row = batch.nbytes / batch.num_rows + max_bytes = int(bytes_per_row * 25) + result = list(pa.normalize_batches([batch], max_bytes=max_bytes)) + assert all(b.num_rows <= 26 for b in result) + assert sum(b.num_rows for b in result) == 100 + + +def test_normalize_batches_min_bytes(): + batches = [_make_batch(5) for _ in range(10)] + single_bytes = batches[0].nbytes + min_bytes = single_bytes * 4 + result = list(pa.normalize_batches(batches, min_bytes=min_bytes)) + total_rows = sum(b.num_rows for b in result) + assert total_rows == 50 + if len(result) > 1: + assert all(b.nbytes >= min_bytes for b in result[:-1]) + + +def test_normalize_batches_combined_constraints(): + batch = _make_batch(100) + bytes_per_row = batch.nbytes / batch.num_rows + result = list(pa.normalize_batches( + [batch], max_rows=40, max_bytes=int(bytes_per_row * 30))) + assert all(b.num_rows <= 40 for b in result) + assert sum(b.num_rows for b in result) == 100 + + +def test_normalize_batches_validation(): + with pytest.raises(ValueError, match="min_rows must be non-negative"): + list(pa.normalize_batches([], min_rows=-1)) + with pytest.raises(ValueError, match="min_bytes must be non-negative"): + list(pa.normalize_batches([], min_bytes=-1)) + with pytest.raises(ValueError, match="max_rows must be positive"): + list(pa.normalize_batches([], max_rows=0)) + with pytest.raises(ValueError, match="max_bytes must be positive"): + list(pa.normalize_batches([], max_bytes=0)) + with pytest.raises(ValueError, match="min_rows.*must be <= max_rows"): + list(pa.normalize_batches([], min_rows=10, max_rows=5)) + with pytest.raises(ValueError, match="min_bytes.*must be <= max_bytes"): + list(pa.normalize_batches([], min_bytes=100, max_bytes=50)) + + +def test_normalize_batches_last_batch_below_minimum(): + batches = [_make_batch(5), _make_batch(3)] + result = list(pa.normalize_batches(batches, min_rows=100)) + assert len(result) == 1 + assert result[0].num_rows == 8 + + +def test_normalize_batches_preserves_schema(): + schema = pa.schema([('a', pa.int32()), ('b', pa.utf8())]) + batch = pa.record_batch( + [pa.array([1, 2, 3], type=pa.int32()), + pa.array(['x', 'y', 'z'], type=pa.utf8())], + schema=schema) + result = list(pa.normalize_batches([batch, batch], max_rows=4)) + for b in result: + assert b.schema.equals(schema) From fcb171123ef99c6f599c0134c6f1ce296feec58b Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Mon, 9 Feb 2026 13:21:26 +0100 Subject: [PATCH 2/3] add more tests --- python/pyarrow/tests/test_ipc.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index c462a5da6f9..e2bb5ffa779 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1529,3 +1529,31 @@ def test_normalize_batches_preserves_schema(): result = list(pa.normalize_batches([batch, batch], max_rows=4)) for b in result: assert b.schema.equals(schema) + + +def test_normalize_batches_from_table(): + table = pa.table({'x': range(100), 'y': [str(i) for i in range(100)]}) + batches = table.to_batches(max_chunksize=7) + result = list(pa.normalize_batches(batches, min_rows=20, max_rows=30)) + total_rows = sum(b.num_rows for b in result) + assert total_rows == 100 + for b in result[:-1]: + assert 20 <= b.num_rows <= 30 + assert result[-1].num_rows <= 30 + for b in result: + assert b.schema.equals(table.schema) + + +def test_normalize_batches_from_dataset(tmp_path): + ds = pytest.importorskip("pyarrow.dataset") + table = pa.table({'a': range(200), 'b': [float(i) for i in range(200)]}) + ds.write_dataset(table, tmp_path, format="parquet", + max_rows_per_file=50) + dataset = ds.dataset(tmp_path, format="parquet") + batches = dataset.to_batches() + result = list(pa.normalize_batches(batches, max_rows=30)) + total_rows = sum(b.num_rows for b in result) + assert total_rows == 200 + assert all(b.num_rows <= 30 for b in result) + for b in result: + assert b.schema.equals(dataset.schema) From 5a9431b020df889892a429edbd83a55fb8069a62 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Mon, 9 Feb 2026 13:37:14 +0100 Subject: [PATCH 3/3] fix --- python/pyarrow/ipc.pxi | 93 ++++++++++++++++---------------- python/pyarrow/tests/test_ipc.py | 3 +- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 0baf6fb7791..5d6aab97353 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -1558,58 +1558,61 @@ def normalize_batches(batches, min_rows=0, max_rows=None, min_bytes=0, raise ValueError( f"min_bytes ({min_bytes}) must be <= max_bytes ({max_bytes})") - acc = [] - acc_rows = 0 - acc_bytes = 0 - - def _should_flush(): - mins_met = acc_rows >= min_rows and acc_bytes >= min_bytes - max_rows_exceeded = max_rows is not None and acc_rows >= max_rows - max_bytes_exceeded = max_bytes is not None and acc_bytes >= max_bytes - return mins_met or max_rows_exceeded or max_bytes_exceeded - - def _flush_batches(): - nonlocal acc, acc_rows, acc_bytes - if len(acc) == 1: - merged = acc[0] - else: - merged = concat_batches(acc) - acc = [] - acc_rows = 0 - acc_bytes = 0 - - while merged.num_rows > 0: - chunk_rows = merged.num_rows - - if max_rows is not None: - chunk_rows = min(chunk_rows, max_rows) - - if max_bytes is not None and merged.num_rows > 0: - bytes_per_row = merged.nbytes / merged.num_rows - if bytes_per_row > 0: - rows_for_bytes = max(1, int(max_bytes / bytes_per_row)) - chunk_rows = min(chunk_rows, rows_for_bytes) - - chunk = merged.slice(0, chunk_rows) - remainder = merged.slice(chunk_rows) - - if remainder.num_rows > 0: - yield chunk - merged = remainder - else: - yield chunk - break + cdef: + list acc = [] + int64_t acc_rows = 0 + int64_t acc_bytes = 0 + int64_t _min_rows = min_rows + int64_t _min_bytes = min_bytes + int64_t _max_rows = max_rows if max_rows is not None else -1 + int64_t _max_bytes = max_bytes if max_bytes is not None else -1 + int64_t chunk_rows, rows_for_bytes, batch_rows, batch_bytes + double bytes_per_row + bint has_max_rows = max_rows is not None + bint has_max_bytes = max_bytes is not None for batch in batches: - if batch.num_rows == 0: + batch_rows = batch.num_rows + if batch_rows == 0: continue acc.append(batch) - acc_rows += batch.num_rows + acc_rows += batch_rows acc_bytes += batch.nbytes - while _should_flush(): - yield from _flush_batches() + while acc_rows > 0 and ( + (acc_rows >= _min_rows and acc_bytes >= _min_bytes) or + (has_max_rows and acc_rows >= _max_rows) or + (has_max_bytes and acc_bytes >= _max_bytes) + ): + if len(acc) == 1: + merged = acc[0] + else: + merged = concat_batches(acc) + acc = [] + acc_rows = 0 + acc_bytes = 0 + + while merged.num_rows > 0: + chunk_rows = merged.num_rows + + if has_max_rows and chunk_rows > _max_rows: + chunk_rows = _max_rows + + if has_max_bytes and merged.num_rows > 0: + bytes_per_row = merged.nbytes / merged.num_rows + if bytes_per_row > 0: + rows_for_bytes = max( + 1, (_max_bytes / bytes_per_row)) + if chunk_rows > rows_for_bytes: + chunk_rows = rows_for_bytes + + if chunk_rows >= merged.num_rows: + yield merged + break + + yield merged.slice(0, chunk_rows) + merged = merged.slice(chunk_rows) if acc: if len(acc) == 1: diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index e2bb5ffa779..a6c0948d535 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1548,7 +1548,8 @@ def test_normalize_batches_from_dataset(tmp_path): ds = pytest.importorskip("pyarrow.dataset") table = pa.table({'a': range(200), 'b': [float(i) for i in range(200)]}) ds.write_dataset(table, tmp_path, format="parquet", - max_rows_per_file=50) + max_rows_per_file=50, + max_rows_per_group=50) dataset = ds.dataset(tmp_path, format="parquet") batches = dataset.to_batches() result = list(pa.normalize_batches(batches, max_rows=30))