diff --git a/rust/benches/my_benchmark.rs b/rust/benches/my_benchmark.rs index 997f0578..d2882ddd 100644 --- a/rust/benches/my_benchmark.rs +++ b/rust/benches/my_benchmark.rs @@ -14,6 +14,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use glob::glob; +use sedpack_rs::batch_iteration::BatchIterator; use sedpack_rs::example_iteration::{ get_shard_progress, CompressionType, ExampleIterator, ShardInfo, }; @@ -31,6 +32,17 @@ pub fn get_shard_files() -> Vec { shard_infos } +pub fn batch_iterator_benchmark(c: &mut Criterion) { + let shard_infos = get_shard_files(); + c.bench_function("BatchIterator", |b| { + b.iter(|| { + for batch in BatchIterator::new(shard_infos.clone(), 12, 32, vec![true, true]) { + let _ = std::hint::black_box(batch); + } + }) + }); +} + pub fn example_iterator_benchmark(c: &mut Criterion) { let shard_infos = get_shard_files(); c.bench_function("ExampleIterator", |b| { @@ -55,5 +67,10 @@ pub fn parallel_map_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, example_iterator_benchmark, parallel_map_benchmark,); +criterion_group!( + benches, + batch_iterator_benchmark, + example_iterator_benchmark, + parallel_map_benchmark, +); criterion_main!(benches); diff --git a/rust/src/batch_iteration.rs b/rust/src/batch_iteration.rs new file mode 100644 index 00000000..29432fd9 --- /dev/null +++ b/rust/src/batch_iteration.rs @@ -0,0 +1,101 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub use super::example_iteration::{ + get_shard_progress, CompressionType, Example, ExampleIterator, ShardInfo, ShardProgress, +}; +pub use super::parallel_map::parallel_map; +pub use super::shard_generated::sedpack::io::flatbuffer::shardfile::{root_as_shard, Shard}; + +/// Single attribute which has been batched. +pub enum BatchedAttribute { + /// Row-major order batch of the attribute with static (fixed) size. That is in NumPy C-order + /// we can index as data[batch_index][attribute_index] where batch_index in 0..batch_size and + /// attribute_index in 0..len(attribute). + Static { data: numpy::ndarray::Array }, + /// Dynamic data where we do not know shape up front (e.g., string, bytearray) is represented + /// as a vector with the same indexing semantic. + Dynamic { data: Vec> }, +} + +pub type Batch = Vec; + +struct Batcher { + example_iterator: Box + Send>, + batch_size: usize, + has_fixed_shape: Vec, +} + +impl Iterator for Batcher { + type Item = Batch; + + fn next(&mut self) -> Option { + // Collect examples. + let cache: Vec = self.example_iterator.by_ref().take(self.batch_size).collect(); + + // Decide if we have enough (the last batch might not have batch_size examples). + if cache.is_empty() { + return None; + } + + // Batch the examples. + let mut result = Batch::new(); + for (attribute_index, is_fixed) in self.has_fixed_shape.iter().enumerate() { + // Collect batched version of current attribute across all cached examples. + let current_batched_attribute = match is_fixed { + true => BatchedAttribute::Static { + data: numpy::ndarray::Array::::from_iter( + cache.iter().flat_map(|e| e[attribute_index].iter().cloned()), + ), + }, + false => BatchedAttribute::Dynamic { + data: cache + .iter() + .map(|e| { + numpy::ndarray::Array::::from_iter( + e[attribute_index].iter().cloned(), + ) + }) + .collect(), + }, + }; + + // Save the batched attribute. + result.push(current_batched_attribute); + } + Some(result) + } +} + +pub struct BatchIterator { + batch_iterator: Box + Send>, +} + +impl BatchIterator { + pub fn new( + files: Vec, threads: usize, batch_size: usize, has_fixed_shape: Vec, + ) -> Self { + let example_iterator = Box::new(ExampleIterator::new(files, threads)); + let batch_iterator = Box::new(Batcher { example_iterator, batch_size, has_fixed_shape }); + BatchIterator { batch_iterator } + } +} + +impl Iterator for BatchIterator { + type Item = Batch; + + fn next(&mut self) -> Option { + self.batch_iterator.next() + } +} diff --git a/rust/src/example_iteration.rs b/rust/src/example_iteration.rs index 3fa88634..215f8be8 100644 --- a/rust/src/example_iteration.rs +++ b/rust/src/example_iteration.rs @@ -100,7 +100,12 @@ impl ExampleIterator { /// `files: impl Iterator`. pub fn new(files: Vec, threads: usize) -> Self { let example_iterator = Box::new( - parallel_map(|x| get_shard_progress(&x), files.into_iter(), threads).flatten(), + parallel_map( + |x| get_shard_progress(&x).collect::>(), + files.into_iter(), + threads, + ) + .flatten(), ); ExampleIterator { example_iterator } } @@ -142,7 +147,7 @@ fn read_to_end(mut reader: impl std::io::Read) -> Vec { } /// Get ShardProgress. -pub fn get_shard_progress(shard_info: &ShardInfo) -> Vec { +pub fn get_shard_progress(shard_info: &ShardInfo) -> ShardProgress { let file_bytes = get_file_bytes(shard_info); // A shard is a vector of examples (positive number -- invariant kept by Python code). @@ -156,7 +161,7 @@ pub fn get_shard_progress(shard_info: &ShardInfo) -> Vec { // Number of examples might be different in different shards. let total_examples = shard.get().examples().unwrap().len(); - ShardProgress { total_examples, used_examples: 0, shard }.collect() + ShardProgress { total_examples, used_examples: 0, shard } } /// Get single example out of a ShardProgress. diff --git a/rust/src/lib.rs b/rust/src/lib.rs index b5fa9cca..8813b8e5 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2024-2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ pub use shard_generated::sedpack::io::flatbuffer::shardfile::{ root_as_shard, root_as_shard_unchecked, Attribute, Example, Shard, }; +pub mod batch_iteration; pub mod example_iteration; pub mod parallel_map; // Import the autogenerated code for parsing a shard represented as a FlatBuffer. @@ -150,9 +151,152 @@ mod static_iter { } } +/// Python wrappers around `example_iteration`. +mod static_batched_iter { + use std::collections::HashMap; + use std::str::FromStr; + + use numpy::IntoPyArray; + use pyo3::prelude::*; + use pyo3::{pyclass, pymethods, PyRefMut}; + + use super::batch_iteration::{BatchIterator, BatchedAttribute}; + use super::example_iteration::{CompressionType, ShardInfo}; + + /// Implementation details: The goal is to own the BatchIterator in Rust and only send + /// examples to Python. This helps with concurrent reading and parsing of shard files. + /// Moreover Python code cannot compromise integrity of the data structures. + /// + /// - We need support for multiple BatchIterator's at the same time since during training the + /// train and validation split are being read in an interleaved manner. To support this each + /// RustIter instance keeps a `static_index` determining which `BatchIterator` it is using + /// (dispatch done using a HashMap). + /// - Since a `HashMap` cannot be instantiated static we use an LazyLock>>. + /// - Using a mutex to avoid the need to use unsafe for a static mutable variable. The overhead + /// should be negligible since only a single thread is expected to access this. + /// - Python does not guarantee that __del__ is called right away (or at all). Thus RustIter + /// also implements a context manager which is guaranteed to call __exit__ and drop memory + /// owned by the corresponding BatchIterator. + static STATIC_ITERATORS: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| std::sync::Mutex::new(HashMap::new())); + + #[pyclass] + pub struct BatchedRustIter { + /// Which BatchIterator are we interacting with (unique id). Experimental API expect + /// breaking changes. + static_index: i32, + /// Read only value. For iteration we use this object as a context manager which allows us + /// to free resources in STATIC_ITERATORS on the call of __exit__. + /// + /// Alternatives considered: + /// - __del__ is not yet supported by pyo3 and also not guaranteed to be called by Python. + #[pyo3(get)] + can_iterate: bool, + } + + impl Iterator for BatchedRustIter { + type Item = ::Item; + + fn next(&mut self) -> Option { + // TODO move println to logging. + if !self.can_iterate { + println!( + "Use the context manager to enable iteration and guaranteed memory \ + deallocation" + ); + return None; + } + let mut hash_map = STATIC_ITERATORS.lock().unwrap(); + let iter = hash_map + .get_mut(&self.static_index) + .expect("The static_index was not found among the STATIC_ITERATORS."); + iter.next() + } + } + + #[pymethods] + impl BatchedRustIter { + #[new] + fn new( + files: Vec, threads: usize, compression: String, batch_size: usize, + has_fixed_shape: Vec, + ) -> Self { + let static_index = rand::random(); + let mut hash_map = STATIC_ITERATORS.lock().unwrap(); + let compression_type = CompressionType::from_str(&compression).unwrap(); + let shard_infos = files + .into_iter() + .map(|file_path| ShardInfo { file_path, compression_type }) + .collect(); + hash_map.insert( + static_index, + BatchIterator::new(shard_infos, threads, batch_size, has_fixed_shape), + ); + + BatchedRustIter { static_index, can_iterate: false } + } + + #[staticmethod] + fn supported_compressions() -> Vec { + CompressionType::supported_compressions() + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + /// Yields another batch of examples. Attributes are batched in the following way: + /// + /// - static (fixed shape) is row-major order (C-order) numpy array batch of attribute + /// values. + /// + /// - dynamic (unknown shape, e.g., str, bytes) list of numpy arrays. + fn __next__<'py>(mut slf: PyRefMut<'py, Self>) -> Option> { + match slf.next() { + None => None, + Some(result) => { + let elements: Vec> = result + .into_iter() + .map(|batched_attribute| match batched_attribute { + BatchedAttribute::Static { data } => { + data.into_pyarray(slf.py()).into_any() + } + BatchedAttribute::Dynamic { data } => pyo3::types::PyList::new( + slf.py(), + data.into_iter().map(|e| e.into_pyarray(slf.py())), + ) + .unwrap() + .into_any(), + }) + .collect(); + Some(pyo3::types::PyList::new(slf.py(), elements).unwrap()) + } + } + } + + /// The implementation is reentrant. If changing also change + /// `sedpack.io.dataset_iteration.RustGenerator`. + fn __enter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> { + slf.can_iterate = true; + slf + } + + fn __exit__( + mut slf: PyRefMut<'_, Self>, _exc_type: &Bound<'_, PyAny>, _exc_val: &Bound<'_, PyAny>, + _exc_tb: &Bound<'_, PyAny>, + ) { + slf.can_iterate = false; + // Drop from STATIC_ITERATORS. + let mut hash_map = STATIC_ITERATORS.lock().unwrap(); + drop(hash_map.remove(&slf.static_index)); + } + } +} + /// A Python module implemented in Rust. #[pymodule] fn _sedpack_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/sedpack/io/iteration/__init__.py b/src/sedpack/io/iteration/__init__.py index c15aa47d..9392da7c 100644 --- a/src/sedpack/io/iteration/__init__.py +++ b/src/sedpack/io/iteration/__init__.py @@ -13,8 +13,10 @@ # limitations under the License. """Dataset iteration.""" +from sedpack.io.iteration.rust_batched_generator import RustBatchedGenerator from sedpack.io.iteration.rust_generator import RustGenerator __all__ = [ + "RustBatchedGenerator", "RustGenerator", ] diff --git a/src/sedpack/io/iteration/rust_batched_generator.py b/src/sedpack/io/iteration/rust_batched_generator.py new file mode 100644 index 00000000..8f7c45d4 --- /dev/null +++ b/src/sedpack/io/iteration/rust_batched_generator.py @@ -0,0 +1,184 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rust batched generator object wrapping the rust object to behave nicely with +TensorFlow.""" +import itertools +import os +from pathlib import Path +from types import TracebackType +from typing import ( + Callable, + Iterable, + Iterator, + Type, +) +from typing_extensions import Self + +import numpy as np + +from sedpack.io.flatbuffer import IterateShardFlatBuffer +from sedpack.io.metadata import DatasetStructure +from sedpack.io.shard.iterate_shard_base import T +from sedpack.io.shard_file_metadata import ShardInfo +from sedpack.io.types import ExampleT + +from sedpack._sedpack_rs import BatchedRustIter + + +class RustBatchedGenerator: + """Similar to sedpack.io.iteration.RustGenerator with batching. + Experimental API, expect breaking changes. + """ + + def __init__( + self, + *, + dataset_path: Path, + dataset_structure: DatasetStructure, + shard_iterator: Iterable[ShardInfo], + batch_size: int, + process_record: Callable[[ExampleT], T] | None = None, + file_parallelism: int = os.cpu_count() or 1, + ) -> None: + """A reentrant generator. + + Args: + + dataset_path (Path): The root path of the dataset. + + dataset_structure (DatasetStructure): The structure of the dataset. + + shard_iterator: (Iterable[ShardInfo]): How the shards should be + iterated. + + batch_size (int): Size of the batches. + + process_record (Callable[[ExampleT], T] | None): Optional + transformation of each example. + + file_parallelism (int): How many files to read in parallel. + """ + self._iter: BatchedRustIter | None # type: ignore[no-any-unimported] + self._iter = None + self._stopped: bool = False + + # Workaround until BatchedRustIter supports an Iterable[ShardInfo]. Take + # _shard_chunk_size shard paths at once. + self._shard_chunk_size: int = 1_000_000 + + # Check file_parallelism is positive. + if file_parallelism <= 0: + raise ValueError("The argument file_parallelism should be " + f"positive but is {file_parallelism}") + + self._dataset_path: Path = dataset_path + self._dataset_structure: DatasetStructure = dataset_structure + # Make sure that any iteration on shard_iterator advances instead of + # starting again. + self._shard_iterator: Iterator[ShardInfo] = iter(shard_iterator) + self._process_record: Callable[[ExampleT], T] | None = process_record + self._batch_size: int = batch_size + self._file_parallelism: int = file_parallelism + + # Which attributes have fixed shapes and which do not. + self._has_fixed_shape: tuple[bool, ...] = tuple( + not attribute.has_variable_size() + for attribute in dataset_structure.saved_data_description) + + # Only FlatBuffers are supported. + if dataset_structure.shard_file_type != "fb": + raise ValueError( + "RustBatchedGenerator is implemented only for FlatBuffers.") + + # Check if the compression type is supported by Rust. + supported_compressions = BatchedRustIter.supported_compressions() + if dataset_structure.compression not in supported_compressions: + raise ValueError( + f"The compression {dataset_structure.compression} is not " + "among the supported compressions: {supported_compressions}") + + def to_dict(example: list[np.typing.NDArray[np.uint8]]) -> ExampleT: + result: ExampleT = {} + for np_bytes, attribute in zip( + example, dataset_structure.saved_data_description): + result[attribute.name] = IterateShardFlatBuffer.decode_array( + np_bytes=np_bytes, + attribute=attribute, + batch_size=-1, + ) + return result + + self._to_dict = to_dict + + def __enter__(self) -> Self: + """Enter the context manager (takes care of freeing memory held by + Rust). + """ + return self + + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Drop the rust data structure holding content of open files and + future examples. + """ + if self._iter is not None: + self._iter.__exit__(exc_type, exc_value, exc_tb) + + def __call__(self) -> Iterable[ExampleT] | Iterable[T]: + """Return an iterable. + """ + while not self._stopped: + yield from self._single_iter() + + def _single_iter(self) -> Iterable[ExampleT] | Iterable[T]: + """Iterate over a single chunk of shards. + """ + if self._iter is None: + shard_paths: list[str] = [ + str(self._dataset_path / s.file_infos[0].file_path) + for s in itertools.islice( + self._shard_iterator, + self._shard_chunk_size, + ) + ] + + if not shard_paths: + # No shards to iterate. + self._stopped = True + return + + self._iter = BatchedRustIter( + files=shard_paths, + threads=self._file_parallelism, + compression=self._dataset_structure.compression, + batch_size=self._batch_size, + has_fixed_shape=self._has_fixed_shape, + ) + # Manually calling __enter__ and __exit__ -- see class docstring. + self._iter.__enter__() # pylint: disable=unnecessary-dunder-call + elif not self._iter.can_iterate: + self._iter.__enter__() # pylint: disable=unnecessary-dunder-call + + example_iterator = map(self._to_dict, iter(self._iter)) + if self._process_record: + yield from map(self._process_record, example_iterator) + else: + yield from example_iterator + + self._iter.__exit__(None, None, None) + self._iter = None diff --git a/src/sedpack/io/types.py b/src/sedpack/io/types.py index d228808a..95d324d4 100644 --- a/src/sedpack/io/types.py +++ b/src/sedpack/io/types.py @@ -13,6 +13,8 @@ # limitations under the License. """Build and load tensorFlow dataset Record wrapper""" +# pylint: disable=invalid-name + from typing import Any, Literal, Union import numpy as np @@ -32,7 +34,12 @@ TFDatasetT: TypeAlias = Any # Type of an attribute value. -AttributeValueT: TypeAlias = Union[str, int, npt.NDArray[np.generic], bytes] +AttributeValueT: TypeAlias = Union[ + str, # UTF-8 string + int, + npt.NDArray[np.generic], + bytes, +] # Compression choices. CompressionT: TypeAlias = Literal[ diff --git a/tests/io/iteration/test_rust_batched_generator.py b/tests/io/iteration/test_rust_batched_generator.py new file mode 100644 index 00000000..e51bf183 --- /dev/null +++ b/tests/io/iteration/test_rust_batched_generator.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import pytest +import random +from typing import Union +import uuid + +import numpy as np + +import sedpack +from sedpack.io import Dataset +from sedpack.io.iteration import RustBatchedGenerator +from sedpack.io.metadata import DatasetStructure, Metadata +from sedpack.io.shard_info_iterator import CachedShardInfoIterator +from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT + + +@pytest.fixture(scope="module") +def dataset_and_values(tmpdir_factory) -> None: + data_points: int = 1_024 + dtype: str = "float32" + + # Values saved in the dataset. + values = { + "fixed": np.random.random((data_points, 138)).astype(dtype), + "fixed_2d": np.random.random((data_points, 3, 5)).astype(dtype), + # TODO(reintroduce) when https://github.com/google/sedpack/pull/227 is + # merged + #"dynamic_shape": [ + # uuid.uuid4().hex[:random.randint(15, 25)] + # for _ in range(data_points) + #], + } + tmpdir = tmpdir_factory.mktemp("end_2_end_data") + + tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment" + + # Create a dataset + dataset_metadata = Metadata(description="Test of the lib") + + example_attributes = [ + sedpack.io.metadata.Attribute( + name="fixed", + dtype=str(dtype), + shape=values["fixed"][0].shape, + ), + sedpack.io.metadata.Attribute( + name="fixed_2d", + dtype=str(dtype), + shape=values["fixed_2d"][0].shape, + ), + #sedpack.io.metadata.Attribute( + # name="dynamic_shape", + # dtype="str", + # shape=(), + #), + ] + + dataset_structure = sedpack.io.metadata.DatasetStructure( + saved_data_description=example_attributes, + compression="LZ4", + examples_per_shard=24, + shard_file_type="fb", + ) + + # Test attribute_by_name + for attribute in example_attributes: + assert dataset_structure.attribute_by_name( + attribute_name=attribute.name) == attribute + + dataset = Dataset.create( + path=tiny_experiment_path, + metadata=dataset_metadata, + dataset_structure=dataset_structure, + ) + + # Fill data in the dataset + + with dataset.filler() as filler: + for i in range(data_points): + filler.write_example( + values={ + name: val[i] for name, val in values.items() + }, + split=TRAIN_SPLIT, + ) + + # Check the data is correct + # Reopen the dataset + dataset = Dataset(tiny_experiment_path) + dataset.check() + + yield (dataset, values) + + # Teardown + + +def test_wrong_file_paralelism() -> None: + with pytest.raises( + ValueError, + match="The argument file_parallelism should be positive.*", + ): + g = RustBatchedGenerator( + dataset_path=Path(), + dataset_structure=DatasetStructure(), + shard_iterator=[], + process_record=None, + file_parallelism=0, + batch_size=1, + ) + + +def test_wrong_shard_type() -> None: + with pytest.raises( + ValueError, + match="RustBatchedGenerator is implemented only for FlatBuffers.", + ): + g = RustBatchedGenerator( + dataset_path=Path(), + dataset_structure=DatasetStructure(shard_file_type="tfrec"), + shard_iterator=[], + process_record=None, + file_parallelism=1, + batch_size=1, + ) + + +def test_wrong_compression() -> None: + with pytest.raises( + ValueError, + match= + "The compression .* is not among the supported compressions: .*", + ): + g = RustBatchedGenerator( + dataset_path=Path(), + dataset_structure=DatasetStructure( + shard_file_type="fb", + compression="ZIP", + ), + shard_iterator=[], + process_record=None, + file_parallelism=1, + batch_size=1, + ) + + +@pytest.mark.parametrize("batch_size", [1, 2, 7]) +def test_end_to_end_rust_batched( + batch_size, + dataset_and_values, +): + dataset, values = dataset_and_values + + with RustBatchedGenerator( + dataset_path=dataset.path, + dataset_structure=dataset.dataset_structure, + shard_iterator=CachedShardInfoIterator( + dataset_path=dataset.path, + dataset_info=dataset.dataset_info, + split="train", + repeat=False, + shards=None, + custom_metadata_type_limit=None, + shard_filter=None, + shuffle=0, + ), + batch_size=batch_size, + process_record=None, + file_parallelism=8, + ) as g: + index: int = 0 + for batch in g(): + current_batch_size: int = -1 + + for name, attribute_values in batch.items(): + if current_batch_size < 0: + current_batch_size = len(attribute_values) + else: + assert len(attribute_values) == current_batch_size + + for i in range(current_batch_size): + if name == "dynamic_shape": + assert values[name][index + i] == attribute_values[i] + else: + assert (values[name][index + + i] == attribute_values[i]).all() + + index += current_batch_size