Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion rust/benches/my_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use criterion::{criterion_group, criterion_main, Criterion};
use glob::glob;
use sedpack_rs::batch_iteration::BatchIterator;
use sedpack_rs::example_iteration::{
get_shard_progress, CompressionType, ExampleIterator, ShardInfo,
};
Expand All @@ -31,6 +32,17 @@ pub fn get_shard_files() -> Vec<ShardInfo> {
shard_infos
}

pub fn batch_iterator_benchmark(c: &mut Criterion) {
let shard_infos = get_shard_files();
c.bench_function("BatchIterator", |b| {
b.iter(|| {
for batch in BatchIterator::new(shard_infos.clone(), 12, 32, vec![true, true]) {
let _ = std::hint::black_box(batch);
}
})
});
}

pub fn example_iterator_benchmark(c: &mut Criterion) {
let shard_infos = get_shard_files();
c.bench_function("ExampleIterator", |b| {
Expand All @@ -55,5 +67,10 @@ pub fn parallel_map_benchmark(c: &mut Criterion) {
});
}

criterion_group!(benches, example_iterator_benchmark, parallel_map_benchmark,);
criterion_group!(
benches,
batch_iterator_benchmark,
example_iterator_benchmark,
parallel_map_benchmark,
);
criterion_main!(benches);
101 changes: 101 additions & 0 deletions rust/src/batch_iteration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub use super::example_iteration::{
get_shard_progress, CompressionType, Example, ExampleIterator, ShardInfo, ShardProgress,
};
pub use super::parallel_map::parallel_map;
pub use super::shard_generated::sedpack::io::flatbuffer::shardfile::{root_as_shard, Shard};

/// Single attribute which has been batched.
pub enum BatchedAttribute {
/// Row-major order batch of the attribute with static (fixed) size. That is in NumPy C-order
/// we can index as data[batch_index][attribute_index] where batch_index in 0..batch_size and
/// attribute_index in 0..len(attribute).
Static { data: numpy::ndarray::Array<u8, numpy::Ix1> },
/// Dynamic data where we do not know shape up front (e.g., string, bytearray) is represented
/// as a vector with the same indexing semantic.
Dynamic { data: Vec<numpy::ndarray::Array<u8, numpy::Ix1>> },
}

pub type Batch = Vec<BatchedAttribute>;

struct Batcher {
example_iterator: Box<dyn Iterator<Item = Example> + Send>,
batch_size: usize,
has_fixed_shape: Vec<bool>,
}

impl Iterator for Batcher {
type Item = Batch;

fn next(&mut self) -> Option<Self::Item> {
// Collect examples.
let cache: Vec<Example> = self.example_iterator.by_ref().take(self.batch_size).collect();

// Decide if we have enough (the last batch might not have batch_size examples).
if cache.is_empty() {
return None;
}

// Batch the examples.
let mut result = Batch::new();
for (attribute_index, is_fixed) in self.has_fixed_shape.iter().enumerate() {
// Collect batched version of current attribute across all cached examples.
let current_batched_attribute = match is_fixed {
true => BatchedAttribute::Static {
data: numpy::ndarray::Array::<u8, numpy::Ix1>::from_iter(
cache.iter().flat_map(|e| e[attribute_index].iter().cloned()),
),
},
false => BatchedAttribute::Dynamic {
data: cache
.iter()
.map(|e| {
numpy::ndarray::Array::<u8, numpy::Ix1>::from_iter(
e[attribute_index].iter().cloned(),
)
})
.collect(),
},
};

// Save the batched attribute.
result.push(current_batched_attribute);
}
Some(result)
}
}

pub struct BatchIterator {
batch_iterator: Box<dyn Iterator<Item = Batch> + Send>,
}

impl BatchIterator {
pub fn new(
files: Vec<ShardInfo>, threads: usize, batch_size: usize, has_fixed_shape: Vec<bool>,
) -> Self {
let example_iterator = Box::new(ExampleIterator::new(files, threads));
let batch_iterator = Box::new(Batcher { example_iterator, batch_size, has_fixed_shape });
BatchIterator { batch_iterator }
}
}

impl Iterator for BatchIterator {
type Item = Batch;

fn next(&mut self) -> Option<Self::Item> {
self.batch_iterator.next()
}
}
11 changes: 8 additions & 3 deletions rust/src/example_iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ impl ExampleIterator {
/// `files: impl Iterator<Item = &str>`.
pub fn new(files: Vec<ShardInfo>, threads: usize) -> Self {
let example_iterator = Box::new(
parallel_map(|x| get_shard_progress(&x), files.into_iter(), threads).flatten(),
parallel_map(
|x| get_shard_progress(&x).collect::<Vec<Example>>(),
files.into_iter(),
threads,
)
.flatten(),
);
ExampleIterator { example_iterator }
}
Expand Down Expand Up @@ -142,7 +147,7 @@ fn read_to_end(mut reader: impl std::io::Read) -> Vec<u8> {
}

/// Get ShardProgress.
pub fn get_shard_progress(shard_info: &ShardInfo) -> Vec<Example> {
pub fn get_shard_progress(shard_info: &ShardInfo) -> ShardProgress {
let file_bytes = get_file_bytes(shard_info);

// A shard is a vector of examples (positive number -- invariant kept by Python code).
Expand All @@ -156,7 +161,7 @@ pub fn get_shard_progress(shard_info: &ShardInfo) -> Vec<Example> {
// Number of examples might be different in different shards.
let total_examples = shard.get().examples().unwrap().len();

ShardProgress { total_examples, used_examples: 0, shard }.collect()
ShardProgress { total_examples, used_examples: 0, shard }
}

/// Get single example out of a ShardProgress.
Expand Down
146 changes: 145 additions & 1 deletion rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Google LLC
// Copyright 2024-2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@ pub use shard_generated::sedpack::io::flatbuffer::shardfile::{
root_as_shard, root_as_shard_unchecked, Attribute, Example, Shard,
};

pub mod batch_iteration;
pub mod example_iteration;
pub mod parallel_map;
// Import the autogenerated code for parsing a shard represented as a FlatBuffer.
Expand Down Expand Up @@ -150,9 +151,152 @@ mod static_iter {
}
}

/// Python wrappers around `example_iteration`.
mod static_batched_iter {
use std::collections::HashMap;
use std::str::FromStr;

use numpy::IntoPyArray;
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyRefMut};

use super::batch_iteration::{BatchIterator, BatchedAttribute};
use super::example_iteration::{CompressionType, ShardInfo};

/// Implementation details: The goal is to own the BatchIterator in Rust and only send
/// examples to Python. This helps with concurrent reading and parsing of shard files.
/// Moreover Python code cannot compromise integrity of the data structures.
///
/// - We need support for multiple BatchIterator's at the same time since during training the
/// train and validation split are being read in an interleaved manner. To support this each
/// RustIter instance keeps a `static_index` determining which `BatchIterator` it is using
/// (dispatch done using a HashMap).
/// - Since a `HashMap` cannot be instantiated static we use an LazyLock<Mutex<HashMap>>>.
/// - Using a mutex to avoid the need to use unsafe for a static mutable variable. The overhead
/// should be negligible since only a single thread is expected to access this.
/// - Python does not guarantee that __del__ is called right away (or at all). Thus RustIter
/// also implements a context manager which is guaranteed to call __exit__ and drop memory
/// owned by the corresponding BatchIterator.
static STATIC_ITERATORS: std::sync::LazyLock<std::sync::Mutex<HashMap<i32, BatchIterator>>> =
std::sync::LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));

#[pyclass]
pub struct BatchedRustIter {
/// Which BatchIterator are we interacting with (unique id). Experimental API expect
/// breaking changes.
static_index: i32,
/// Read only value. For iteration we use this object as a context manager which allows us
/// to free resources in STATIC_ITERATORS on the call of __exit__.
///
/// Alternatives considered:
/// - __del__ is not yet supported by pyo3 and also not guaranteed to be called by Python.
#[pyo3(get)]
can_iterate: bool,
}

impl Iterator for BatchedRustIter {
type Item = <BatchIterator as Iterator>::Item;

fn next(&mut self) -> Option<Self::Item> {
// TODO move println to logging.
if !self.can_iterate {
println!(
"Use the context manager to enable iteration and guaranteed memory \
deallocation"
);
return None;
}
let mut hash_map = STATIC_ITERATORS.lock().unwrap();
let iter = hash_map
.get_mut(&self.static_index)
.expect("The static_index was not found among the STATIC_ITERATORS.");
iter.next()
}
}

#[pymethods]
impl BatchedRustIter {
#[new]
fn new(
files: Vec<String>, threads: usize, compression: String, batch_size: usize,
has_fixed_shape: Vec<bool>,
) -> Self {
let static_index = rand::random();
let mut hash_map = STATIC_ITERATORS.lock().unwrap();
let compression_type = CompressionType::from_str(&compression).unwrap();
let shard_infos = files
.into_iter()
.map(|file_path| ShardInfo { file_path, compression_type })
.collect();
hash_map.insert(
static_index,
BatchIterator::new(shard_infos, threads, batch_size, has_fixed_shape),
);

BatchedRustIter { static_index, can_iterate: false }
}

#[staticmethod]
fn supported_compressions() -> Vec<String> {
CompressionType::supported_compressions()
}

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

/// Yields another batch of examples. Attributes are batched in the following way:
///
/// - static (fixed shape) is row-major order (C-order) numpy array batch of attribute
/// values.
///
/// - dynamic (unknown shape, e.g., str, bytes) list of numpy arrays.
fn __next__<'py>(mut slf: PyRefMut<'py, Self>) -> Option<Bound<'py, pyo3::types::PyList>> {
match slf.next() {
None => None,
Some(result) => {
let elements: Vec<Bound<'py, PyAny>> = result
.into_iter()
.map(|batched_attribute| match batched_attribute {
BatchedAttribute::Static { data } => {
data.into_pyarray(slf.py()).into_any()
}
BatchedAttribute::Dynamic { data } => pyo3::types::PyList::new(
slf.py(),
data.into_iter().map(|e| e.into_pyarray(slf.py())),
)
.unwrap()
.into_any(),
})
.collect();
Some(pyo3::types::PyList::new(slf.py(), elements).unwrap())
}
}
}

/// The implementation is reentrant. If changing also change
/// `sedpack.io.dataset_iteration.RustGenerator`.
fn __enter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
slf.can_iterate = true;
slf
}

fn __exit__(
mut slf: PyRefMut<'_, Self>, _exc_type: &Bound<'_, PyAny>, _exc_val: &Bound<'_, PyAny>,
_exc_tb: &Bound<'_, PyAny>,
) {
slf.can_iterate = false;
// Drop from STATIC_ITERATORS.
let mut hash_map = STATIC_ITERATORS.lock().unwrap();
drop(hash_map.remove(&slf.static_index));
}
}
}

/// A Python module implemented in Rust.
#[pymodule]
fn _sedpack_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<static_iter::RustIter>()?;
m.add_class::<static_batched_iter::BatchedRustIter>()?;
Ok(())
}
2 changes: 2 additions & 0 deletions src/sedpack/io/iteration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
"""Dataset iteration."""

from sedpack.io.iteration.rust_batched_generator import RustBatchedGenerator
from sedpack.io.iteration.rust_generator import RustGenerator

__all__ = [
"RustBatchedGenerator",
"RustGenerator",
]
Loading
Loading