Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b8d49fd
in progres batching size determination
orionarcher Mar 7, 2025
f63b8da
add autobatching logic
orionarcher Mar 8, 2025
8c90ac3
rename batching_utils -> autobatching
orionarcher Mar 8, 2025
748dfa1
fix logic for too big states and memory estimation
orionarcher Mar 8, 2025
7606bb8
add tests and correct chunked autobatcher
orionarcher Mar 8, 2025
cd0e638
lint tests
orionarcher Mar 8, 2025
fcc439d
add testing and make APIs more consistent
orionarcher Mar 9, 2025
6df12f9
small reorganization
orionarcher Mar 9, 2025
c32a023
Merge commit 'a5a4a9477a7e0c715094b2f379de3c0e7e811218' into batch_ma…
orionarcher Mar 9, 2025
29dea77
update state to work with hotswapping, split and concat logic needs r…
orionarcher Mar 10, 2025
78f7a25
add iterator logic to chunking and fix logic for hot swapping
orionarcher Mar 10, 2025
972266d
update convergence handling in runner
orionarcher Mar 10, 2025
aa31794
update convergence function in high level api
orionarcher Mar 10, 2025
7b84830
add an autobatching example script
orionarcher Mar 10, 2025
a04276c
add optimized pop state and split state utilities to states
orionarcher Mar 11, 2025
1100eb1
update hot swapping autobatcher to return full state and pop states m…
orionarcher Mar 11, 2025
4ddb5c6
add improved and more efficient pop_states and split states methods
orionarcher Mar 11, 2025
0e21fd4
finish hotswapping autobatching function
orionarcher Mar 11, 2025
6d8e68c
make pop_states take list of ints
orionarcher Mar 11, 2025
468e561
update autobatching and example
orionarcher Mar 11, 2025
69ffc28
add binpacking>=1.5.2 to pkg deps
janosh Mar 11, 2025
bac0861
lint
janosh Mar 11, 2025
fcc6da1
.gitignore logging and model checkpoint
janosh Mar 11, 2025
ebae9fb
fix testing and add return indices to chunking auto batcher
orionarcher Mar 12, 2025
e5a8f2d
tighten convergence on runners
orionarcher Mar 12, 2025
cc3f191
finish chunking tests
orionarcher Mar 12, 2025
3e54f4d
rename metric and associated vars
orionarcher Mar 13, 2025
ca84c0b
change names of utility functions
orionarcher Mar 13, 2025
012518d
lint
orionarcher Mar 13, 2025
8202ab5
fix testing
orionarcher Mar 13, 2025
e1f2140
final lint
orionarcher Mar 13, 2025
e694785
skip example if not cuda and correct case of hotswapping autobatcher
orionarcher Mar 13, 2025
1124eb6
clean script
orionarcher Mar 13, 2025
04bbeef
system exit in proper place and raise error if memory estimation is a…
orionarcher Mar 13, 2025
97fb53b
try enabling running on CPU
orionarcher Mar 13, 2025
264a660
lint
orionarcher Mar 13, 2025
07addcc
propagate Hotswapping->HotSwapping rename to tests
janosh Mar 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,10 @@ __pycache__
build/
dist/
*.egg-info

# logging
*.log
*log.txt

# model checkpoints
*.model
5 changes: 2 additions & 3 deletions examples/4_High_level_api/4.1_high_level_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,8 @@
system=systems,
model=mace_model,
optimizer=unit_cell_fire,
convergence_fn=lambda state, last_energy: torch.all(
last_energy - state.energy < 1e-6 * MetalUnits.energy
),
convergence_fn=lambda state, last_energy: last_energy - state.energy
< 1e-6 * MetalUnits.energy,
max_steps=10 if os.getenv("CI") else 1000,
)

Expand Down
137 changes: 137 additions & 0 deletions examples/4_High_level_api/4.2_auto_batching_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Examples of using the auto-batching API."""

# /// script
# dependencies = [
# "mace-torch>=0.3.10",
# ]
# ///

"""Run as a interactive script."""
# ruff: noqa: E402


# %%
import os

import torch
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp

from torchsim.autobatching import (
ChunkingAutoBatcher,
HotSwappingAutoBatcher,
calculate_memory_scaler,
split_state,
)
from torchsim.integrators import nvt_langevin
from torchsim.models.mace import MaceModel
from torchsim.optimizers import unit_cell_fire
from torchsim.runners import atoms_to_state
from torchsim.state import BaseState
from torchsim.units import MetalUnits


if not torch.cuda.is_available():
raise SystemExit(0)

si_atoms = bulk("Si", "fcc", a=5.43, cubic=True).repeat((3, 3, 3))
fe_atoms = bulk("Fe", "fcc", a=5.43, cubic=True).repeat((3, 3, 3))

device = torch.device("cuda")

mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(
model=mace,
device=device,
periodic=True,
dtype=torch.float64,
compute_force=True,
)

si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64)
fe_state = atoms_to_state(fe_atoms, device=device, dtype=torch.float64)

fire_init, fire_update = unit_cell_fire(mace_model)

si_fire_state = fire_init(si_state)
fe_fire_state = fire_init(fe_state)

fire_states = [si_fire_state, fe_fire_state] * (2 if os.getenv("CI") else 20)
fire_states = [state.clone() for state in fire_states]
for state in fire_states:
state.positions += torch.randn_like(state.positions) * 0.01

len(fire_states)


# %% run hot swapping autobatcher
def convergence_fn(state: BaseState) -> bool:
"""Check if the system has converged."""
batch_wise_max_force = torch.zeros(state.n_batches, device=state.device)
max_forces = state.forces.norm(dim=1)
batch_wise_max_force = batch_wise_max_force.scatter_reduce(
dim=0,
index=state.batch,
src=max_forces,
reduce="amax",
)
return batch_wise_max_force < 1e-1


single_system_memory = calculate_memory_scaler(fire_states[0])
batcher = HotSwappingAutoBatcher(
model=mace_model,
states=fire_states,
memory_scales_with="n_atoms_x_density",
max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None,
)

all_completed_states, convergence_tensor = [], None
while True:
print(f"Starting new batch of {state.n_batches} states.")

state, completed_states = batcher.next_batch(state, convergence_tensor)
print("Number of completed states", len(completed_states))

all_completed_states.extend(completed_states)
if state is None:
break

# run 10 steps, arbitrary number
for _step in range(10):
state = fire_update(state)
convergence_tensor = convergence_fn(state)


# %% run chunking autobatcher
nvt_init, nvt_update = nvt_langevin(
model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature
)


si_state = atoms_to_state(si_atoms, device=device, dtype=torch.float64)
fe_state = atoms_to_state(fe_atoms, device=device, dtype=torch.float64)

si_nvt_state = nvt_init(si_state)
fe_nvt_state = nvt_init(fe_state)

nvt_states = [si_nvt_state, fe_nvt_state] * (2 if os.getenv("CI") else 20)
nvt_states = [state.clone() for state in nvt_states]
for state in nvt_states:
state.positions += torch.randn_like(state.positions) * 0.01


single_system_memory = calculate_memory_scaler(fire_states[0])
batcher = ChunkingAutoBatcher(
model=mace_model,
states=nvt_states,
memory_scales_with="n_atoms_x_density",
max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None,
)

finished_states = []
for batch in batcher:
for _ in range(100):
batch = nvt_update(batch)

finished_states.extend(split_state(batch))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ classifiers = [
requires-python = ">=3.11"
dependencies = [
"ase>=3.24",
"binpacking>=1.5.2",
"h5py>=3.12.1",
"numpy>=1.26",
"tables>=3.10.2",
Expand Down
Loading