diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index f5da1674..144b2c63 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -31,7 +31,7 @@ def random_state() -> MDState: energy=torch.tensor(1.0), forces=torch.randn(10, 3), masses=torch.ones(10), - cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), + cell=torch.tensor([[[4.0, -4.0, 0.0], [4.0, 4.0, 0.0], [0.0, 0.0, 8.0]]]), atomic_numbers=torch.ones(10, dtype=torch.int32), system_idx=torch.zeros(10, dtype=torch.int32), pbc=[True, True, False], @@ -453,7 +453,7 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non # Test basic properties assert len(atoms) == len(random_state.atomic_numbers) - np.testing.assert_allclose(atoms.get_cell(), random_state.cell.numpy()[0]) + np.testing.assert_allclose(atoms.get_cell(), random_state.row_vector_cell.numpy()[0]) np.testing.assert_allclose(atoms.get_positions(), random_state.positions.numpy()) np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() @@ -524,7 +524,9 @@ def test_write_ase_trajectory( for _, atoms in enumerate(ase_traj): # Check basic properties match assert len(atoms) == len(random_state.atomic_numbers) - np.testing.assert_allclose(atoms.get_cell(), random_state.cell.numpy()[0]) + np.testing.assert_allclose( + atoms.get_cell(), random_state.row_vector_cell.numpy()[0] + ) np.testing.assert_allclose(atoms.get_positions(), random_state.positions.numpy()) np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() @@ -826,7 +828,10 @@ def test_get_atoms_importerror(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) ) traj.write_state(state, steps=0) - with pytest.raises(ImportError, match="ASE is required to convert to ASE Atoms"): + with pytest.raises( + ImportError, + match="ASE is required for state_to_atoms conversion", + ): traj.get_atoms(0) traj.close() diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 60c3a2f2..338a827e 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -1047,24 +1047,11 @@ def get_structure(self, frame: int = -1) -> Any: Raises: ImportError: If pymatgen is not installed """ - from pymatgen.core import Structure + from torch_sim.io import state_to_structures - arrays = self._get_state_arrays(frame) - - # Create pymatgen Structure - # TODO: check if this is correct - lattice = arrays["cell"][0].T # pymatgen expects lattice matrix as rows - species = [str(num) for num in arrays["atomic_numbers"]] - - return Structure( - lattice=np.ascontiguousarray(lattice), - species=species, - coords=np.ascontiguousarray(arrays["positions"]), - coords_are_cartesian=True, - validate_proximity=False, - ) + return state_to_structures(self.get_state(frame, device=torch.device("cpu")))[0] - def get_atoms(self, frame: int = -1) -> "Atoms": + def get_atoms(self, frame: int = -1, **kwargs: Any) -> "Atoms": """Get an ASE Atoms object for a given frame. Converts the state at the specified frame to an ASE Atoms object @@ -1072,6 +1059,7 @@ def get_atoms(self, frame: int = -1) -> "Atoms": Args: frame (int): Frame index to retrieve (-1 for last frame) + **kwargs: Additional keyword arguments passed to `state_to_atoms`. Returns: Atoms: ASE Atoms object for the specified frame @@ -1079,21 +1067,11 @@ def get_atoms(self, frame: int = -1) -> "Atoms": Raises: ImportError: If ASE is not installed """ - try: - from ase import Atoms - except ImportError: - raise ImportError( - "ASE is required to convert to ASE Atoms. Run `pip install ase`" - ) from None + from torch_sim.io import state_to_atoms - arrays = self._get_state_arrays(frame) - - return Atoms( - numbers=np.ascontiguousarray(arrays["atomic_numbers"]), - positions=np.ascontiguousarray(arrays["positions"]), - cell=np.ascontiguousarray(arrays["cell"])[0], - pbc=np.ascontiguousarray(arrays["pbc"]), - ) + return state_to_atoms( + self.get_state(frame, device=torch.device("cpu")), **kwargs + )[0] def get_state( self,