diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 09e648a3f..a05e06089 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -7,6 +7,7 @@ from torch_sim.autobatching import ( BinningAutoBatcher, InFlightAutoBatcher, + _n_edges_scalers, calculate_memory_scalers, determine_max_batch_size, to_constant_volume_bins, @@ -156,6 +157,33 @@ def test_calculate_scaling_metric_mixed_pbc_uses_per_system_path( assert metric_values == pytest.approx(expected_values, rel=1e-5) +def test_n_edges_scalers_periodic(si_sim_state: ts.SimState) -> None: + """n_edges scalers for a single periodic system have correct shape and type.""" + result = _n_edges_scalers(si_sim_state, cutoff=5.0) + assert isinstance(result, list) + assert len(result) == si_sim_state.n_systems + assert all(isinstance(v, float) for v in result) + assert all(v >= 0 for v in result) + + +def test_n_edges_scalers_non_periodic(benzene_sim_state: ts.SimState) -> None: + """n_edges scalers for a non-periodic (molecular) system have correct shape/type.""" + result = _n_edges_scalers(benzene_sim_state, cutoff=5.0) + assert isinstance(result, list) + assert len(result) == benzene_sim_state.n_systems + assert all(isinstance(v, float) for v in result) + assert all(v >= 0 for v in result) + + +def test_n_edges_scalers_batched(ar_double_sim_state: ts.SimState) -> None: + """n_edges scalers for a batched state return one value per system.""" + result = _n_edges_scalers(ar_double_sim_state, cutoff=5.0) + assert isinstance(result, list) + assert len(result) == ar_double_sim_state.n_systems + assert all(isinstance(v, float) for v in result) + assert all(v >= 0 for v in result) + + @pytest.mark.parametrize("items", [[], {}]) def test_to_constant_volume_bins_empty_input( items: list[Any] | dict[int, float], @@ -226,6 +254,38 @@ def test_binning_auto_batcher( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) +def test_binning_auto_batcher_n_edges( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with n_edges memory metric.""" + states = [si_sim_state, fe_supercell_sim_state] + cutoff = 5.0 + + # Pre-compute scalers to set a meaningful max_memory_scaler + scalers = [_n_edges_scalers(s, cutoff)[0] for s in states] + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_edges", + cutoff=cutoff, + max_memory_scaler=sum(scalers) + 1, + ) + batcher.load_states(states) + + assert len(batcher.memory_scalers) == len(states) + assert all(isinstance(v, float) for v in batcher.memory_scalers) + assert batcher.memory_scalers == scalers + + batches = [batch for batch, _ in batcher] + restored_states = batcher.restore_original_order(batches) + + assert len(restored_states) == len(states) + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + def test_binning_auto_batcher_auto_metric( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 36902eb7f..b77947f38 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -28,7 +28,8 @@ import torch_sim as ts from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState +from torch_sim.neighbors import torchsim_nl +from torch_sim.state import SimState, pbc_to_tensor, require_system_idx from torch_sim.typing import MemoryScaling @@ -324,15 +325,31 @@ def determine_max_batch_size( return sizes[-1] +def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: + """Return per-system edge counts from the neighbor list as memory scalers.""" + cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device) + _, system_mapping, _ = torchsim_nl( + positions=state.positions, + cell=state.cell, + pbc=pbc_to_tensor(state.pbc, state.device), + cutoff=cutoff_tensor, + system_idx=require_system_idx(state.system_idx), + ) + return system_mapping.bincount(minlength=state.n_systems).float().tolist() + + def calculate_memory_scalers( state: SimState, memory_scales_with: MemoryScaling = "n_atoms_x_density", + cutoff: float = 7.0, ) -> list[float]: """Calculate a metric that estimates memory requirements for each system in a state. Provides different scaling metrics that correlate with memory usage. Models with radial neighbor cutoffs generally scale with "n_atoms_x_density", while models with a fixed number of neighbors scale with "n_atoms". + For molecular systems, "n_edges" gives the most accurate estimate by computing + the actual neighbor list edge count using the provided cutoff. The choice of metric can significantly impact the accuracy of memory requirement estimations for different types of simulation systems. @@ -342,11 +359,16 @@ def calculate_memory_scalers( Args: state (SimState): State to calculate metric for, with shape information specific to the SimState instance. - memory_scales_with ("n_atoms_x_density" | "n_atoms"): Type of metric - to use. "n_atoms" uses only atom count and is suitable for models that - have a fixed number of neighbors. "n_atoms_x_density" uses atom count - multiplied by number density and is better for models with radial cutoffs - Defaults to "n_atoms_x_density". + memory_scales_with ("n_atoms_x_density" | "n_atoms" | "n_edges"): Type of + metric to use. "n_atoms" uses only atom count and is suitable for models + that have a fixed number of neighbors. "n_atoms_x_density" uses atom count + multiplied by number density and is better for models with radial cutoffs. + "n_edges" computes the actual neighbor list edge count, which is the most + accurate metric overall but more expensive to compute than the alternatives; + strongly recommended for molecular systems. Defaults to "n_atoms_x_density". + cutoff (float): Neighbor list cutoff distance in Angstroms. Only used when + memory_scales_with="n_edges". Should match the model's cutoff for best + accuracy. Defaults to 7.0. Returns: list[float]: Calculated metric value for each system. @@ -361,6 +383,11 @@ def calculate_memory_scalers( # Calculate memory scaling factor based on atom count and density metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms_x_density") + + # Calculate memory scaling factor based on actual neighbor list edge count + metrics = calculate_memory_scalers( + state, memory_scales_with="n_edges", cutoff=5.0 + ) """ if memory_scales_with == "n_atoms": return state.n_atoms_per_system.tolist() @@ -410,6 +437,8 @@ def calculate_memory_scalers( volume = bbox.prod() / 1000 scalers.append(system_state.n_atoms * (system_state.n_atoms / volume.item())) return scalers + if memory_scales_with == "n_edges": + return _n_edges_scalers(state, cutoff) raise ValueError( f"Invalid metric: {memory_scales_with}, must be one of {get_args(MemoryScaling)}" ) @@ -522,6 +551,7 @@ def __init__( model: ModelInterface, *, memory_scales_with: MemoryScaling = "n_atoms_x_density", + cutoff: float = 7.0, max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, @@ -533,11 +563,16 @@ def __init__( Args: model (ModelInterface): Model to batch for, used to estimate memory requirements. - memory_scales_with ("n_atoms" | "n_atoms_x_density"): Metric to use - for estimating memory requirements: + memory_scales_with ("n_atoms" | "n_atoms_x_density" | "n_edges"): Metric to + use for estimating memory requirements: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density + - "n_edges": Uses actual neighbor list edge count; most accurate overall + but more expensive; strongly recommended for molecular systems Defaults to "n_atoms_x_density". + cutoff (float): Neighbor list cutoff in Angstroms. Only used when + memory_scales_with="n_edges". Should match the model's cutoff. + Defaults to 7.0. max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. max_atoms_to_try (int): Maximum number of atoms to try when estimating @@ -555,6 +590,7 @@ def __init__( self.max_memory_scaler = max_memory_scaler self.max_atoms_to_try = max_atoms_to_try self.memory_scales_with = memory_scales_with + self.cutoff = cutoff self.model = model self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding @@ -595,7 +631,9 @@ def load_states(self, states: T | Sequence[T]) -> float: batched = ( states if isinstance(states, SimState) else ts.concatenate_states(states) ) - self.memory_scalers = calculate_memory_scalers(batched, self.memory_scales_with) + self.memory_scalers = calculate_memory_scalers( + batched, self.memory_scales_with, self.cutoff + ) if not self.max_memory_scaler: self.max_memory_scaler = estimate_max_memory_scaler( batched, @@ -799,6 +837,7 @@ def __init__( model: ModelInterface, *, memory_scales_with: MemoryScaling = "n_atoms_x_density", + cutoff: float = 7.0, max_memory_scaler: float | None = None, max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, @@ -811,11 +850,16 @@ def __init__( Args: model (ModelInterface): Model to batch for, used to estimate memory requirements. - memory_scales_with ("n_atoms" | "n_atoms_x_density"): Metric to use - for estimating memory requirements: + memory_scales_with ("n_atoms" | "n_atoms_x_density" | "n_edges"): Metric to + use for estimating memory requirements: - "n_atoms": Uses only atom count - "n_atoms_x_density": Uses atom count multiplied by number density + - "n_edges": Uses actual neighbor list edge count; most accurate overall + but more expensive; strongly recommended for molecular systems Defaults to "n_atoms_x_density". + cutoff (float): Neighbor list cutoff in Angstroms. Only used when + memory_scales_with="n_edges". Should match the model's cutoff. + Defaults to 7.0. max_memory_scaler (float | None): Maximum metric value allowed per system. If None, will be automatically estimated. Defaults to None. max_atoms_to_try (int): Maximum number of atoms to try when estimating @@ -835,6 +879,7 @@ def __init__( """ self.model = model self.memory_scales_with = memory_scales_with + self.cutoff = cutoff self.max_memory_scaler = max_memory_scaler or None self.max_atoms_to_try = max_atoms_to_try self.memory_scaling_factor = memory_scaling_factor @@ -903,7 +948,9 @@ def _get_next_states(self) -> list[T]: new_idx: list[int] = [] new_states: list[T] = [] for state in self.states_iterator: - metric = calculate_memory_scalers(state, self.memory_scales_with)[0] + metric = calculate_memory_scalers( + state, self.memory_scales_with, self.cutoff + )[0] if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator] raise ValueError( f"State {metric=} is greater than max_metric {self.max_memory_scaler}" diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 09b015361..11f84e738 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -14,7 +14,7 @@ from torch_sim.state import SimState -MemoryScaling = Literal["n_atoms_x_density", "n_atoms"] +MemoryScaling = Literal["n_atoms_x_density", "n_atoms", "n_edges"] StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "system_idx"] StateDict = dict[StateKey, torch.Tensor]