diff --git a/examples/scripts/1_Introduction/1.1_Lennard_Jones.py b/examples/scripts/1_Introduction/1.1_Lennard_Jones.py index 91d8d26f0..860834c63 100644 --- a/examples/scripts/1_Introduction/1.1_Lennard_Jones.py +++ b/examples/scripts/1_Introduction/1.1_Lennard_Jones.py @@ -10,6 +10,7 @@ import torch +from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel @@ -69,6 +70,8 @@ dtype=dtype, compute_forces=True, compute_stress=True, + per_atom_energies=True, + per_atom_stresses=True, ) # Print system information @@ -88,3 +91,37 @@ print(f"Energy: {results['energy']}") print(f"Forces: {results['forces']}") print(f"Stress: {results['stress']}") +print(f"Energies: {results['energies']}") +print(f"Stresses: {results['stresses']}") + +# Batched model +batched_model = LennardJonesModel( + use_neighbor_list=True, + cutoff=2.5 * 3.405, + sigma=3.405, + epsilon=0.0104, + device=device, + dtype=dtype, + compute_forces=True, + compute_stress=True, + per_atom_energies=True, + per_atom_stresses=True, +) + +# Batched state +state = dict( + positions=positions, + cell=cell.unsqueeze(0), + atomic_numbers=atomic_numbers, + pbc=True, +) + +# Run the simulation and get results +results = batched_model(state) + +# Print the results +print(f"Energy: {results['energy']}") +print(f"Forces: {results['forces']}") +print(f"Stress: {results['stress']}") +print(f"Energies: {results['energies']}") +print(f"Stresses: {results['stresses']}") diff --git a/tests/models/test_lennard_jones.py b/tests/models/test_lennard_jones.py index 2d6ec15c8..3a7fa1f59 100644 --- a/tests/models/test_lennard_jones.py +++ b/tests/models/test_lennard_jones.py @@ -155,6 +155,8 @@ def models( "dtype": torch.float64, "compute_forces": True, "compute_stress": True, + "per_atom_energies": True, + "per_atom_stresses": True, } cutoff = 2.5 * 3.405 # Standard LJ cutoff * sigma @@ -178,6 +180,14 @@ def test_energy_match( assert torch.allclose(results_nl["energy"], results_direct["energy"], rtol=1e-10) +def test_per_atom_energy_match( + models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], +) -> None: + """Test that per-atom energy matches between neighbor list and direct calculations.""" + results_nl, results_direct = models + assert torch.allclose(results_nl["energies"], results_direct["energies"], rtol=1e-10) + + def test_forces_match( models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], ) -> None: @@ -194,6 +204,15 @@ def test_stress_match( assert torch.allclose(results_nl["stress"], results_direct["stress"], rtol=1e-10) +def test_per_atom_stress_match( + models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], +) -> None: + """Test that per-atom stress tensors match between neighbor list + and direct calculations.""" + results_nl, results_direct = models + assert torch.allclose(results_nl["stresses"], results_direct["stresses"], rtol=1e-10) + + def test_force_conservation( models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], ) -> None: diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 85f691c11..4910bea47 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -293,7 +293,10 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: compute_forces=True) - "stress": Stress tensor with shape [n_batches, 3, 3] (if compute_stress=True) - - May include additional outputs based on configuration + - "energies": Per-atom energies with shape [n_atoms] (if + per_atom_energies=True) + - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if + per_atom_stresses=True) Raises: ValueError: If batch cannot be inferred for multi-cell systems. @@ -307,6 +310,8 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: energy = results["energy"] # Shape: [n_batches] forces = results["forces"] # Shape: [n_atoms, 3] stress = results["stress"] # Shape: [n_batches, 3, 3] + energies = results["energies"] # Shape: [n_atoms] + stresses = results["stresses"] # Shape: [n_atoms, 3, 3] """ if isinstance(state, dict): state = SimState(**state, masses=torch.ones_like(state["positions"])) @@ -324,7 +329,7 @@ def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs]) - for key in ("forces",): + for key in ("forces", "energies", "stresses"): if key in properties: results[key] = torch.cat([out[key] for out in outputs], dim=0)