Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4a4d43b
Implement lbfgs
abhijeetgangan Nov 26, 2025
793de83
fix example
abhijeetgangan Nov 26, 2025
b4c692f
Implement bfgs
abhijeetgangan Nov 26, 2025
49579d3
Merge branch 'main' into ag/opt_lbfgs_bfgs
abhijeetgangan Dec 3, 2025
a2e9b0d
Merge branch 'main' into ag/opt_lbfgs_bfgs
abhijeetgangan Dec 16, 2025
4e85aa0
Merge branch 'main' into ag/opt_lbfgs_bfgs
abhijeetgangan Dec 30, 2025
57458ad
Merge branch 'main' into ag/opt_lbfgs_bfgs
CompRhys Jan 8, 2026
b41a8ef
clean scripts per #385
CompRhys Jan 9, 2026
a5688c5
Merge branch 'main' into ag/opt_lbfgs_bfgs
abhijeetgangan Jan 20, 2026
a0ebb47
Merge branch 'main' into ag/opt_lbfgs_bfgs
abhijeetgangan Feb 3, 2026
cea301d
correct unit cell filter and add CellLBFGS and CellBFGS states
abhijeetgangan Feb 4, 2026
9478d85
global attr for LBFGS history
abhijeetgangan Feb 4, 2026
9c20599
Add methods to init
abhijeetgangan Feb 4, 2026
a9ede47
Add test comparing with ASE
abhijeetgangan Feb 4, 2026
802c966
Update batched bfgs
abhijeetgangan Feb 4, 2026
8983439
charge and spin
abhijeetgangan Feb 4, 2026
f51a0db
Update tests
abhijeetgangan Feb 4, 2026
2da4c44
BFGS with constraints
abhijeetgangan Feb 4, 2026
bd360b4
Add batched lbfgs
abhijeetgangan Feb 4, 2026
0cd42b3
bump max history and fix shape issue
abhijeetgangan Feb 4, 2026
69ce9f4
add optimizer test and remove redundant autobatcher test
abhijeetgangan Feb 4, 2026
790236f
why exact check for floating point?
abhijeetgangan Feb 4, 2026
7c20da1
keep example script same
abhijeetgangan Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/scripts/2_structural_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,72 @@
print(f"Initial pressure: {initial_pressure} GPa")
print(f"Final pressure: {final_pressure} GPa")

# ============================================================================
# SECTION 7: Batched MACE L-BFGS
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 7: Batched MACE L-BFGS")
print("=" * 70)

# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)

cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)

fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)

atoms_list = [si_dc, cu_dc, fe_dc]

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
results = model(state)
state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0)

print("\nRunning L-BFGS:")
for step in range(N_steps):
if step % 20 == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
state = ts.lbfgs_step(state=state, model=model, max_history=100)

print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")


# ============================================================================
# SECTION 8: Batched MACE BFGS
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 8: Batched MACE BFGS")
print("=" * 70)

# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)

cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)

fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)

atoms_list = [si_dc, cu_dc, fe_dc]

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
results = model(state)
state = ts.bfgs_init(state=state, model=model, alpha=70.0)

print("\nRunning BFGS:")
for step in range(N_steps):
if step % 20 == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
state = ts.bfgs_step(state=state, model=model)

print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")


print("\n" + "=" * 70)
print("Structural optimization examples completed!")
print("=" * 70)
176 changes: 176 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,179 @@ def test_in_flight_max_iterations(
# Verify iteration_count tracking
for idx in range(len(states)):
assert batcher.iteration_count[idx] == max_iterations


@pytest.mark.parametrize(
"num_steps_per_batch",
[
5, # At 5 steps, not every state will converge before the next batch.
10, # At 10 steps, all states will converge before the next batch
],
)
def test_in_flight_with_bfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
num_steps_per_batch: int,
) -> None:
"""Test InFlightAutoBatcher with BFGS optimizer."""
si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_bfgs_state = ts.bfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5
bfgs_states = [state.clone() for state in bfgs_states]
for state in bfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=6000,
)
batcher.load_states(bfgs_states)

def convergence_fn(state: ts.BFGSState) -> torch.Tensor:
system_wise_max_force = torch.zeros(
state.n_systems, device=state.device, dtype=torch.float64
)
max_forces = state.forces.norm(dim=1)
system_wise_max_force = system_wise_max_force.scatter_reduce(
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
)
return system_wise_max_force < 5e-1

all_completed_states, convergence_tensor = [], None
while True:
state, completed_states = batcher.next_batch(state, convergence_tensor)

all_completed_states.extend(completed_states)
if state is None:
break

for _ in range(num_steps_per_batch):
state = ts.bfgs_step(state=state, model=lj_model)
convergence_tensor = convergence_fn(state)

assert len(all_completed_states) == len(bfgs_states)


def test_binning_auto_batcher_with_bfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with BFGS optimizer."""
si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_bfgs_state = ts.bfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5
bfgs_states = [state.clone() for state in bfgs_states]
for state in bfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = BinningAutoBatcher(
model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000
)
batcher.load_states(bfgs_states)

all_finished_states: list[ts.SimState] = []
total_batches = 0
for batch, _ in batcher:
total_batches += 1 # noqa: SIM113
for _ in range(5):
batch = ts.bfgs_step(state=batch, model=lj_model)
all_finished_states.extend(batch.split())

assert len(all_finished_states) == len(bfgs_states)


@pytest.mark.parametrize(
"num_steps_per_batch",
[
5, # At 5 steps, not every state will converge before the next batch.
10, # At 10 steps, all states will converge before the next batch
],
)
def test_in_flight_with_lbfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
num_steps_per_batch: int,
) -> None:
"""Test InFlightAutoBatcher with L-BFGS optimizer."""
si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_lbfgs_state = ts.lbfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5
lbfgs_states = [state.clone() for state in lbfgs_states]
for state in lbfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=6000,
)
batcher.load_states(lbfgs_states)

def convergence_fn(state: ts.LBFGSState) -> torch.Tensor:
system_wise_max_force = torch.zeros(
state.n_systems, device=state.device, dtype=torch.float64
)
max_forces = state.forces.norm(dim=1)
system_wise_max_force = system_wise_max_force.scatter_reduce(
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
)
return system_wise_max_force < 5e-1

all_completed_states, convergence_tensor = [], None
while True:
state, completed_states = batcher.next_batch(state, convergence_tensor)

all_completed_states.extend(completed_states)
if state is None:
break

for _ in range(num_steps_per_batch):
state = ts.lbfgs_step(state=state, model=lj_model)
convergence_tensor = convergence_fn(state)

assert len(all_completed_states) == len(lbfgs_states)


def test_binning_auto_batcher_with_lbfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with L-BFGS optimizer."""
si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_lbfgs_state = ts.lbfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5
lbfgs_states = [state.clone() for state in lbfgs_states]
for state in lbfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = BinningAutoBatcher(
model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000
)
batcher.load_states(lbfgs_states)

all_finished_states: list[ts.SimState] = []
total_batches = 0
for batch, _ in batcher:
total_batches += 1 # noqa: SIM113
for _ in range(5):
batch = ts.lbfgs_step(state=batch, model=lj_model)
all_finished_states.extend(batch.split())

assert len(all_finished_states) == len(lbfgs_states)
Loading