Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_sim.autobatching import (
BinningAutoBatcher,
InFlightAutoBatcher,
_n_edges_scalers,
calculate_memory_scalers,
determine_max_batch_size,
to_constant_volume_bins,
Expand Down Expand Up @@ -156,6 +157,33 @@ def test_calculate_scaling_metric_mixed_pbc_uses_per_system_path(
assert metric_values == pytest.approx(expected_values, rel=1e-5)


def test_n_edges_scalers_periodic(si_sim_state: ts.SimState) -> None:
"""n_edges scalers for a single periodic system have correct shape and type."""
result = _n_edges_scalers(si_sim_state, cutoff=5.0)
assert isinstance(result, list)
assert len(result) == si_sim_state.n_systems
assert all(isinstance(v, float) for v in result)
assert all(v >= 0 for v in result)


def test_n_edges_scalers_non_periodic(benzene_sim_state: ts.SimState) -> None:
"""n_edges scalers for a non-periodic (molecular) system have correct shape/type."""
result = _n_edges_scalers(benzene_sim_state, cutoff=5.0)
assert isinstance(result, list)
assert len(result) == benzene_sim_state.n_systems
assert all(isinstance(v, float) for v in result)
assert all(v >= 0 for v in result)


def test_n_edges_scalers_batched(ar_double_sim_state: ts.SimState) -> None:
"""n_edges scalers for a batched state return one value per system."""
result = _n_edges_scalers(ar_double_sim_state, cutoff=5.0)
assert isinstance(result, list)
assert len(result) == ar_double_sim_state.n_systems
assert all(isinstance(v, float) for v in result)
assert all(v >= 0 for v in result)


@pytest.mark.parametrize("items", [[], {}])
def test_to_constant_volume_bins_empty_input(
items: list[Any] | dict[int, float],
Expand Down Expand Up @@ -226,6 +254,38 @@ def test_binning_auto_batcher(
assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers)


def test_binning_auto_batcher_n_edges(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with n_edges memory metric."""
states = [si_sim_state, fe_supercell_sim_state]
cutoff = 5.0

# Pre-compute scalers to set a meaningful max_memory_scaler
scalers = [_n_edges_scalers(s, cutoff)[0] for s in states]

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_edges",
cutoff=cutoff,
max_memory_scaler=sum(scalers) + 1,
)
batcher.load_states(states)

assert len(batcher.memory_scalers) == len(states)
assert all(isinstance(v, float) for v in batcher.memory_scalers)
assert batcher.memory_scalers == scalers

batches = [batch for batch, _ in batcher]
restored_states = batcher.restore_original_order(batches)

assert len(restored_states) == len(states)
assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers)
assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers)


def test_binning_auto_batcher_auto_metric(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
Expand Down
71 changes: 59 additions & 12 deletions torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

import torch_sim as ts
from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState
from torch_sim.neighbors import torchsim_nl
from torch_sim.state import SimState, pbc_to_tensor, require_system_idx
from torch_sim.typing import MemoryScaling


Expand Down Expand Up @@ -324,15 +325,31 @@ def determine_max_batch_size(
return sizes[-1]


def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]:
"""Return per-system edge counts from the neighbor list as memory scalers."""
cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device)
_, system_mapping, _ = torchsim_nl(
positions=state.positions,
cell=state.cell,
pbc=pbc_to_tensor(state.pbc, state.device),
cutoff=cutoff_tensor,
system_idx=require_system_idx(state.system_idx),
)
return system_mapping.bincount(minlength=state.n_systems).float().tolist()


def calculate_memory_scalers(
state: SimState,
memory_scales_with: MemoryScaling = "n_atoms_x_density",
cutoff: float = 7.0,
) -> list[float]:
"""Calculate a metric that estimates memory requirements for each system in a state.

Provides different scaling metrics that correlate with memory usage.
Models with radial neighbor cutoffs generally scale with "n_atoms_x_density",
while models with a fixed number of neighbors scale with "n_atoms".
For molecular systems, "n_edges" gives the most accurate estimate by computing
the actual neighbor list edge count using the provided cutoff.
The choice of metric can significantly impact the accuracy of memory requirement
estimations for different types of simulation systems.

Expand All @@ -342,11 +359,16 @@ def calculate_memory_scalers(
Args:
state (SimState): State to calculate metric for, with shape information
specific to the SimState instance.
memory_scales_with ("n_atoms_x_density" | "n_atoms"): Type of metric
to use. "n_atoms" uses only atom count and is suitable for models that
have a fixed number of neighbors. "n_atoms_x_density" uses atom count
multiplied by number density and is better for models with radial cutoffs
Defaults to "n_atoms_x_density".
memory_scales_with ("n_atoms_x_density" | "n_atoms" | "n_edges"): Type of
metric to use. "n_atoms" uses only atom count and is suitable for models
that have a fixed number of neighbors. "n_atoms_x_density" uses atom count
multiplied by number density and is better for models with radial cutoffs.
"n_edges" computes the actual neighbor list edge count, which is the most
accurate metric overall but more expensive to compute than the alternatives;
strongly recommended for molecular systems. Defaults to "n_atoms_x_density".
cutoff (float): Neighbor list cutoff distance in Angstroms. Only used when
memory_scales_with="n_edges". Should match the model's cutoff for best
accuracy. Defaults to 7.0.

Returns:
list[float]: Calculated metric value for each system.
Expand All @@ -361,6 +383,11 @@ def calculate_memory_scalers(

# Calculate memory scaling factor based on atom count and density
metrics = calculate_memory_scalers(state, memory_scales_with="n_atoms_x_density")

# Calculate memory scaling factor based on actual neighbor list edge count
metrics = calculate_memory_scalers(
state, memory_scales_with="n_edges", cutoff=5.0
)
"""
if memory_scales_with == "n_atoms":
return state.n_atoms_per_system.tolist()
Expand Down Expand Up @@ -410,6 +437,8 @@ def calculate_memory_scalers(
volume = bbox.prod() / 1000
scalers.append(system_state.n_atoms * (system_state.n_atoms / volume.item()))
return scalers
if memory_scales_with == "n_edges":
return _n_edges_scalers(state, cutoff)
raise ValueError(
f"Invalid metric: {memory_scales_with}, must be one of {get_args(MemoryScaling)}"
)
Expand Down Expand Up @@ -522,6 +551,7 @@ def __init__(
model: ModelInterface,
*,
memory_scales_with: MemoryScaling = "n_atoms_x_density",
cutoff: float = 7.0,
max_memory_scaler: float | None = None,
max_atoms_to_try: int = 500_000,
memory_scaling_factor: float = 1.6,
Expand All @@ -533,11 +563,16 @@ def __init__(
Args:
model (ModelInterface): Model to batch for, used to estimate memory
requirements.
memory_scales_with ("n_atoms" | "n_atoms_x_density"): Metric to use
for estimating memory requirements:
memory_scales_with ("n_atoms" | "n_atoms_x_density" | "n_edges"): Metric to
use for estimating memory requirements:
- "n_atoms": Uses only atom count
- "n_atoms_x_density": Uses atom count multiplied by number density
- "n_edges": Uses actual neighbor list edge count; most accurate overall
but more expensive; strongly recommended for molecular systems
Defaults to "n_atoms_x_density".
cutoff (float): Neighbor list cutoff in Angstroms. Only used when
memory_scales_with="n_edges". Should match the model's cutoff.
Defaults to 7.0.
max_memory_scaler (float | None): Maximum metric value allowed per system. If
None, will be automatically estimated. Defaults to None.
max_atoms_to_try (int): Maximum number of atoms to try when estimating
Expand All @@ -555,6 +590,7 @@ def __init__(
self.max_memory_scaler = max_memory_scaler
self.max_atoms_to_try = max_atoms_to_try
self.memory_scales_with = memory_scales_with
self.cutoff = cutoff
self.model = model
self.memory_scaling_factor = memory_scaling_factor
self.max_memory_padding = max_memory_padding
Expand Down Expand Up @@ -595,7 +631,9 @@ def load_states(self, states: T | Sequence[T]) -> float:
batched = (
states if isinstance(states, SimState) else ts.concatenate_states(states)
)
self.memory_scalers = calculate_memory_scalers(batched, self.memory_scales_with)
self.memory_scalers = calculate_memory_scalers(
batched, self.memory_scales_with, self.cutoff
)
if not self.max_memory_scaler:
self.max_memory_scaler = estimate_max_memory_scaler(
batched,
Expand Down Expand Up @@ -799,6 +837,7 @@ def __init__(
model: ModelInterface,
*,
memory_scales_with: MemoryScaling = "n_atoms_x_density",
cutoff: float = 7.0,
max_memory_scaler: float | None = None,
max_atoms_to_try: int = 500_000,
memory_scaling_factor: float = 1.6,
Expand All @@ -811,11 +850,16 @@ def __init__(
Args:
model (ModelInterface): Model to batch for, used to estimate memory
requirements.
memory_scales_with ("n_atoms" | "n_atoms_x_density"): Metric to use
for estimating memory requirements:
memory_scales_with ("n_atoms" | "n_atoms_x_density" | "n_edges"): Metric to
use for estimating memory requirements:
- "n_atoms": Uses only atom count
- "n_atoms_x_density": Uses atom count multiplied by number density
- "n_edges": Uses actual neighbor list edge count; most accurate overall
but more expensive; strongly recommended for molecular systems
Defaults to "n_atoms_x_density".
cutoff (float): Neighbor list cutoff in Angstroms. Only used when
memory_scales_with="n_edges". Should match the model's cutoff.
Defaults to 7.0.
max_memory_scaler (float | None): Maximum metric value allowed per system.
If None, will be automatically estimated. Defaults to None.
max_atoms_to_try (int): Maximum number of atoms to try when estimating
Expand All @@ -835,6 +879,7 @@ def __init__(
"""
self.model = model
self.memory_scales_with = memory_scales_with
self.cutoff = cutoff
self.max_memory_scaler = max_memory_scaler or None
self.max_atoms_to_try = max_atoms_to_try
self.memory_scaling_factor = memory_scaling_factor
Expand Down Expand Up @@ -903,7 +948,9 @@ def _get_next_states(self) -> list[T]:
new_idx: list[int] = []
new_states: list[T] = []
for state in self.states_iterator:
metric = calculate_memory_scalers(state, self.memory_scales_with)[0]
metric = calculate_memory_scalers(
state, self.memory_scales_with, self.cutoff
)[0]
if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator]
raise ValueError(
f"State {metric=} is greater than max_metric {self.max_memory_scaler}"
Expand Down
2 changes: 1 addition & 1 deletion torch_sim/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch_sim.state import SimState


MemoryScaling = Literal["n_atoms_x_density", "n_atoms"]
MemoryScaling = Literal["n_atoms_x_density", "n_atoms", "n_edges"]
StateKey = Literal["positions", "masses", "cell", "pbc", "atomic_numbers", "system_idx"]
StateDict = dict[StateKey, torch.Tensor]

Expand Down
Loading