From 42a48d650f2d6e86e965627c55d2ca38835eef91 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 2 May 2026 09:36:12 -0400 Subject: [PATCH 1/3] fix: traj atoms cells need to be row vector --- tests/test_trajectory.py | 8 +++++--- torch_sim/trajectory.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index f5da1674d..474d62d9d 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -31,7 +31,7 @@ def random_state() -> MDState: energy=torch.tensor(1.0), forces=torch.randn(10, 3), masses=torch.ones(10), - cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), + cell=torch.tensor([[[4.0, -4.0, 0.0], [4.0, 4.0, 0.0], [0.0, 0.0, 8.0]]]), atomic_numbers=torch.ones(10, dtype=torch.int32), system_idx=torch.zeros(10, dtype=torch.int32), pbc=[True, True, False], @@ -453,7 +453,7 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non # Test basic properties assert len(atoms) == len(random_state.atomic_numbers) - np.testing.assert_allclose(atoms.get_cell(), random_state.cell.numpy()[0]) + np.testing.assert_allclose(atoms.get_cell(), random_state.row_vector_cell.numpy()[0]) np.testing.assert_allclose(atoms.get_positions(), random_state.positions.numpy()) np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() @@ -524,7 +524,9 @@ def test_write_ase_trajectory( for _, atoms in enumerate(ase_traj): # Check basic properties match assert len(atoms) == len(random_state.atomic_numbers) - np.testing.assert_allclose(atoms.get_cell(), random_state.cell.numpy()[0]) + np.testing.assert_allclose( + atoms.get_cell(), random_state.row_vector_cell.numpy()[0] + ) np.testing.assert_allclose(atoms.get_positions(), random_state.positions.numpy()) np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 60c3a2f2e..9b51e9799 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -1091,7 +1091,7 @@ def get_atoms(self, frame: int = -1) -> "Atoms": return Atoms( numbers=np.ascontiguousarray(arrays["atomic_numbers"]), positions=np.ascontiguousarray(arrays["positions"]), - cell=np.ascontiguousarray(arrays["cell"])[0], + cell=np.ascontiguousarray(arrays["cell"])[0].T, pbc=np.ascontiguousarray(arrays["pbc"]), ) From b9d4ded7dc38447f0d7945b0b90c3040bcde7a87 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 2 May 2026 09:42:33 -0400 Subject: [PATCH 2/3] dedup code even though there may be some speed penalty --- tests/test_trajectory.py | 5 ++++- torch_sim/trajectory.py | 33 ++++----------------------------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 474d62d9d..144b2c63b 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -828,7 +828,10 @@ def test_get_atoms_importerror(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) ) traj.write_state(state, steps=0) - with pytest.raises(ImportError, match="ASE is required to convert to ASE Atoms"): + with pytest.raises( + ImportError, + match="ASE is required for state_to_atoms conversion", + ): traj.get_atoms(0) traj.close() diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 9b51e9799..d441a3f79 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -1047,22 +1047,9 @@ def get_structure(self, frame: int = -1) -> Any: Raises: ImportError: If pymatgen is not installed """ - from pymatgen.core import Structure + from torch_sim.io import state_to_structures - arrays = self._get_state_arrays(frame) - - # Create pymatgen Structure - # TODO: check if this is correct - lattice = arrays["cell"][0].T # pymatgen expects lattice matrix as rows - species = [str(num) for num in arrays["atomic_numbers"]] - - return Structure( - lattice=np.ascontiguousarray(lattice), - species=species, - coords=np.ascontiguousarray(arrays["positions"]), - coords_are_cartesian=True, - validate_proximity=False, - ) + return state_to_structures(self.get_state(frame, device=torch.device("cpu")))[0] def get_atoms(self, frame: int = -1) -> "Atoms": """Get an ASE Atoms object for a given frame. @@ -1079,21 +1066,9 @@ def get_atoms(self, frame: int = -1) -> "Atoms": Raises: ImportError: If ASE is not installed """ - try: - from ase import Atoms - except ImportError: - raise ImportError( - "ASE is required to convert to ASE Atoms. Run `pip install ase`" - ) from None + from torch_sim.io import state_to_atoms - arrays = self._get_state_arrays(frame) - - return Atoms( - numbers=np.ascontiguousarray(arrays["atomic_numbers"]), - positions=np.ascontiguousarray(arrays["positions"]), - cell=np.ascontiguousarray(arrays["cell"])[0].T, - pbc=np.ascontiguousarray(arrays["pbc"]), - ) + return state_to_atoms(self.get_state(frame, device=torch.device("cpu")))[0] def get_state( self, From f6fe3cfce79c7d7d9c37efbeb469a78de994d213 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 2 May 2026 09:57:49 -0400 Subject: [PATCH 3/3] fea: get_atoms kwargs --- torch_sim/trajectory.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index d441a3f79..338a827e3 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -1051,7 +1051,7 @@ def get_structure(self, frame: int = -1) -> Any: return state_to_structures(self.get_state(frame, device=torch.device("cpu")))[0] - def get_atoms(self, frame: int = -1) -> "Atoms": + def get_atoms(self, frame: int = -1, **kwargs: Any) -> "Atoms": """Get an ASE Atoms object for a given frame. Converts the state at the specified frame to an ASE Atoms object @@ -1059,6 +1059,7 @@ def get_atoms(self, frame: int = -1) -> "Atoms": Args: frame (int): Frame index to retrieve (-1 for last frame) + **kwargs: Additional keyword arguments passed to `state_to_atoms`. Returns: Atoms: ASE Atoms object for the specified frame @@ -1068,7 +1069,9 @@ def get_atoms(self, frame: int = -1) -> "Atoms": """ from torch_sim.io import state_to_atoms - return state_to_atoms(self.get_state(frame, device=torch.device("cpu")))[0] + return state_to_atoms( + self.get_state(frame, device=torch.device("cpu")), **kwargs + )[0] def get_state( self,