From 9cf61f04e67ded3411546cf343aa240498678be6 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sun, 1 Mar 2026 13:08:32 -0500 Subject: [PATCH 1/4] Add n_edges memory metric to auto-batching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a third memory scaling metric, "n_edges", that computes the actual neighbor list edge count via torchsim_nl. This is more accurate than n_atoms or n_atoms_x_density for molecular systems where neighbor counts vary significantly across structures. A cutoff parameter (default 7.0 Å) is threaded through both BinningAutoBatcher and InFlightAutoBatcher into calculate_memory_scalers. --- torch_sim/autobatching.py | 76 +++++++++++++++++++++++++++++++++------ torch_sim/typing.py | 2 +- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 36902eb7f..4d646e861 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -324,15 +324,42 @@ def determine_max_batch_size( return sizes[-1] +def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: + """Return per-system edge counts from the neighbor list as memory scalers.""" + from torch_sim.neighbors import torchsim_nl + + cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device) + system_idx = state.system_idx + if system_idx is None: + system_idx = torch.zeros(state.n_atoms, dtype=torch.long, device=state.device) + pbc = state.pbc + if not torch.is_tensor(pbc): + pbc_list = [pbc] * 3 if isinstance(pbc, bool) else list(pbc) + pbc = torch.tensor(pbc_list, dtype=torch.bool, device=state.device) + if pbc.ndim == 1 and state.n_systems > 1: + pbc = pbc.unsqueeze(0).expand(state.n_systems, -1) + _, system_mapping, _ = torchsim_nl( + positions=state.positions, + cell=state.cell, + pbc=pbc, + cutoff=cutoff_tensor, + system_idx=system_idx, + ) + return system_mapping.bincount(minlength=state.n_systems).tolist() + + def calculate_memory_scalers( state: SimState, memory_scales_with: MemoryScaling = "n_atoms_x_density", + cutoff: float = 7.0, ) -> list[float]: """Calculate a metric that estimates memory requirements for each system in a state. Provides different scaling metrics that correlate with memory usage. Models with radial neighbor cutoffs generally scale with "n_atoms_x_density", while models with a fixed number of neighbors scale with "n_atoms". + For molecular systems, "n_edges" gives the most accurate estimate by computing + the actual neighbor list edge count using the provided cutoff. The choice of metric can significantly impact the accuracy of memory requirement estimations for different types of simulation systems. @@ -342,11 +369,15 @@ def calculate_memory_scalers( Args: state (SimState): State to calculate metric for, with shape information specific to the SimState instance. - memory_scales_with ("n_atoms_x_density" | "n_atoms"): Type of metric - to use. "n_atoms" uses only atom count and is suitable for models that - have a fixed number of neighbors. "n_atoms_x_density" uses atom count - multiplied by number density and is better for models with radial cutoffs - Defaults to "n_atoms_x_density". + memory_scales_with ("n_atoms_x_density" | "n_atoms" | "n_edges"): Type of + metric to use. "n_atoms" uses only atom count and is suitable for models + that have a fixed number of neighbors. "n_atoms_x_density" uses atom count + multiplied by number density and is better for models with radial cutoffs. + "n_edges" computes the actual neighbor list edge count, which is the most + accurate metric for molecular systems. Defaults to "n_atoms_x_density". + cutoff (float): Neighbor list cutoff distance in Angstroms. Only used when + memory_scales_with="n_edges". Should match the model's cutoff for best + accuracy. Defaults to 7.0. Returns: list[float]: Calculated metric value for each system. @@ -361,6 +392,11 @@ def calculate_memory_scalers( # Calculate memory scaling factor based on atom count and density metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms_x_density") + + # Calculate memory scaling factor based on actual neighbor list edge count + metrics = calculate_memory_scalers( + state, memory_scales_with="n_edges", cutoff=5.0 + ) """ if memory_scales_with == "n_atoms": return state.n_atoms_per_system.tolist() @@ -410,6 +446,8 @@ def calculate_memory_scalers( volume = bbox.prod() / 1000 scalers.append(system_state.n_atoms * (system_state.n_atoms / volume.item())) return scalers + if memory_scales_with == "n_edges": + return _n_edges_scalers(state, cutoff) raise ValueError( f"Invalid metric: {memory_scales_with}, must be one of {get_args(MemoryScaling)}" ) @@ -522,6 +560,7 @@ def __init__( model: ModelInterface, *, memory_scales_with: MemoryScaling = "n_atoms_x_density", + cutoff: float = 7.0, max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, @@ -533,11 +572,15 @@ def __init__( Args: model (ModelInterface): Model to batch for, used to estimate memory requirements. - memory_scales_with ("n_atoms" | "n_atoms_x_density"): Metric to use - for estimating memory requirements: + memory_scales_with ("n_atoms" | "n_atoms_x_density" | "n_edges"): Metric to + use for estimating memory requirements: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density + - "n_edges": Uses actual neighbor list edge count (best for molecules) Defaults to "n_atoms_x_density". + cutoff (float): Neighbor list cutoff in Angstroms. Only used when + memory_scales_with="n_edges". Should match the model's cutoff. + Defaults to 7.0. max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. max_atoms_to_try (int): Maximum number of atoms to try when estimating @@ -555,6 +598,7 @@ def __init__( self.max_memory_scaler = max_memory_scaler self.max_atoms_to_try = max_atoms_to_try self.memory_scales_with = memory_scales_with + self.cutoff = cutoff self.model = model self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding @@ -595,7 +639,9 @@ def load_states(self, states: T | Sequence[T]) -> float: batched = ( states if isinstance(states, SimState) else ts.concatenate_states(states) ) - self.memory_scalers = calculate_memory_scalers(batched, self.memory_scales_with) + self.memory_scalers = calculate_memory_scalers( + batched, self.memory_scales_with, self.cutoff + ) if not self.max_memory_scaler: self.max_memory_scaler = estimate_max_memory_scaler( batched, @@ -799,6 +845,7 @@ def __init__( model: ModelInterface, *, memory_scales_with: MemoryScaling = "n_atoms_x_density", + cutoff: float = 7.0, max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, @@ -811,11 +858,15 @@ def __init__( Args: model (ModelInterface): Model to batch for, used to estimate memory requirements. - memory_scales_with ("n_atoms" | "n_atoms_x_density"): Metric to use - for estimating memory requirements: + memory_scales_with ("n_atoms" | "n_atoms_x_density" | "n_edges"): Metric to + use for estimating memory requirements: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density + - "n_edges": Uses actual neighbor list edge count (best for molecules) Defaults to "n_atoms_x_density". + cutoff (float): Neighbor list cutoff in Angstroms. Only used when + memory_scales_with="n_edges". Should match the model's cutoff. + Defaults to 7.0. max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. max_atoms_to_try (int): Maximum number of atoms to try when estimating @@ -835,6 +886,7 @@ def __init__( """ self.model = model self.memory_scales_with = memory_scales_with + self.cutoff = cutoff self.max_memory_scaler = max_memory_scaler or None self.max_atoms_to_try = max_atoms_to_try self.memory_scaling_factor = memory_scaling_factor @@ -903,7 +955,9 @@ def _get_next_states(self) -> list[T]: new_idx: list[int] = [] new_states: list[T] = [] for state in self.states_iterator: - metric = calculate_memory_scalers(state, self.memory_scales_with)[0] + metric = calculate_memory_scalers( + state, self.memory_scales_with, self.cutoff + )[0] if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator] raise ValueError( f"State {metric=} is greater than max_metric {self.max_memory_scaler}" diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 09b015361..11f84e738 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -14,7 +14,7 @@ from torch_sim.state import SimState -MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] +MemoryScaling = Literal["n_atoms_x_density", "n_atoms", "n_edges"] StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "system_idx"] StateDict = dict[StateKey, torch.Tensor] From 3b8e118d7043fd699068db655168c82cace6ce79 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sun, 1 Mar 2026 13:19:27 -0500 Subject: [PATCH 2/4] Update n_edges metric: top-level import, float cast, tests - Move torchsim_nl import from inside _n_edges_scalers to module level - Cast bincount result to float so _n_edges_scalers returns list[float] consistently with the declared return type - Add tests for _n_edges_scalers covering periodic, non-periodic, and batched states --- tests/test_autobatching.py | 28 ++++++++++++++++++++++++++++ torch_sim/autobatching.py | 14 ++++++++------ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 09e648a3f..eca982a00 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -7,6 +7,7 @@ from torch_sim.autobatching import ( BinningAutoBatcher, InFlightAutoBatcher, + _n_edges_scalers, calculate_memory_scalers, determine_max_batch_size, to_constant_volume_bins, @@ -156,6 +157,33 @@ def test_calculate_scaling_metric_mixed_pbc_uses_per_system_path( assert metric_values == pytest.approx(expected_values, rel=1e-5) +def test_n_edges_scalers_periodic(si_sim_state: ts.SimState) -> None: + """n_edges scalers for a single periodic system have correct shape and type.""" + result = _n_edges_scalers(si_sim_state, cutoff=5.0) + assert isinstance(result, list) + assert len(result) == si_sim_state.n_systems + assert all(isinstance(v, float) for v in result) + assert all(v >= 0 for v in result) + + +def test_n_edges_scalers_non_periodic(benzene_sim_state: ts.SimState) -> None: + """n_edges scalers for a non-periodic (molecular) system have correct shape/type.""" + result = _n_edges_scalers(benzene_sim_state, cutoff=5.0) + assert isinstance(result, list) + assert len(result) == benzene_sim_state.n_systems + assert all(isinstance(v, float) for v in result) + assert all(v >= 0 for v in result) + + +def test_n_edges_scalers_batched(ar_double_sim_state: ts.SimState) -> None: + """n_edges scalers for a batched state return one value per system.""" + result = _n_edges_scalers(ar_double_sim_state, cutoff=5.0) + assert isinstance(result, list) + assert len(result) == ar_double_sim_state.n_systems + assert all(isinstance(v, float) for v in result) + assert all(v >= 0 for v in result) + + @pytest.mark.parametrize("items", [[], {}]) def test_to_constant_volume_bins_empty_input( items: list[Any] | dict[int, float], diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 4d646e861..04af5109d 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -28,6 +28,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import torchsim_nl from torch_sim.state import SimState from torch_sim.typing import MemoryScaling @@ -326,8 +327,6 @@ def determine_max_batch_size( def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: """Return per-system edge counts from the neighbor list as memory scalers.""" - from torch_sim.neighbors import torchsim_nl - cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device) system_idx = state.system_idx if system_idx is None: @@ -345,7 +344,7 @@ def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: cutoff=cutoff_tensor, system_idx=system_idx, ) - return system_mapping.bincount(minlength=state.n_systems).tolist() + return system_mapping.bincount(minlength=state.n_systems).float().tolist() def calculate_memory_scalers( @@ -374,7 +373,8 @@ def calculate_memory_scalers( that have a fixed number of neighbors. "n_atoms_x_density" uses atom count multiplied by number density and is better for models with radial cutoffs. "n_edges" computes the actual neighbor list edge count, which is the most - accurate metric for molecular systems. Defaults to "n_atoms_x_density". + accurate metric overall but more expensive to compute than the alternatives; + strongly recommended for molecular systems. Defaults to "n_atoms_x_density". cutoff (float): Neighbor list cutoff distance in Angstroms. Only used when memory_scales_with="n_edges". Should match the model's cutoff for best accuracy. Defaults to 7.0. @@ -576,7 +576,8 @@ def __init__( use for estimating memory requirements: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density - - "n_edges": Uses actual neighbor list edge count (best for molecules) + - "n_edges": Uses actual neighbor list edge count; most accurate overall + but more expensive; strongly recommended for molecular systems Defaults to "n_atoms_x_density". cutoff (float): Neighbor list cutoff in Angstroms. Only used when memory_scales_with="n_edges". Should match the model's cutoff. @@ -862,7 +863,8 @@ def __init__( use for estimating memory requirements: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density - - "n_edges": Uses actual neighbor list edge count (best for molecules) + - "n_edges": Uses actual neighbor list edge count; most accurate overall + but more expensive; strongly recommended for molecular systems Defaults to "n_atoms_x_density". cutoff (float): Neighbor list cutoff in Angstroms. Only used when memory_scales_with="n_edges". Should match the model's cutoff. From f39908212024f8b3ab990ae5b34ac3705902c520 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sun, 1 Mar 2026 13:24:31 -0500 Subject: [PATCH 3/4] Satisfy ty checker in _n_edges_scalers using pbc_to_tensor and require_system_idx --- torch_sim/autobatching.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 04af5109d..b77947f38 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -29,7 +29,7 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl -from torch_sim.state import SimState +from torch_sim.state import SimState, pbc_to_tensor, require_system_idx from torch_sim.typing import MemoryScaling @@ -328,21 +328,12 @@ def determine_max_batch_size( def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: """Return per-system edge counts from the neighbor list as memory scalers.""" cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device) - system_idx = state.system_idx - if system_idx is None: - system_idx = torch.zeros(state.n_atoms, dtype=torch.long, device=state.device) - pbc = state.pbc - if not torch.is_tensor(pbc): - pbc_list = [pbc] * 3 if isinstance(pbc, bool) else list(pbc) - pbc = torch.tensor(pbc_list, dtype=torch.bool, device=state.device) - if pbc.ndim == 1 and state.n_systems > 1: - pbc = pbc.unsqueeze(0).expand(state.n_systems, -1) _, system_mapping, _ = torchsim_nl( positions=state.positions, cell=state.cell, - pbc=pbc, + pbc=pbc_to_tensor(state.pbc, state.device), cutoff=cutoff_tensor, - system_idx=system_idx, + system_idx=require_system_idx(state.system_idx), ) return system_mapping.bincount(minlength=state.n_systems).float().tolist() From d322cb0831421baae88e17b39d10db3af44e479e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sun, 1 Mar 2026 13:36:00 -0500 Subject: [PATCH 4/4] Add BinningAutoBatcher end-to-end test for n_edges metric --- tests/test_autobatching.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index eca982a00..a05e06089 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -254,6 +254,38 @@ def test_binning_auto_batcher( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) +def test_binning_auto_batcher_n_edges( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with n_edges memory metric.""" + states = [si_sim_state, fe_supercell_sim_state] + cutoff = 5.0 + + # Pre-compute scalers to set a meaningful max_memory_scaler + scalers = [_n_edges_scalers(s, cutoff)[0] for s in states] + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_edges", + cutoff=cutoff, + max_memory_scaler=sum(scalers) + 1, + ) + batcher.load_states(states) + + assert len(batcher.memory_scalers) == len(states) + assert all(isinstance(v, float) for v in batcher.memory_scalers) + assert batcher.memory_scalers == scalers + + batches = [batch for batch, _ in batcher] + restored_states = batcher.restore_original_order(batches) + + assert len(restored_states) == len(states) + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + def test_binning_auto_batcher_auto_metric( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState,