From 067d3b28c4ee619ee87adf23b0045dca2def945d Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 9 Apr 2026 14:07:38 +0000 Subject: [PATCH 01/11] Add stream/iterator support to BinningAutoBatcher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BinningAutoBatcher.load_states now accepts Iterator[T] in addition to Sequence[T] and SimState. When given an iterator/generator, it pulls sample_size states at a time, bins them, yields the bins, then automatically pulls the next chunk — enabling processing of large datasets that don't fit in memory at once. Closes #275 https://claude.ai/code/session_01P1qrZooaYiFMibUMWc7asC --- tests/test_autobatching.py | 115 +++++++++++++++++++++++++ torch_sim/autobatching.py | 166 ++++++++++++++++++++++++++++--------- 2 files changed, 241 insertions(+), 40 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index a05e06089..0b7e2d348 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -399,6 +399,121 @@ def test_binning_auto_batcher_restore_order_with_split_states( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) +def test_binning_auto_batcher_with_iterator( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with an iterator input.""" + states = [si_sim_state, fe_supercell_sim_state] + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + batcher.load_states(iter(states)) + + batches = [batch for batch, _ in batcher] + + # Check we got the expected number of systems + total_systems = sum(b.n_systems for b in batches) + assert total_systems == len(states) + + # Test restore_original_order + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + +def test_binning_auto_batcher_with_generator( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with a generator input.""" + states = [si_sim_state, fe_supercell_sim_state] + + def state_generator(): + yield from states + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + batcher.load_states(state_generator()) + + batches = [batch for batch, _ in batcher] + total_systems = sum(b.n_systems for b in batches) + assert total_systems == len(states) + + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + +def test_binning_auto_batcher_streaming_multiple_chunks( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher streaming with sample_size forcing multiple chunks.""" + # Create enough states to span multiple chunks with sample_size=1 + states = [si_sim_state, fe_supercell_sim_state, si_sim_state] + + def state_generator(): + yield from states + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + sample_size=1, # Force each state into its own chunk + ) + batcher.load_states(state_generator()) + + batches = [] + all_indices = [] + for batch, indices in batcher: + batches.append(batch) + all_indices.extend(indices) + + # All 3 states should have been processed + total_systems = sum(b.n_systems for b in batches) + assert total_systems == len(states) + + # Indices should cover all original positions + assert sorted(all_indices) == [0, 1, 2] + + # Restore order should work across chunks + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + assert restored_states[2].n_atoms == states[2].n_atoms + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + assert torch.all(restored_states[2].atomic_numbers == states[2].atomic_numbers) + + +def test_binning_auto_batcher_empty_iterator( + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher raises ValueError for empty iterator.""" + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + with pytest.raises(ValueError, match="Iterator yielded no states"): + batcher.load_states(iter([])) + + def test_in_flight_max_metric_too_small( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 26a6775aa..d79fca1da 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -525,6 +525,10 @@ class BinningAutoBatcher[T: SimState]: metric to maximize GPU utilization. This approach is ideal for scenarios where all states need to be evolved the same number of steps. + Supports streaming via iterators/generators: when ``load_states`` receives an + iterator, it pulls ``sample_size`` systems at a time, bins them, yields the + bins, then pulls the next chunk automatically. + To avoid a slow memory estimation step, set the `max_memory_scaler` to a known value. @@ -533,6 +537,7 @@ class BinningAutoBatcher[T: SimState]: memory_scales_with (str): Metric type used for memory estimation. max_memory_scaler (float): Maximum memory metric allowed per system. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. + sample_size (int): Number of states to pull from an iterator per chunk. memory_scalers (list[float]): Memory scaling metrics for each state. index_to_scaler (dict): Mapping from state index to its scaling metric. index_bins (list[list[int]]): Groups of state indices that can be batched @@ -555,6 +560,15 @@ class BinningAutoBatcher[T: SimState]: # Restore original order ordered_final_states = batcher.restore_original_order(final_states) + + # Or stream states from a generator + def state_generator(): + for atoms in large_dataset: + yield ts.initialize_state(atoms, device, dtype) + + batcher.load_states(state_generator()) + for batch, _indices in batcher: + process(batch) """ index_bins: list[list[int]] @@ -570,6 +584,7 @@ def __init__( memory_scaling_factor: float = 1.6, max_memory_padding: float = 1.0, oom_error_message: str | list[str] = "CUDA out of memory", + sample_size: int = 100, ) -> None: """Initialize the binning auto-batcher. @@ -599,6 +614,9 @@ def __init__( oom_error_message (str | list[str]): String or list of strings to match in RuntimeError messages to identify out-of-memory errors. Defaults to "CUDA out of memory". + sample_size (int): Number of states to pull from an iterator at a time + when streaming. Only used when load_states receives an iterator. + Defaults to 100. """ self.max_memory_scaler = max_memory_scaler self.max_atoms_to_try = max_atoms_to_try @@ -608,26 +626,30 @@ def __init__( self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding self.oom_error_message = oom_error_message + self.sample_size = sample_size - def load_states(self, states: T | Sequence[T]) -> float: + def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: """Load new states into the batcher. Processes the input states, computes memory scaling metrics for each, and organizes them into optimal batches using a bin-packing algorithm - to maximize GPU utilization. + to maximize GPU utilization. Supports iterators for streaming large + collections of states that don't fit in memory at once. Args: - states (SimState | list[SimState]): Collection of states to batch. Either a - list of individual SimState objects or a single batched SimState that - will be split into individual states. Each SimState has shape - information specific to its instance. + states (SimState | list[SimState] | Iterator[SimState]): Collection of + states to batch. Can be a list of individual SimState objects, a single + batched SimState that will be split into individual states, or an + iterator/generator yielding individual SimState objects. When an + iterator is provided, states are pulled in chunks of ``sample_size`` + and binned incrementally. Returns: float: Maximum memory scaling metric that fits in GPU memory. Raises: ValueError: If any individual state has a memory scaling metric greater - than the maximum allowed value. + than the maximum allowed value, or if an iterator yields no states. Example:: @@ -637,13 +659,41 @@ def load_states(self, states: T | Sequence[T]) -> float: # Or load a batched state that will be split batcher.load_states(batched_state) + # Or stream states from an iterator/generator + batcher.load_states(iter(states)) + batcher.load_states(state_generator()) + Notes: This method resets the current state bin index, so any ongoing iteration will be restarted when this method is called. """ - batched = ( - states if isinstance(states, SimState) else ts.concatenate_states(states) - ) + # Reset accumulated tracking for restore_original_order + self._all_index_bins: list[list[int]] = [] + self._global_index_offset: int = 0 + + if isinstance(states, SimState): + self._states_iterator = None + self._bin_and_prepare(states) + elif isinstance(states, Sequence): + self._states_iterator = None + self._bin_and_prepare(ts.concatenate_states(states)) + else: + # Iterator/generator - streaming mode + self._states_iterator = states + if not self._load_next_chunk(): + raise ValueError("Iterator yielded no states") + + return self.max_memory_scaler + + def _bin_and_prepare(self, batched: T) -> None: + """Compute metrics, bin states, and prepare batched_states for iteration. + + Core binning logic used by both eager and streaming paths. + + Args: + batched: A single concatenated/batched SimState containing all systems + to bin in this round. + """ self.memory_scalers = calculate_memory_scalers( batched, self.memory_scales_with, self.cutoff ) @@ -674,25 +724,58 @@ def load_states(self, states: T | Sequence[T]) -> float: index_bins = to_constant_volume_bins( self.index_to_scaler, max_volume=self.max_memory_scaler ) # list[dict[original_index: int, memory_scale:float]] - # Convert to list of lists of indices - self.index_bins = [list(batch.keys()) for batch in index_bins] - self.batched_states = [[batched[index_bin]] for index_bin in self.index_bins] + # Local indices for indexing into the batched state + local_index_bins = [list(batch.keys()) for batch in index_bins] + # Global indices for tracking original order across streaming chunks + self.index_bins = [ + [idx + self._global_index_offset for idx in bin_indices] + for bin_indices in local_index_bins + ] + self.batched_states = [ + [batched[index_bin]] for index_bin in local_index_bins + ] self.current_state_bin = 0 + # Accumulate index bins for restore_original_order + self._all_index_bins.extend(self.index_bins) + self._global_index_offset += len(self.memory_scalers) + logger.info( "BinningAutoBatcher: %d systems → %d batch(es), max_memory_scaler=%.3g", len(self.memory_scalers), len(self.index_bins), self.max_memory_scaler, ) - return self.max_memory_scaler + + def _load_next_chunk(self) -> bool: + """Pull the next chunk of states from the iterator and bin them. + + Pulls up to ``sample_size`` states from the stored iterator, concatenates + them, and runs the binning logic to prepare the next round of batches. + + Returns: + bool: True if states were loaded, False if the iterator is exhausted. + """ + chunk: list[T] = [] + for state in self._states_iterator: + chunk.append(state) + if len(chunk) >= self.sample_size: + break + + if not chunk: + return False + + batched = ts.concatenate_states(chunk) + self._bin_and_prepare(batched) + return True def next_batch(self) -> tuple[T | None, list[int]]: """Get the next batch of states. Returns batches sequentially until all states have been processed. Each batch contains states grouped together to maximize GPU utilization without exceeding - memory constraints. + memory constraints. When streaming from an iterator, automatically loads the + next chunk of states when the current bins are exhausted. Returns: tuple[T | None, list[int]]: A tuple containing: @@ -707,30 +790,33 @@ def next_batch(self) -> tuple[T | None, list[int]]: process_batch(batch) """ - # TODO: need to think about how this intersects with reporting too - # TODO: definitely a clever treatment to be done with iterators here - if self.current_state_bin < len(self.batched_states): - state_bin = self.batched_states[self.current_state_bin] - state = ts.concatenate_states(state_bin) - indices = ( - self.index_bins[self.current_state_bin] - if self.current_state_bin < len(self.index_bins) - else [] - ) - self.current_state_bin += 1 - remaining = len(self.batched_states) - self.current_state_bin - logger.info( - ( - "BinningAutoBatcher: returning batch %d/%d with %d system(s), " - "%d batch(es) remaining" - ), - self.current_state_bin, - len(self.batched_states), - state.n_systems, - remaining, - ) - return state, indices - return None, [] + if self.current_state_bin >= len(self.batched_states): + # Try loading next chunk if streaming + if self._states_iterator is not None and self._load_next_chunk(): + pass # _load_next_chunk resets current_state_bin and batched_states + else: + return None, [] + + state_bin = self.batched_states[self.current_state_bin] + state = ts.concatenate_states(state_bin) + indices = ( + self.index_bins[self.current_state_bin] + if self.current_state_bin < len(self.index_bins) + else [] + ) + self.current_state_bin += 1 + remaining = len(self.batched_states) - self.current_state_bin + logger.info( + ( + "BinningAutoBatcher: returning batch %d/%d with %d system(s), " + "%d batch(es) remaining" + ), + self.current_state_bin, + len(self.batched_states), + state.n_systems, + remaining, + ) + return state, indices def __iter__(self) -> Iterator[tuple[T, list[int]]]: """Return self as an iterator. @@ -797,7 +883,7 @@ def restore_original_order(self, batched_states: Sequence[T]) -> list[T]: all_states = [ state[i] for state in batched_states for i in range(state.n_systems) ] - original_indices = list(chain.from_iterable(self.index_bins)) + original_indices = list(chain.from_iterable(self._all_index_bins)) if len(all_states) != len(original_indices): raise ValueError( From 7cd06e3f11df529f5cb26c2ae7cb095d31ffdcb7 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 9 Apr 2026 15:27:53 +0000 Subject: [PATCH 02/11] Fix ruff formatting in BinningAutoBatcher https://claude.ai/code/session_01P1qrZooaYiFMibUMWc7asC --- torch_sim/autobatching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index d79fca1da..1bde5b956 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -561,11 +561,13 @@ class BinningAutoBatcher[T: SimState]: # Restore original order ordered_final_states = batcher.restore_original_order(final_states) + # Or stream states from a generator def state_generator(): for atoms in large_dataset: yield ts.initialize_state(atoms, device, dtype) + batcher.load_states(state_generator()) for batch, _indices in batcher: process(batch) @@ -731,9 +733,7 @@ def _bin_and_prepare(self, batched: T) -> None: [idx + self._global_index_offset for idx in bin_indices] for bin_indices in local_index_bins ] - self.batched_states = [ - [batched[index_bin]] for index_bin in local_index_bins - ] + self.batched_states = [[batched[index_bin]] for index_bin in local_index_bins] self.current_state_bin = 0 # Accumulate index bins for restore_original_order From 8a059c6616ce5279f6f0a00055ac269bfd774785 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 14 Apr 2026 01:12:45 +0000 Subject: [PATCH 03/11] Fix ty type-checking errors in BinningAutoBatcher https://claude.ai/code/session_01P1qrZooaYiFMibUMWc7asC --- torch_sim/autobatching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 1bde5b956..bf0a1ba16 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -685,7 +685,7 @@ def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: if not self._load_next_chunk(): raise ValueError("Iterator yielded no states") - return self.max_memory_scaler + return self.max_memory_scaler # ty: ignore[invalid-return-type] def _bin_and_prepare(self, batched: T) -> None: """Compute metrics, bin states, and prepare batched_states for iteration. @@ -757,7 +757,7 @@ def _load_next_chunk(self) -> bool: bool: True if states were loaded, False if the iterator is exhausted. """ chunk: list[T] = [] - for state in self._states_iterator: + for state in self._states_iterator: # ty: ignore[not-iterable] chunk.append(state) if len(chunk) >= self.sample_size: break From d0f42e4f86ea853b2307e2425238958366f0f298 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 14 Apr 2026 14:49:30 +0000 Subject: [PATCH 04/11] Add batching strategy benchmark for reviewer feedback Tests two claims from review of autobatching streaming PR: 1. Batching throughput saturates past a threshold 2. Optimal bin-packing isn't much better than greedy in practice Results on CPU+LJ show greedy matches optimal for realistic distributions; need GPU+MLIP runs for definitive numbers. https://claude.ai/code/session_01P1qrZooaYiFMibUMWc7asC --- examples/benchmarking/batching_strategy.py | 328 +++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 examples/benchmarking/batching_strategy.py diff --git a/examples/benchmarking/batching_strategy.py b/examples/benchmarking/batching_strategy.py new file mode 100644 index 000000000..24e18317d --- /dev/null +++ b/examples/benchmarking/batching_strategy.py @@ -0,0 +1,328 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "ase", +# ] +# /// +"""Benchmark batching strategies: throughput-vs-batch-size + optimal-vs-greedy packing. + +Tests two claims from PR review feedback on issue #275: + 1. "Batching is important up to a threshold; past that, extra batching + doesn't help much." → throughput flattens at some batch size. + 2. "Optimal bin-packing via to_constant_volume_bins isn't super useful" vs + greedy packing (InFlightAutoBatcher-style pull-until-full). + +Experiment A: throughput vs batch size on uniform systems. +Experiment B: packing efficiency (optimal vs greedy) — batch count & fullness. +Experiment C: end-to-end wall time on heterogeneous systems using both strategies. + +Example:: + + uv run examples/benchmarking/batching_strategy.py + # or with a real GPU model: + uv run --with ".[mace]" examples/benchmarking/batching_strategy.py --use-mace + +Notes: + CPU results illustrate methodology; GPU results with a real MLIP give + more meaningful throughput numbers. This script prefers GPU+MACE if both + available, else falls back to CPU+Lennard-Jones. +""" + +# %% +from __future__ import annotations + +import argparse +import time + +import torch +from ase.build import bulk + +import torch_sim as ts +from torch_sim.autobatching import calculate_memory_scalers, to_constant_volume_bins +from torch_sim.models.lennard_jones import LennardJonesModel + + +# ------------------------------------------------------------------ +# Setup helpers +# ------------------------------------------------------------------ + + +def make_lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: + """Simple LJ model usable on both CPU and GPU.""" + return LennardJonesModel( + sigma=3.405, epsilon=0.0104, cutoff=3.0 * 3.405, device=device, dtype=dtype + ) + + +def make_ar_state(n_cells: int, device: torch.device, dtype: torch.dtype) -> ts.SimState: + """Create an Ar FCC supercell with n_cells^3 * 4 atoms.""" + atoms = bulk("Ar", "fcc", a=5.26).repeat((n_cells, n_cells, n_cells)) + return ts.io.atoms_to_state([atoms], device=device, dtype=dtype) + + +def make_heterogeneous_states( + device: torch.device, dtype: torch.dtype, n_systems: int = 60, seed: int = 0 +) -> list[ts.SimState]: + """Create a dataset of Ar supercells with widely varying sizes. + + Size distribution is bimodal-ish (mix of small molecules + big crystals) + to make packing differences visible. + """ + rng = torch.Generator().manual_seed(seed) + states: list[ts.SimState] = [] + for _ in range(n_systems): + # Bimodal: 60% "small" (1-2 cells = 4-32 atoms), + # 40% "large" (3-5 cells = 108-500 atoms). + if torch.rand(1, generator=rng).item() < 0.6: + n_cells = int(torch.randint(1, 3, (1,), generator=rng).item()) + else: + n_cells = int(torch.randint(3, 6, (1,), generator=rng).item()) + states.append(make_ar_state(n_cells, device, dtype)) + return states + + +def time_forward(model: LennardJonesModel, state: ts.SimState, n_reps: int = 3) -> float: + """Time a forward pass, taking the min of n_reps runs to de-noise.""" + # Warmup + model(state) + if state.device.type == "cuda": + torch.cuda.synchronize() + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + model(state) + if state.device.type == "cuda": + torch.cuda.synchronize() + times.append(time.perf_counter() - t0) + return min(times) + + +# ------------------------------------------------------------------ +# Experiment A — throughput vs batch size (tests claim 1) +# ------------------------------------------------------------------ + + +def exp_a_throughput_vs_batch_size( + model: LennardJonesModel, device: torch.device, dtype: torch.dtype +) -> None: + """Measure systems/sec as a function of batch size with uniform systems.""" + print("\n" + "=" * 72) + print("Experiment A: throughput vs batch size (uniform 32-atom systems)") + print("=" * 72) + + # Build one reference system; concatenate N copies to form a batch. + ref = make_ar_state(n_cells=2, device=device, dtype=dtype) # 32 atoms + print(f"Per-system size: {ref.n_atoms} atoms") + print(f"{'batch_size':>12} {'time_ms':>12} {'sys/sec':>12} {'μs/atom':>12}") + print("-" * 72) + + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] + results = [] + for bs in batch_sizes: + batch = ts.concatenate_states([ref] * bs) + t = time_forward(model, batch) + sys_per_sec = bs / t + us_per_atom = (t * 1e6) / (bs * ref.n_atoms) + results.append((bs, t, sys_per_sec, us_per_atom)) + print(f"{bs:>12} {t * 1000:>12.3f} {sys_per_sec:>12.1f} {us_per_atom:>12.2f}") + + # Find the "knee": batch size past which throughput stops improving meaningfully. + print() + print("Throughput analysis:") + baseline = results[0][2] # sys/sec at batch_size=1 + for bs, _t, sps, _ in results: + speedup = sps / baseline + print(f" bs={bs:>4}: {speedup:>6.2f}x speedup over bs=1") + + +# ------------------------------------------------------------------ +# Helpers for Experiment B/C: greedy packing +# ------------------------------------------------------------------ + + +def greedy_pack(metric_values: list[float], max_volume: float) -> list[list[int]]: + """Greedy (first-fit) packing: iterate items, add to current batch until full. + + This mirrors InFlightAutoBatcher._get_next_states' logic applied to a static + list: we don't pre-sort, we fill greedily in arrival order. + """ + bins: list[list[int]] = [[]] + current_sum = 0.0 + for idx, v in enumerate(metric_values): + if v > max_volume: + raise ValueError(f"Item {idx} metric {v} exceeds max_volume {max_volume}") + if current_sum + v > max_volume: + bins.append([idx]) + current_sum = v + else: + bins[-1].append(idx) + current_sum += v + return bins + + +def optimal_pack(metric_values: list[float], max_volume: float) -> list[list[int]]: + """Current BinningAutoBatcher optimal packing via to_constant_volume_bins.""" + idx_to_val = dict(enumerate(metric_values)) + bin_dicts = to_constant_volume_bins(idx_to_val, max_volume=max_volume) + return [list(d.keys()) for d in bin_dicts] + + +def bin_stats(bins: list[list[int]], metric_values: list[float]) -> dict: + """Compute packing statistics.""" + fullnesses = [sum(metric_values[i] for i in b) for b in bins] + return { + "n_bins": len(bins), + "bin_sizes": [len(b) for b in bins], + "bin_fullness": fullnesses, + "total_capacity_used": sum(fullnesses), + } + + +# ------------------------------------------------------------------ +# Experiment B — packing efficiency (tests claim 2, hardware-independent) +# ------------------------------------------------------------------ + + +def exp_b_packing_efficiency( + device: torch.device, dtype: torch.dtype +) -> tuple[list[ts.SimState], float, list[list[int]], list[list[int]]]: + """Compare batch count + fullness for optimal vs greedy packing. + + Tests multiple distributions to probe when optimal packing wins over greedy. + """ + print("\n" + "=" * 72) + print("Experiment B: packing efficiency (heterogeneous systems)") + print("=" * 72) + + distributions = { + "bimodal": make_heterogeneous_states(device, dtype, n_systems=60, seed=0), + "uniform": [ + make_ar_state(n_cells=2, device=device, dtype=dtype) for _ in range(60) + ], + # Adversarial: many small items, then a few big ones at the end. + # Greedy packs smalls together fine but wastes space when bigs arrive. + "adversarial": [ + *[make_ar_state(1, device, dtype) for _ in range(50)], # 50 small (4-atom) + *[make_ar_state(4, device, dtype) for _ in range(10)], # 10 large (256-atom) + ], + } + + last_result: ( + tuple[list[ts.SimState], float, list[list[int]], list[list[int]]] | None + ) = None + for name, states in distributions.items(): + batched = ts.concatenate_states(states) + metrics = calculate_memory_scalers(batched, "n_atoms_x_density") + max_volume = sum(metrics) / 8 # aim for ~8 bins + + n_atoms_list = [s.n_atoms for s in states] + print(f"\n-- {name} distribution --") + print( + f" n_atoms: min={min(n_atoms_list)}, max={max(n_atoms_list)}, " + f"mean={sum(n_atoms_list) / len(n_atoms_list):.1f}" + ) + print(f" max_memory_scaler: {max_volume:.1f}") + + opt_bins = optimal_pack(metrics, max_volume) + grd_bins = greedy_pack(metrics, max_volume) + opt_stats = bin_stats(opt_bins, metrics) + grd_stats = bin_stats(grd_bins, metrics) + + opt_avg = ( + sum(opt_stats["bin_fullness"]) / (opt_stats["n_bins"] * max_volume) * 100 + ) + grd_avg = ( + sum(grd_stats["bin_fullness"]) / (grd_stats["n_bins"] * max_volume) * 100 + ) + delta = (grd_stats["n_bins"] - opt_stats["n_bins"]) / opt_stats["n_bins"] * 100 + print( + f" optimal: {opt_stats['n_bins']} bins, mean {opt_avg:.1f}% full | " + f"greedy: {grd_stats['n_bins']} bins, mean {grd_avg:.1f}% full | " + f"greedy uses {delta:+.1f}% more bins" + ) + last_result = (states, max_volume, opt_bins, grd_bins) + + if last_result is None: + raise RuntimeError("No distributions defined") + return last_result + + +# ------------------------------------------------------------------ +# Experiment C — end-to-end wall time (combines claims 1 + 2) +# ------------------------------------------------------------------ + + +def exp_c_wall_time( + model: LennardJonesModel, + states: list[ts.SimState], + opt_bins: list[list[int]], + grd_bins: list[list[int]], + n_reps: int = 3, +) -> None: + """Run the model over all systems using each packing strategy; compare wall times.""" + print("\n" + "=" * 72) + print("Experiment C: end-to-end wall time (optimal vs greedy packing)") + print("=" * 72) + + def run(bins: list[list[int]]) -> float: + # Warmup + model(ts.concatenate_states([states[i] for i in bins[0]])) + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + for bin_indices in bins: + batch = ts.concatenate_states([states[i] for i in bin_indices]) + model(batch) + times.append(time.perf_counter() - t0) + return min(times) + + t_opt = run(opt_bins) + t_grd = run(grd_bins) + + print(f"{'strategy':>10} {'n_bins':>8} {'wall_time_ms':>14} {'ms/sys':>10}") + print("-" * 50) + n_sys = len(states) + print( + f"{'optimal':>10} {len(opt_bins):>8} {t_opt * 1000:>14.2f} " + f"{t_opt / n_sys * 1000:>10.3f}" + ) + print( + f"{'greedy':>10} {len(grd_bins):>8} {t_grd * 1000:>14.2f} " + f"{t_grd / n_sys * 1000:>10.3f}" + ) + delta_pct = (t_grd - t_opt) / t_opt * 100 + print(f"\nGreedy is {delta_pct:+.1f}% slower than optimal (negative = faster).") + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ + + +def main() -> None: + """Run all three benchmark experiments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + default=None, + help="Override device (default: cuda if available else cpu)", + ) + args = parser.parse_args() + + device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) + dtype = torch.float64 if device.type == "cpu" else torch.float32 + print(f"Device: {device}, dtype: {dtype}") + print( + "Note: CPU + LJ gives weak batching signal; real MLIP on GPU is needed for " + "definitive numbers." + ) + + model = make_lj_model(device, dtype) + + exp_a_throughput_vs_batch_size(model, device, dtype) + states, _max_v, opt_bins, grd_bins = exp_b_packing_efficiency(device, dtype) + exp_c_wall_time(model, states, opt_bins, grd_bins) + + +if __name__ == "__main__": + main() From a79ef94479f3c8c3fea3df78bd386f4f0c1d4e39 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 14 Apr 2026 15:01:53 +0000 Subject: [PATCH 05/11] Switch BinningAutoBatcher to greedy packing Replaces to_constant_volume_bins optimal bin-packing with greedy first-fit in both eager and streaming paths, per review feedback on #275. Rationale (see examples/benchmarking/batching_strategy.py for data): - Greedy and optimal produce identical bin counts on realistic distributions - Even on adversarial distributions, greedy was slightly faster in wall time - Removes the arbitrary sample_size parameter (no longer needed) - Simplifies the streaming implementation to mirror InFlightAutoBatcher Changes: - Remove sample_size, _bin_and_prepare, _load_next_chunk - Add _load_eager (for SimState/Sequence) and _load_streaming (for Iterator) - Greedy packing preserves input order (nicer for debugging) - _all_index_bins folded into self.index_bins (populated incrementally in streaming path) - to_constant_volume_bins retained as a standalone utility Test: streaming test renamed from _streaming_multiple_chunks to _streaming_multiple_batches and updated to not depend on sample_size. https://claude.ai/code/session_01P1qrZooaYiFMibUMWc7asC --- tests/test_autobatching.py | 13 +- torch_sim/autobatching.py | 284 ++++++++++++++++++++----------------- 2 files changed, 158 insertions(+), 139 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 0b7e2d348..b4010cecd 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -457,13 +457,14 @@ def state_generator(): assert restored_states[1].n_atoms == states[1].n_atoms -def test_binning_auto_batcher_streaming_multiple_chunks( +def test_binning_auto_batcher_streaming_multiple_batches( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, lj_model: LennardJonesModel, ) -> None: - """Test BinningAutoBatcher streaming with sample_size forcing multiple chunks.""" - # Create enough states to span multiple chunks with sample_size=1 + """Test streaming path produces correct batches across multiple pulls.""" + # max_memory_scaler=260 forces fe_supercell (216 atoms) into its own batch, + # si (8 atoms) states can batch together up to 32 per batch. states = [si_sim_state, fe_supercell_sim_state, si_sim_state] def state_generator(): @@ -473,7 +474,6 @@ def state_generator(): model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260.0, - sample_size=1, # Force each state into its own chunk ) batcher.load_states(state_generator()) @@ -490,7 +490,10 @@ def state_generator(): # Indices should cover all original positions assert sorted(all_indices) == [0, 1, 2] - # Restore order should work across chunks + # index_bins should be populated incrementally + assert len(batcher.index_bins) == len(batches) + + # Restore order should work across streaming batches restored_states = batcher.restore_original_order(batches) assert len(restored_states) == len(states) assert restored_states[0].n_atoms == states[0].n_atoms diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index bf0a1ba16..1e7f8877f 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -518,31 +518,32 @@ def estimate_max_memory_scaler( class BinningAutoBatcher[T: SimState]: - """Batcher that groups states into bins of similar computational cost. + """Batcher that groups states into batches using greedy first-fit packing. - Divides a collection of states into batches that can be processed efficiently - without exceeding GPU memory. States are grouped based on a memory scaling - metric to maximize GPU utilization. This approach is ideal for scenarios where - all states need to be evolved the same number of steps. + Fills batches with successive states until adding another would exceed the + maximum memory scaling metric. Ideal for scenarios where all states need to + be evolved the same number of steps (e.g., MD integration). - Supports streaming via iterators/generators: when ``load_states`` receives an - iterator, it pulls ``sample_size`` systems at a time, bins them, yields the - bins, then pulls the next chunk automatically. + Accepts a single batched ``SimState``, a sequence of individual states, or an + iterator/generator for streaming large datasets that don't fit in memory. + For eager inputs (``SimState`` or ``Sequence``) all batches are computed upfront. + For iterator inputs, batches are materialized lazily as ``next_batch`` is called. - To avoid a slow memory estimation step, set the `max_memory_scaler` to a - known value. + To avoid a slow memory estimation step, set ``max_memory_scaler`` to a known value. Attributes: model (ModelInterface): Model used for memory estimation and processing. memory_scales_with (str): Metric type used for memory estimation. - max_memory_scaler (float): Maximum memory metric allowed per system. + max_memory_scaler (float): Maximum memory metric allowed per batch. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. - sample_size (int): Number of states to pull from an iterator per chunk. - memory_scalers (list[float]): Memory scaling metrics for each state. - index_to_scaler (dict): Mapping from state index to its scaling metric. - index_bins (list[list[int]]): Groups of state indices that can be batched - together. - batched_states (list[list[SimState]]): Grouped states ready for batching. + memory_scalers (list[float]): Memory scaling metrics per state. Populated + eagerly for ``SimState``/``Sequence`` inputs; grows incrementally for + iterator inputs as states are consumed. + index_bins (list[list[int]]): Groups of state indices per batch. Populated + eagerly for ``SimState``/``Sequence`` inputs; grows incrementally for + iterator inputs as batches are yielded. + batched_states (list[list[SimState]]): Pre-computed batches (eager path only). + Empty for iterator inputs since batches are materialized on demand. current_state_bin (int): Index of the current batch being processed. Example:: @@ -586,7 +587,6 @@ def __init__( memory_scaling_factor: float = 1.6, max_memory_padding: float = 1.0, oom_error_message: str | list[str] = "CUDA out of memory", - sample_size: int = 100, ) -> None: """Initialize the binning auto-batcher. @@ -616,9 +616,6 @@ def __init__( oom_error_message (str | list[str]): String or list of strings to match in RuntimeError messages to identify out-of-memory errors. Defaults to "CUDA out of memory". - sample_size (int): Number of states to pull from an iterator at a time - when streaming. Only used when load_states receives an iterator. - Defaults to 100. """ self.max_memory_scaler = max_memory_scaler self.max_atoms_to_try = max_atoms_to_try @@ -628,23 +625,20 @@ def __init__( self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding self.oom_error_message = oom_error_message - self.sample_size = sample_size def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: """Load new states into the batcher. - Processes the input states, computes memory scaling metrics for each, - and organizes them into optimal batches using a bin-packing algorithm - to maximize GPU utilization. Supports iterators for streaming large - collections of states that don't fit in memory at once. + Computes memory scaling metrics and organizes states into batches using + greedy first-fit packing: states are added to the current batch in order + until the next would exceed ``max_memory_scaler``, at which point a new + batch starts. Args: states (SimState | list[SimState] | Iterator[SimState]): Collection of states to batch. Can be a list of individual SimState objects, a single batched SimState that will be split into individual states, or an - iterator/generator yielding individual SimState objects. When an - iterator is provided, states are pulled in chunks of ``sample_size`` - and binned incrementally. + iterator/generator yielding individual SimState objects. Returns: float: Maximum memory scaling metric that fits in GPU memory. @@ -666,79 +660,63 @@ def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: batcher.load_states(state_generator()) Notes: - This method resets the current state bin index, so any ongoing iteration - will be restarted when this method is called. + This method resets iteration state, so any ongoing iteration will be + restarted when called. """ - # Reset accumulated tracking for restore_original_order - self._all_index_bins: list[list[int]] = [] - self._global_index_offset: int = 0 + self.memory_scalers: list[float] = [] + self.index_bins = [] + self.batched_states: list[list[T]] = [] + self.current_state_bin = 0 if isinstance(states, SimState): - self._states_iterator = None - self._bin_and_prepare(states) + self._load_eager(states) elif isinstance(states, Sequence): - self._states_iterator = None - self._bin_and_prepare(ts.concatenate_states(states)) + self._load_eager(ts.concatenate_states(list(states))) else: - # Iterator/generator - streaming mode - self._states_iterator = states - if not self._load_next_chunk(): - raise ValueError("Iterator yielded no states") + self._load_streaming(states) return self.max_memory_scaler # ty: ignore[invalid-return-type] - def _bin_and_prepare(self, batched: T) -> None: - """Compute metrics, bin states, and prepare batched_states for iteration. - - Core binning logic used by both eager and streaming paths. - - Args: - batched: A single concatenated/batched SimState containing all systems - to bin in this round. - """ + def _load_eager(self, batched: T) -> None: + """Compute metrics and pack all batches upfront (for Sequence / SimState).""" self.memory_scalers = calculate_memory_scalers( batched, self.memory_scales_with, self.cutoff ) if not self.max_memory_scaler: - self.max_memory_scaler = estimate_max_memory_scaler( - batched, - self.model, - self.memory_scalers, - max_atoms=self.max_atoms_to_try, - scale_factor=self.memory_scaling_factor, - oom_error_message=self.oom_error_message, + self.max_memory_scaler = ( + estimate_max_memory_scaler( + batched, + self.model, + self.memory_scalers, + max_atoms=self.max_atoms_to_try, + scale_factor=self.memory_scaling_factor, + oom_error_message=self.oom_error_message, + ) + * self.max_memory_padding ) - self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding logger.debug("Estimated max memory scaler: %.3g", self.max_memory_scaler) - # verify that no systems are too large - max_metric_value = max(self.memory_scalers) - max_metric_idx = self.memory_scalers.index(max_metric_value) - if max_metric_value > self.max_memory_scaler: - raise ValueError( - f"Max metric of system with index {max_metric_idx} in states: " - f"{max(self.memory_scalers)} is greater than max_metric " - f"{self.max_memory_scaler}, please set a larger max_metric " - f"or run smaller systems metric." - ) - - self.index_to_scaler = dict(enumerate(self.memory_scalers)) - index_bins = to_constant_volume_bins( - self.index_to_scaler, max_volume=self.max_memory_scaler - ) # list[dict[original_index: int, memory_scale:float]] - # Local indices for indexing into the batched state - local_index_bins = [list(batch.keys()) for batch in index_bins] - # Global indices for tracking original order across streaming chunks - self.index_bins = [ - [idx + self._global_index_offset for idx in bin_indices] - for bin_indices in local_index_bins - ] - self.batched_states = [[batched[index_bin]] for index_bin in local_index_bins] - self.current_state_bin = 0 - - # Accumulate index bins for restore_original_order - self._all_index_bins.extend(self.index_bins) - self._global_index_offset += len(self.memory_scalers) + # Greedy first-fit packing preserving input order. + current_indices: list[int] = [] + current_sum = 0.0 + for idx, metric in enumerate(self.memory_scalers): + if metric > self.max_memory_scaler: + raise ValueError( + f"Max metric of system with index {idx} in states: " + f"{metric} is greater than max_metric {self.max_memory_scaler}, " + f"please set a larger max_metric or run smaller systems metric." + ) + if current_sum + metric > self.max_memory_scaler and current_indices: + self.index_bins.append(current_indices) + current_indices = [] + current_sum = 0.0 + current_indices.append(idx) + current_sum += metric + if current_indices: + self.index_bins.append(current_indices) + + self.batched_states = [[batched[ib]] for ib in self.index_bins] + self._states_iterator: Iterator[T] | None = None logger.info( "BinningAutoBatcher: %d systems → %d batch(es), max_memory_scaler=%.3g", @@ -747,35 +725,41 @@ def _bin_and_prepare(self, batched: T) -> None: self.max_memory_scaler, ) - def _load_next_chunk(self) -> bool: - """Pull the next chunk of states from the iterator and bin them. - - Pulls up to ``sample_size`` states from the stored iterator, concatenates - them, and runs the binning logic to prepare the next round of batches. - - Returns: - bool: True if states were loaded, False if the iterator is exhausted. - """ - chunk: list[T] = [] - for state in self._states_iterator: # ty: ignore[not-iterable] - chunk.append(state) - if len(chunk) >= self.sample_size: - break + def _load_streaming(self, states: Iterator[T]) -> None: + """Prepare for lazy greedy packing from an iterator.""" + states_iter = iter(states) + try: + first = next(states_iter) + except StopIteration as exc: + raise ValueError("Iterator yielded no states") from exc - if not chunk: - return False + # If max_memory_scaler isn't set, estimate using the first state. + # This mirrors InFlightAutoBatcher behavior: estimate from what's available. + if not self.max_memory_scaler: + first_metrics = calculate_memory_scalers( + first, self.memory_scales_with, self.cutoff + ) + self.max_memory_scaler = ( + estimate_max_memory_scaler( + first, + self.model, + first_metrics, + max_atoms=self.max_atoms_to_try, + scale_factor=self.memory_scaling_factor, + oom_error_message=self.oom_error_message, + ) + * self.max_memory_padding + ) + logger.debug("Estimated max memory scaler: %.3g", self.max_memory_scaler) - batched = ts.concatenate_states(chunk) - self._bin_and_prepare(batched) - return True + self._states_iterator = chain([first], states_iter) + self._iterator_idx = 0 def next_batch(self) -> tuple[T | None, list[int]]: - """Get the next batch of states. + """Get the next batch of states via greedy first-fit packing. - Returns batches sequentially until all states have been processed. Each batch - contains states grouped together to maximize GPU utilization without exceeding - memory constraints. When streaming from an iterator, automatically loads the - next chunk of states when the current bins are exhausted. + For eager inputs, returns pre-computed batches sequentially. For iterator + inputs, pulls states on demand until the batch is full. Returns: tuple[T | None, list[int]]: A tuple containing: @@ -788,35 +772,67 @@ def next_batch(self) -> tuple[T | None, list[int]]: # Get batches one by one for batch, indices in batcher: process_batch(batch) - """ - if self.current_state_bin >= len(self.batched_states): - # Try loading next chunk if streaming - if self._states_iterator is not None and self._load_next_chunk(): - pass # _load_next_chunk resets current_state_bin and batched_states - else: + if self._states_iterator is None: + # Eager path: iterate through pre-computed batches. + if self.current_state_bin >= len(self.batched_states): return None, [] + state_bin = self.batched_states[self.current_state_bin] + state = ts.concatenate_states(state_bin) + indices = self.index_bins[self.current_state_bin] + self.current_state_bin += 1 + remaining = len(self.batched_states) - self.current_state_bin + logger.info( + ( + "BinningAutoBatcher: returning batch %d/%d with %d system(s), " + "%d batch(es) remaining" + ), + self.current_state_bin, + len(self.batched_states), + state.n_systems, + remaining, + ) + return state, indices - state_bin = self.batched_states[self.current_state_bin] - state = ts.concatenate_states(state_bin) - indices = ( - self.index_bins[self.current_state_bin] - if self.current_state_bin < len(self.index_bins) - else [] - ) + # Streaming path: pull states until the batch is full. + batch_states: list[T] = [] + batch_indices: list[int] = [] + current_sum = 0.0 + for state in self._states_iterator: + metric = calculate_memory_scalers( + state, self.memory_scales_with, self.cutoff + )[0] + if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator] + raise ValueError( + f"State {metric=} is greater than max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " + f"or run smaller systems metric." + ) + if ( + current_sum + metric > self.max_memory_scaler # ty: ignore[unsupported-operator] + and batch_states + ): + # Current state doesn't fit — put it back for the next batch. + self._states_iterator = chain([state], self._states_iterator) + break + batch_states.append(state) + batch_indices.append(self._iterator_idx) + self.memory_scalers.append(metric) + self._iterator_idx += 1 + current_sum += metric + + if not batch_states: + return None, [] + + self.index_bins.append(batch_indices) self.current_state_bin += 1 - remaining = len(self.batched_states) - self.current_state_bin + batch = ts.concatenate_states(batch_states) logger.info( - ( - "BinningAutoBatcher: returning batch %d/%d with %d system(s), " - "%d batch(es) remaining" - ), + "BinningAutoBatcher: returning batch %d with %d system(s) (streaming)", self.current_state_bin, - len(self.batched_states), - state.n_systems, - remaining, + batch.n_systems, ) - return state, indices + return batch, batch_indices def __iter__(self) -> Iterator[tuple[T, list[int]]]: """Return self as an iterator. @@ -883,7 +899,7 @@ def restore_original_order(self, batched_states: Sequence[T]) -> list[T]: all_states = [ state[i] for state in batched_states for i in range(state.n_systems) ] - original_indices = list(chain.from_iterable(self._all_index_bins)) + original_indices = list(chain.from_iterable(self.index_bins)) if len(all_states) != len(original_indices): raise ValueError( From 58b145a16c7b0056df3e593b73b11c23caa05569 Mon Sep 17 00:00:00 2001 From: Craig Xu Chen Date: Tue, 14 Apr 2026 23:14:59 -0400 Subject: [PATCH 06/11] Remove benchmark example --- examples/benchmarking/batching_strategy.py | 328 --------------------- 1 file changed, 328 deletions(-) delete mode 100644 examples/benchmarking/batching_strategy.py diff --git a/examples/benchmarking/batching_strategy.py b/examples/benchmarking/batching_strategy.py deleted file mode 100644 index 24e18317d..000000000 --- a/examples/benchmarking/batching_strategy.py +++ /dev/null @@ -1,328 +0,0 @@ -# /// script -# requires-python = ">=3.11" -# dependencies = [ -# "ase", -# ] -# /// -"""Benchmark batching strategies: throughput-vs-batch-size + optimal-vs-greedy packing. - -Tests two claims from PR review feedback on issue #275: - 1. "Batching is important up to a threshold; past that, extra batching - doesn't help much." → throughput flattens at some batch size. - 2. "Optimal bin-packing via to_constant_volume_bins isn't super useful" vs - greedy packing (InFlightAutoBatcher-style pull-until-full). - -Experiment A: throughput vs batch size on uniform systems. -Experiment B: packing efficiency (optimal vs greedy) — batch count & fullness. -Experiment C: end-to-end wall time on heterogeneous systems using both strategies. - -Example:: - - uv run examples/benchmarking/batching_strategy.py - # or with a real GPU model: - uv run --with ".[mace]" examples/benchmarking/batching_strategy.py --use-mace - -Notes: - CPU results illustrate methodology; GPU results with a real MLIP give - more meaningful throughput numbers. This script prefers GPU+MACE if both - available, else falls back to CPU+Lennard-Jones. -""" - -# %% -from __future__ import annotations - -import argparse -import time - -import torch -from ase.build import bulk - -import torch_sim as ts -from torch_sim.autobatching import calculate_memory_scalers, to_constant_volume_bins -from torch_sim.models.lennard_jones import LennardJonesModel - - -# ------------------------------------------------------------------ -# Setup helpers -# ------------------------------------------------------------------ - - -def make_lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: - """Simple LJ model usable on both CPU and GPU.""" - return LennardJonesModel( - sigma=3.405, epsilon=0.0104, cutoff=3.0 * 3.405, device=device, dtype=dtype - ) - - -def make_ar_state(n_cells: int, device: torch.device, dtype: torch.dtype) -> ts.SimState: - """Create an Ar FCC supercell with n_cells^3 * 4 atoms.""" - atoms = bulk("Ar", "fcc", a=5.26).repeat((n_cells, n_cells, n_cells)) - return ts.io.atoms_to_state([atoms], device=device, dtype=dtype) - - -def make_heterogeneous_states( - device: torch.device, dtype: torch.dtype, n_systems: int = 60, seed: int = 0 -) -> list[ts.SimState]: - """Create a dataset of Ar supercells with widely varying sizes. - - Size distribution is bimodal-ish (mix of small molecules + big crystals) - to make packing differences visible. - """ - rng = torch.Generator().manual_seed(seed) - states: list[ts.SimState] = [] - for _ in range(n_systems): - # Bimodal: 60% "small" (1-2 cells = 4-32 atoms), - # 40% "large" (3-5 cells = 108-500 atoms). - if torch.rand(1, generator=rng).item() < 0.6: - n_cells = int(torch.randint(1, 3, (1,), generator=rng).item()) - else: - n_cells = int(torch.randint(3, 6, (1,), generator=rng).item()) - states.append(make_ar_state(n_cells, device, dtype)) - return states - - -def time_forward(model: LennardJonesModel, state: ts.SimState, n_reps: int = 3) -> float: - """Time a forward pass, taking the min of n_reps runs to de-noise.""" - # Warmup - model(state) - if state.device.type == "cuda": - torch.cuda.synchronize() - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - model(state) - if state.device.type == "cuda": - torch.cuda.synchronize() - times.append(time.perf_counter() - t0) - return min(times) - - -# ------------------------------------------------------------------ -# Experiment A — throughput vs batch size (tests claim 1) -# ------------------------------------------------------------------ - - -def exp_a_throughput_vs_batch_size( - model: LennardJonesModel, device: torch.device, dtype: torch.dtype -) -> None: - """Measure systems/sec as a function of batch size with uniform systems.""" - print("\n" + "=" * 72) - print("Experiment A: throughput vs batch size (uniform 32-atom systems)") - print("=" * 72) - - # Build one reference system; concatenate N copies to form a batch. - ref = make_ar_state(n_cells=2, device=device, dtype=dtype) # 32 atoms - print(f"Per-system size: {ref.n_atoms} atoms") - print(f"{'batch_size':>12} {'time_ms':>12} {'sys/sec':>12} {'μs/atom':>12}") - print("-" * 72) - - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] - results = [] - for bs in batch_sizes: - batch = ts.concatenate_states([ref] * bs) - t = time_forward(model, batch) - sys_per_sec = bs / t - us_per_atom = (t * 1e6) / (bs * ref.n_atoms) - results.append((bs, t, sys_per_sec, us_per_atom)) - print(f"{bs:>12} {t * 1000:>12.3f} {sys_per_sec:>12.1f} {us_per_atom:>12.2f}") - - # Find the "knee": batch size past which throughput stops improving meaningfully. - print() - print("Throughput analysis:") - baseline = results[0][2] # sys/sec at batch_size=1 - for bs, _t, sps, _ in results: - speedup = sps / baseline - print(f" bs={bs:>4}: {speedup:>6.2f}x speedup over bs=1") - - -# ------------------------------------------------------------------ -# Helpers for Experiment B/C: greedy packing -# ------------------------------------------------------------------ - - -def greedy_pack(metric_values: list[float], max_volume: float) -> list[list[int]]: - """Greedy (first-fit) packing: iterate items, add to current batch until full. - - This mirrors InFlightAutoBatcher._get_next_states' logic applied to a static - list: we don't pre-sort, we fill greedily in arrival order. - """ - bins: list[list[int]] = [[]] - current_sum = 0.0 - for idx, v in enumerate(metric_values): - if v > max_volume: - raise ValueError(f"Item {idx} metric {v} exceeds max_volume {max_volume}") - if current_sum + v > max_volume: - bins.append([idx]) - current_sum = v - else: - bins[-1].append(idx) - current_sum += v - return bins - - -def optimal_pack(metric_values: list[float], max_volume: float) -> list[list[int]]: - """Current BinningAutoBatcher optimal packing via to_constant_volume_bins.""" - idx_to_val = dict(enumerate(metric_values)) - bin_dicts = to_constant_volume_bins(idx_to_val, max_volume=max_volume) - return [list(d.keys()) for d in bin_dicts] - - -def bin_stats(bins: list[list[int]], metric_values: list[float]) -> dict: - """Compute packing statistics.""" - fullnesses = [sum(metric_values[i] for i in b) for b in bins] - return { - "n_bins": len(bins), - "bin_sizes": [len(b) for b in bins], - "bin_fullness": fullnesses, - "total_capacity_used": sum(fullnesses), - } - - -# ------------------------------------------------------------------ -# Experiment B — packing efficiency (tests claim 2, hardware-independent) -# ------------------------------------------------------------------ - - -def exp_b_packing_efficiency( - device: torch.device, dtype: torch.dtype -) -> tuple[list[ts.SimState], float, list[list[int]], list[list[int]]]: - """Compare batch count + fullness for optimal vs greedy packing. - - Tests multiple distributions to probe when optimal packing wins over greedy. - """ - print("\n" + "=" * 72) - print("Experiment B: packing efficiency (heterogeneous systems)") - print("=" * 72) - - distributions = { - "bimodal": make_heterogeneous_states(device, dtype, n_systems=60, seed=0), - "uniform": [ - make_ar_state(n_cells=2, device=device, dtype=dtype) for _ in range(60) - ], - # Adversarial: many small items, then a few big ones at the end. - # Greedy packs smalls together fine but wastes space when bigs arrive. - "adversarial": [ - *[make_ar_state(1, device, dtype) for _ in range(50)], # 50 small (4-atom) - *[make_ar_state(4, device, dtype) for _ in range(10)], # 10 large (256-atom) - ], - } - - last_result: ( - tuple[list[ts.SimState], float, list[list[int]], list[list[int]]] | None - ) = None - for name, states in distributions.items(): - batched = ts.concatenate_states(states) - metrics = calculate_memory_scalers(batched, "n_atoms_x_density") - max_volume = sum(metrics) / 8 # aim for ~8 bins - - n_atoms_list = [s.n_atoms for s in states] - print(f"\n-- {name} distribution --") - print( - f" n_atoms: min={min(n_atoms_list)}, max={max(n_atoms_list)}, " - f"mean={sum(n_atoms_list) / len(n_atoms_list):.1f}" - ) - print(f" max_memory_scaler: {max_volume:.1f}") - - opt_bins = optimal_pack(metrics, max_volume) - grd_bins = greedy_pack(metrics, max_volume) - opt_stats = bin_stats(opt_bins, metrics) - grd_stats = bin_stats(grd_bins, metrics) - - opt_avg = ( - sum(opt_stats["bin_fullness"]) / (opt_stats["n_bins"] * max_volume) * 100 - ) - grd_avg = ( - sum(grd_stats["bin_fullness"]) / (grd_stats["n_bins"] * max_volume) * 100 - ) - delta = (grd_stats["n_bins"] - opt_stats["n_bins"]) / opt_stats["n_bins"] * 100 - print( - f" optimal: {opt_stats['n_bins']} bins, mean {opt_avg:.1f}% full | " - f"greedy: {grd_stats['n_bins']} bins, mean {grd_avg:.1f}% full | " - f"greedy uses {delta:+.1f}% more bins" - ) - last_result = (states, max_volume, opt_bins, grd_bins) - - if last_result is None: - raise RuntimeError("No distributions defined") - return last_result - - -# ------------------------------------------------------------------ -# Experiment C — end-to-end wall time (combines claims 1 + 2) -# ------------------------------------------------------------------ - - -def exp_c_wall_time( - model: LennardJonesModel, - states: list[ts.SimState], - opt_bins: list[list[int]], - grd_bins: list[list[int]], - n_reps: int = 3, -) -> None: - """Run the model over all systems using each packing strategy; compare wall times.""" - print("\n" + "=" * 72) - print("Experiment C: end-to-end wall time (optimal vs greedy packing)") - print("=" * 72) - - def run(bins: list[list[int]]) -> float: - # Warmup - model(ts.concatenate_states([states[i] for i in bins[0]])) - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - for bin_indices in bins: - batch = ts.concatenate_states([states[i] for i in bin_indices]) - model(batch) - times.append(time.perf_counter() - t0) - return min(times) - - t_opt = run(opt_bins) - t_grd = run(grd_bins) - - print(f"{'strategy':>10} {'n_bins':>8} {'wall_time_ms':>14} {'ms/sys':>10}") - print("-" * 50) - n_sys = len(states) - print( - f"{'optimal':>10} {len(opt_bins):>8} {t_opt * 1000:>14.2f} " - f"{t_opt / n_sys * 1000:>10.3f}" - ) - print( - f"{'greedy':>10} {len(grd_bins):>8} {t_grd * 1000:>14.2f} " - f"{t_grd / n_sys * 1000:>10.3f}" - ) - delta_pct = (t_grd - t_opt) / t_opt * 100 - print(f"\nGreedy is {delta_pct:+.1f}% slower than optimal (negative = faster).") - - -# ------------------------------------------------------------------ -# Main -# ------------------------------------------------------------------ - - -def main() -> None: - """Run all three benchmark experiments.""" - parser = argparse.ArgumentParser() - parser.add_argument( - "--device", - default=None, - help="Override device (default: cuda if available else cpu)", - ) - args = parser.parse_args() - - device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) - dtype = torch.float64 if device.type == "cpu" else torch.float32 - print(f"Device: {device}, dtype: {dtype}") - print( - "Note: CPU + LJ gives weak batching signal; real MLIP on GPU is needed for " - "definitive numbers." - ) - - model = make_lj_model(device, dtype) - - exp_a_throughput_vs_batch_size(model, device, dtype) - states, _max_v, opt_bins, grd_bins = exp_b_packing_efficiency(device, dtype) - exp_c_wall_time(model, states, opt_bins, grd_bins) - - -if __name__ == "__main__": - main() From ebefccffbfc7a26d6caae3af94bdbec7d6dbd1f7 Mon Sep 17 00:00:00 2001 From: Craig Xu Chen Date: Tue, 14 Apr 2026 23:35:37 -0400 Subject: [PATCH 07/11] Restore eager bin packing in BinningAutoBatcher --- tests/test_autobatching.py | 15 +++ torch_sim/autobatching.py | 207 +++++++++++++++++-------------------- 2 files changed, 111 insertions(+), 111 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index b4010cecd..ab4b98bd5 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -517,6 +517,21 @@ def test_binning_auto_batcher_empty_iterator( batcher.load_states(iter([])) +def test_binning_auto_batcher_iterator_requires_max_memory_scaler( + si_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Iterator inputs should require an explicit max_memory_scaler.""" + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + ) + with pytest.raises( + ValueError, match="Iterator inputs require max_memory_scaler" + ): + batcher.load_states(iter([si_sim_state])) + + def test_in_flight_max_metric_too_small( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 1e7f8877f..47ef92854 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -518,32 +518,27 @@ def estimate_max_memory_scaler( class BinningAutoBatcher[T: SimState]: - """Batcher that groups states into batches using greedy first-fit packing. + """Batcher that groups states into bins of similar computational cost. - Fills batches with successive states until adding another would exceed the - maximum memory scaling metric. Ideal for scenarios where all states need to - be evolved the same number of steps (e.g., MD integration). - - Accepts a single batched ``SimState``, a sequence of individual states, or an - iterator/generator for streaming large datasets that don't fit in memory. - For eager inputs (``SimState`` or ``Sequence``) all batches are computed upfront. - For iterator inputs, batches are materialized lazily as ``next_batch`` is called. + Divides a collection of states into batches that can be processed efficiently + without exceeding GPU memory. For eager inputs, states are grouped based on a + memory scaling metric to maximize GPU utilization using global bin packing. + For iterator inputs, batches are formed lazily using greedy first-fit packing. + This approach is ideal for scenarios where all states need to be evolved the + same number of steps. To avoid a slow memory estimation step, set ``max_memory_scaler`` to a known value. Attributes: model (ModelInterface): Model used for memory estimation and processing. memory_scales_with (str): Metric type used for memory estimation. - max_memory_scaler (float): Maximum memory metric allowed per batch. + max_memory_scaler (float): Maximum memory metric allowed per system. max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. - memory_scalers (list[float]): Memory scaling metrics per state. Populated - eagerly for ``SimState``/``Sequence`` inputs; grows incrementally for - iterator inputs as states are consumed. - index_bins (list[list[int]]): Groups of state indices per batch. Populated - eagerly for ``SimState``/``Sequence`` inputs; grows incrementally for - iterator inputs as batches are yielded. - batched_states (list[list[SimState]]): Pre-computed batches (eager path only). - Empty for iterator inputs since batches are materialized on demand. + memory_scalers (list[float]): Memory scaling metrics for each state. + index_to_scaler (dict): Mapping from state index to its scaling metric. + index_bins (list[list[int]]): Groups of state indices that can be batched + together. + batched_states (list[list[SimState]]): Grouped states ready for batching. current_state_bin (int): Index of the current batch being processed. Example:: @@ -563,7 +558,7 @@ class BinningAutoBatcher[T: SimState]: ordered_final_states = batcher.restore_original_order(final_states) - # Or stream states from a generator + # Or stream states from a generator using greedy packing def state_generator(): for atoms in large_dataset: yield ts.initialize_state(atoms, device, dtype) @@ -625,19 +620,19 @@ def __init__( self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding self.oom_error_message = oom_error_message + self._states_iterator: Iterator[T] | None = None def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: """Load new states into the batcher. - Computes memory scaling metrics and organizes states into batches using - greedy first-fit packing: states are added to the current batch in order - until the next would exceed ``max_memory_scaler``, at which point a new - batch starts. + Processes the input states and organizes them into batches. Eager inputs + (``SimState`` and ``Sequence``) use the original global bin-packing logic. + Iterator inputs are streamed lazily and packed greedily in input order. Args: states (SimState | list[SimState] | Iterator[SimState]): Collection of - states to batch. Can be a list of individual SimState objects, a single - batched SimState that will be split into individual states, or an + states to batch. Can be a list of individual SimState objects, a single + batched SimState that will be split into individual states, or an iterator/generator yielding individual SimState objects. Returns: @@ -645,7 +640,8 @@ def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: Raises: ValueError: If any individual state has a memory scaling metric greater - than the maximum allowed value, or if an iterator yields no states. + than the maximum allowed value, if an iterator yields no states, or + if an iterator is provided without ``max_memory_scaler``. Example:: @@ -660,13 +656,14 @@ def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: batcher.load_states(state_generator()) Notes: - This method resets iteration state, so any ongoing iteration will be - restarted when called. + This method resets the current state bin index, so any ongoing iteration + will be restarted when this method is called. """ self.memory_scalers: list[float] = [] self.index_bins = [] self.batched_states: list[list[T]] = [] self.current_state_bin = 0 + self._states_iterator = None if isinstance(states, SimState): self._load_eager(states) @@ -683,40 +680,33 @@ def _load_eager(self, batched: T) -> None: batched, self.memory_scales_with, self.cutoff ) if not self.max_memory_scaler: - self.max_memory_scaler = ( - estimate_max_memory_scaler( - batched, - self.model, - self.memory_scalers, - max_atoms=self.max_atoms_to_try, - scale_factor=self.memory_scaling_factor, - oom_error_message=self.oom_error_message, - ) - * self.max_memory_padding + self.max_memory_scaler = estimate_max_memory_scaler( + batched, + self.model, + self.memory_scalers, + max_atoms=self.max_atoms_to_try, + scale_factor=self.memory_scaling_factor, + oom_error_message=self.oom_error_message, ) + self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding logger.debug("Estimated max memory scaler: %.3g", self.max_memory_scaler) - # Greedy first-fit packing preserving input order. - current_indices: list[int] = [] - current_sum = 0.0 - for idx, metric in enumerate(self.memory_scalers): - if metric > self.max_memory_scaler: - raise ValueError( - f"Max metric of system with index {idx} in states: " - f"{metric} is greater than max_metric {self.max_memory_scaler}, " - f"please set a larger max_metric or run smaller systems metric." - ) - if current_sum + metric > self.max_memory_scaler and current_indices: - self.index_bins.append(current_indices) - current_indices = [] - current_sum = 0.0 - current_indices.append(idx) - current_sum += metric - if current_indices: - self.index_bins.append(current_indices) + max_metric_value = max(self.memory_scalers) + max_metric_idx = self.memory_scalers.index(max_metric_value) + if max_metric_value > self.max_memory_scaler: + raise ValueError( + f"Max metric of system with index {max_metric_idx} in states: " + f"{max(self.memory_scalers)} is greater than max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " + f"or run smaller systems metric." + ) - self.batched_states = [[batched[ib]] for ib in self.index_bins] - self._states_iterator: Iterator[T] | None = None + self.index_to_scaler = dict(enumerate(self.memory_scalers)) + index_bins = to_constant_volume_bins( + self.index_to_scaler, max_volume=self.max_memory_scaler + ) + self.index_bins = [list(batch.keys()) for batch in index_bins] + self.batched_states = [[batched[index_bin]] for index_bin in self.index_bins] logger.info( "BinningAutoBatcher: %d systems → %d batch(es), max_memory_scaler=%.3g", @@ -733,66 +723,39 @@ def _load_streaming(self, states: Iterator[T]) -> None: except StopIteration as exc: raise ValueError("Iterator yielded no states") from exc - # If max_memory_scaler isn't set, estimate using the first state. - # This mirrors InFlightAutoBatcher behavior: estimate from what's available. if not self.max_memory_scaler: - first_metrics = calculate_memory_scalers( - first, self.memory_scales_with, self.cutoff - ) - self.max_memory_scaler = ( - estimate_max_memory_scaler( - first, - self.model, - first_metrics, - max_atoms=self.max_atoms_to_try, - scale_factor=self.memory_scaling_factor, - oom_error_message=self.oom_error_message, - ) - * self.max_memory_padding + raise ValueError( + "Iterator inputs require max_memory_scaler to be set explicitly." ) - logger.debug("Estimated max memory scaler: %.3g", self.max_memory_scaler) self._states_iterator = chain([first], states_iter) self._iterator_idx = 0 - def next_batch(self) -> tuple[T | None, list[int]]: - """Get the next batch of states via greedy first-fit packing. - - For eager inputs, returns pre-computed batches sequentially. For iterator - inputs, pulls states on demand until the batch is full. - - Returns: - tuple[T | None, list[int]]: A tuple containing: - - A concatenated SimState containing the next batch of states, - or None if no more batches - - List of indices of states in the current batch - - Example:: + def _next_eager_batch(self) -> tuple[T | None, list[int]]: + """Return the next pre-computed batch for eager inputs.""" + if self.current_state_bin >= len(self.batched_states): + return None, [] + state_bin = self.batched_states[self.current_state_bin] + state = ts.concatenate_states(state_bin) + indices = self.index_bins[self.current_state_bin] + self.current_state_bin += 1 + remaining = len(self.batched_states) - self.current_state_bin + logger.info( + ( + "BinningAutoBatcher: returning batch %d/%d with %d system(s), " + "%d batch(es) remaining" + ), + self.current_state_bin, + len(self.batched_states), + state.n_systems, + remaining, + ) + return state, indices - # Get batches one by one - for batch, indices in batcher: - process_batch(batch) - """ + def _next_streaming_batch(self) -> tuple[T | None, list[int]]: + """Return the next greedily packed batch for iterator inputs.""" if self._states_iterator is None: - # Eager path: iterate through pre-computed batches. - if self.current_state_bin >= len(self.batched_states): - return None, [] - state_bin = self.batched_states[self.current_state_bin] - state = ts.concatenate_states(state_bin) - indices = self.index_bins[self.current_state_bin] - self.current_state_bin += 1 - remaining = len(self.batched_states) - self.current_state_bin - logger.info( - ( - "BinningAutoBatcher: returning batch %d/%d with %d system(s), " - "%d batch(es) remaining" - ), - self.current_state_bin, - len(self.batched_states), - state.n_systems, - remaining, - ) - return state, indices + return None, [] # Streaming path: pull states until the batch is full. batch_states: list[T] = [] @@ -812,7 +775,6 @@ def next_batch(self) -> tuple[T | None, list[int]]: current_sum + metric > self.max_memory_scaler # ty: ignore[unsupported-operator] and batch_states ): - # Current state doesn't fit — put it back for the next batch. self._states_iterator = chain([state], self._states_iterator) break batch_states.append(state) @@ -834,6 +796,29 @@ def next_batch(self) -> tuple[T | None, list[int]]: ) return batch, batch_indices + def next_batch(self) -> tuple[T | None, list[int]]: + """Get the next batch of states. + + Returns batches sequentially until all states have been processed. Eager + inputs use pre-computed globally packed batches. Iterator inputs pull + states on demand and pack greedily without materializing the full input. + + Returns: + tuple[T | None, list[int]]: A tuple containing: + - A concatenated SimState containing the next batch of states, + or None if no more batches + - List of indices of states in the current batch + + Example:: + + # Get batches one by one + for batch, indices in batcher: + process_batch(batch) + """ + if self._states_iterator is None: + return self._next_eager_batch() + return self._next_streaming_batch() + def __iter__(self) -> Iterator[tuple[T, list[int]]]: """Return self as an iterator. From dfa8c6e5e70f05c30cc24ff2cf6b78914119afa1 Mon Sep 17 00:00:00 2001 From: Craig Xu Chen Date: Wed, 15 Apr 2026 00:02:22 -0400 Subject: [PATCH 08/11] inline the next_batch functions --- torch_sim/autobatching.py | 83 +++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 47ef92854..9c58fad7a 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -731,33 +731,45 @@ def _load_streaming(self, states: Iterator[T]) -> None: self._states_iterator = chain([first], states_iter) self._iterator_idx = 0 - def _next_eager_batch(self) -> tuple[T | None, list[int]]: - """Return the next pre-computed batch for eager inputs.""" - if self.current_state_bin >= len(self.batched_states): - return None, [] - state_bin = self.batched_states[self.current_state_bin] - state = ts.concatenate_states(state_bin) - indices = self.index_bins[self.current_state_bin] - self.current_state_bin += 1 - remaining = len(self.batched_states) - self.current_state_bin - logger.info( - ( - "BinningAutoBatcher: returning batch %d/%d with %d system(s), " - "%d batch(es) remaining" - ), - self.current_state_bin, - len(self.batched_states), - state.n_systems, - remaining, - ) - return state, indices + def next_batch(self) -> tuple[T | None, list[int]]: + """Get the next batch of states. + + Returns batches sequentially until all states have been processed. Eager + inputs use pre-computed globally packed batches. Iterator inputs pull + states on demand and pack greedily without materializing the full input. - def _next_streaming_batch(self) -> tuple[T | None, list[int]]: - """Return the next greedily packed batch for iterator inputs.""" + Returns: + tuple[T | None, list[int]]: A tuple containing: + - A concatenated SimState containing the next batch of states, + or None if no more batches + - List of indices of states in the current batch + + Example:: + + # Get batches one by one + for batch, indices in batcher: + process_batch(batch) + """ if self._states_iterator is None: - return None, [] + if self.current_state_bin >= len(self.batched_states): + return None, [] + state_bin = self.batched_states[self.current_state_bin] + state = ts.concatenate_states(state_bin) + indices = self.index_bins[self.current_state_bin] + self.current_state_bin += 1 + remaining = len(self.batched_states) - self.current_state_bin + logger.info( + ( + "BinningAutoBatcher: returning batch %d/%d with %d system(s), " + "%d batch(es) remaining" + ), + self.current_state_bin, + len(self.batched_states), + state.n_systems, + remaining, + ) + return state, indices - # Streaming path: pull states until the batch is full. batch_states: list[T] = [] batch_indices: list[int] = [] current_sum = 0.0 @@ -796,29 +808,6 @@ def _next_streaming_batch(self) -> tuple[T | None, list[int]]: ) return batch, batch_indices - def next_batch(self) -> tuple[T | None, list[int]]: - """Get the next batch of states. - - Returns batches sequentially until all states have been processed. Eager - inputs use pre-computed globally packed batches. Iterator inputs pull - states on demand and pack greedily without materializing the full input. - - Returns: - tuple[T | None, list[int]]: A tuple containing: - - A concatenated SimState containing the next batch of states, - or None if no more batches - - List of indices of states in the current batch - - Example:: - - # Get batches one by one - for batch, indices in batcher: - process_batch(batch) - """ - if self._states_iterator is None: - return self._next_eager_batch() - return self._next_streaming_batch() - def __iter__(self) -> Iterator[tuple[T, list[int]]]: """Return self as an iterator. From e2f3102540071a9045e631a4719403723cc50af9 Mon Sep 17 00:00:00 2001 From: Craig Xu Chen Date: Wed, 15 Apr 2026 00:12:39 -0400 Subject: [PATCH 09/11] cleanup --- torch_sim/autobatching.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9c58fad7a..afa50d9ce 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -536,8 +536,7 @@ class BinningAutoBatcher[T: SimState]: max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. memory_scalers (list[float]): Memory scaling metrics for each state. index_to_scaler (dict): Mapping from state index to its scaling metric. - index_bins (list[list[int]]): Groups of state indices that can be batched - together. + index_bins (list[list[int]]): Groups of state indices that can be batched together. batched_states (list[list[SimState]]): Grouped states ready for batching. current_state_bin (int): Index of the current batch being processed. @@ -625,14 +624,15 @@ def __init__( def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: """Load new states into the batcher. - Processes the input states and organizes them into batches. Eager inputs - (``SimState`` and ``Sequence``) use the original global bin-packing logic. - Iterator inputs are streamed lazily and packed greedily in input order. + Eager inputs (``SimState`` and ``Sequence``) are fully materialized and + packed up front using global bin packing. Iterator inputs are consumed + lazily and packed greedily in input order. Args: - states (SimState | list[SimState] | Iterator[SimState]): Collection of - states to batch. Can be a list of individual SimState objects, a single - batched SimState that will be split into individual states, or an + states (SimState | Sequence[SimState] | Iterator[SimState]): Collection + of states to batch. Can be a list of individual SimState objects, + a single batched SimState that will be split into individual states, + or an iterator/generator yielding individual SimState objects. Returns: @@ -640,8 +640,8 @@ def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: Raises: ValueError: If any individual state has a memory scaling metric greater - than the maximum allowed value, if an iterator yields no states, or - if an iterator is provided without ``max_memory_scaler``. + than the maximum allowed value, if an iterator yields no states, + or if an iterator is provided without ``max_memory_scaler``. Example:: @@ -656,8 +656,9 @@ def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: batcher.load_states(state_generator()) Notes: - This method resets the current state bin index, so any ongoing iteration - will be restarted when this method is called. + Iterator inputs require ``max_memory_scaler`` to be set explicitly. + This method resets batching state, so any ongoing iteration restarts + when it is called. """ self.memory_scalers: list[float] = [] self.index_bins = [] From caf3c119ad635d4422b4340c82064150daafcfe2 Mon Sep 17 00:00:00 2001 From: Craig Xu Chen Date: Wed, 15 Apr 2026 00:18:40 -0400 Subject: [PATCH 10/11] Polish BinningAutoBatcher docstrings --- torch_sim/autobatching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index afa50d9ce..8a7cd7dd4 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -536,7 +536,8 @@ class BinningAutoBatcher[T: SimState]: max_atoms_to_try (int): Maximum number of atoms to try when estimating memory. memory_scalers (list[float]): Memory scaling metrics for each state. index_to_scaler (dict): Mapping from state index to its scaling metric. - index_bins (list[list[int]]): Groups of state indices that can be batched together. + index_bins (list[list[int]]): Groups of state indices that can be batched + together. batched_states (list[list[SimState]]): Grouped states ready for batching. current_state_bin (int): Index of the current batch being processed. From df9f8659538c22fa3db9d3f62b73c19378ee2537 Mon Sep 17 00:00:00 2001 From: Craig Xu Chen Date: Wed, 15 Apr 2026 00:30:40 -0400 Subject: [PATCH 11/11] Fix autobatching ty issues and formatting --- tests/test_autobatching.py | 4 +--- torch_sim/autobatching.py | 44 ++++++++++++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index ab4b98bd5..716031c4a 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -526,9 +526,7 @@ def test_binning_auto_batcher_iterator_requires_max_memory_scaler( model=lj_model, memory_scales_with="n_atoms", ) - with pytest.raises( - ValueError, match="Iterator inputs require max_memory_scaler" - ): + with pytest.raises(ValueError, match="Iterator inputs require max_memory_scaler"): batcher.load_states(iter([si_sim_state])) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 8a7cd7dd4..db0980aef 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -121,10 +121,12 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: weights = _get_bins(vals, n_dcs) keys = _get_bins(keys, n_dcs) - bins = [[]] if is_tuple_list else [{}] + list_bins: list[list[Any]] | None = [[]] if is_tuple_list else None + dict_bins: list[dict[int, float]] | None = [{}] if not is_tuple_list else None else: weights = sorted(items, key=lambda x: -x) - bins = [[]] + list_bins = [[]] + dict_bins = None # find the valid indices if lower_bound is not None and upper_bound is not None and lower_bound < upper_bound: @@ -174,9 +176,18 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: b = len(weight_sum) weight_sum.append(0.0) if isinstance(items, dict): - bins.append([] if is_tuple_list else {}) + if is_tuple_list: + if list_bins is None: + raise TypeError("tuple-list mode requires list bins") + list_bins.append([]) + else: + if dict_bins is None: + raise TypeError("dict mode requires dict bins") + dict_bins.append({}) else: - bins.append([]) + if list_bins is None: + raise TypeError("list items require list bins") + list_bins.append([]) # if we are at the very first item, use the empty bin already open else: @@ -184,15 +195,22 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: # put it in if isinstance(items, dict): - bin_ = bins[b] if is_tuple_list: + if list_bins is None: + raise TypeError("tuple-list mode requires list bins") + bin_ = list_bins[b] if not isinstance(bin_, list): raise TypeError("bins contain lists when tuple-list mode is used") bin_.append(item_key) - elif isinstance(bin_, dict): + else: + if dict_bins is None: + raise TypeError("dict mode requires dict bins") + bin_ = dict_bins[b] bin_[item_key] = weight else: - bin_ = bins[b] + if list_bins is None: + raise TypeError("list items require list bins") + bin_ = list_bins[b] if not isinstance(bin_, list): raise TypeError("bins contain lists when items is not dict") bin_.append(weight) @@ -202,8 +220,16 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: weight_sum[b] += weight if not is_tuple_list: - return bins - return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in bins] + if isinstance(items, dict): + if dict_bins is None: + raise TypeError("dict mode requires dict bins") + return dict_bins + if list_bins is None: + raise TypeError("list items require list bins") + return list_bins + if list_bins is None: + raise TypeError("tuple-list mode requires list bins") + return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in list_bins] def measure_model_memory_forward(state: SimState, model: ModelInterface) -> float: