Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tests/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def random_state() -> MDState:
energy=torch.tensor(1.0),
forces=torch.randn(10, 3),
masses=torch.ones(10),
cell=torch.unsqueeze(torch.eye(3) * 10.0, 0),
cell=torch.tensor([[[4.0, -4.0, 0.0], [4.0, 4.0, 0.0], [0.0, 0.0, 8.0]]]),
atomic_numbers=torch.ones(10, dtype=torch.int32),
system_idx=torch.zeros(10, dtype=torch.int32),
pbc=[True, True, False],
Expand Down Expand Up @@ -453,7 +453,7 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non

# Test basic properties
assert len(atoms) == len(random_state.atomic_numbers)
np.testing.assert_allclose(atoms.get_cell(), random_state.cell.numpy()[0])
np.testing.assert_allclose(atoms.get_cell(), random_state.row_vector_cell.numpy()[0])
np.testing.assert_allclose(atoms.get_positions(), random_state.positions.numpy())
np.testing.assert_allclose(
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
Expand Down Expand Up @@ -524,7 +524,9 @@ def test_write_ase_trajectory(
for _, atoms in enumerate(ase_traj):
# Check basic properties match
assert len(atoms) == len(random_state.atomic_numbers)
np.testing.assert_allclose(atoms.get_cell(), random_state.cell.numpy()[0])
np.testing.assert_allclose(
atoms.get_cell(), random_state.row_vector_cell.numpy()[0]
)
np.testing.assert_allclose(atoms.get_positions(), random_state.positions.numpy())
np.testing.assert_allclose(
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
Expand Down Expand Up @@ -826,7 +828,10 @@ def test_get_atoms_importerror(monkeypatch: pytest.MonkeyPatch, tmp_path: Path)
)
traj.write_state(state, steps=0)

with pytest.raises(ImportError, match="ASE is required to convert to ASE Atoms"):
with pytest.raises(
ImportError,
match="ASE is required for state_to_atoms conversion",
):
traj.get_atoms(0)
traj.close()

Expand Down
38 changes: 8 additions & 30 deletions torch_sim/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,53 +1047,31 @@ def get_structure(self, frame: int = -1) -> Any:
Raises:
ImportError: If pymatgen is not installed
"""
from pymatgen.core import Structure
from torch_sim.io import state_to_structures

arrays = self._get_state_arrays(frame)

# Create pymatgen Structure
# TODO: check if this is correct
lattice = arrays["cell"][0].T # pymatgen expects lattice matrix as rows
species = [str(num) for num in arrays["atomic_numbers"]]

return Structure(
lattice=np.ascontiguousarray(lattice),
species=species,
coords=np.ascontiguousarray(arrays["positions"]),
coords_are_cartesian=True,
validate_proximity=False,
)
return state_to_structures(self.get_state(frame, device=torch.device("cpu")))[0]

def get_atoms(self, frame: int = -1) -> "Atoms":
def get_atoms(self, frame: int = -1, **kwargs: Any) -> "Atoms":
"""Get an ASE Atoms object for a given frame.

Converts the state at the specified frame to an ASE Atoms object
for analysis and visualization.

Args:
frame (int): Frame index to retrieve (-1 for last frame)
**kwargs: Additional keyword arguments passed to `state_to_atoms`.

Returns:
Atoms: ASE Atoms object for the specified frame

Raises:
ImportError: If ASE is not installed
"""
try:
from ase import Atoms
except ImportError:
raise ImportError(
"ASE is required to convert to ASE Atoms. Run `pip install ase`"
) from None
from torch_sim.io import state_to_atoms

arrays = self._get_state_arrays(frame)

return Atoms(
numbers=np.ascontiguousarray(arrays["atomic_numbers"]),
positions=np.ascontiguousarray(arrays["positions"]),
cell=np.ascontiguousarray(arrays["cell"])[0],
pbc=np.ascontiguousarray(arrays["pbc"]),
)
return state_to_atoms(
self.get_state(frame, device=torch.device("cpu")), **kwargs
)[0]

def get_state(
self,
Expand Down
Loading