diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 1d8184838..2ca229268 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import os +from typing import TYPE_CHECKING, Generator import numpy as np @@ -94,7 +95,7 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: "forces": np.array([forces]), } try: - stress = atoms.get_stress(False) + stress = atoms.get_stress(voigt=False) except PropertyNotImplementedError: pass else: @@ -110,7 +111,7 @@ def from_multi_systems( step: int | None = None, ase_fmt: str | None = None, **kwargs, - ) -> ase.Atoms: + ) -> Generator[ase.Atoms, None, None]: """Convert a ASE supported file to ASE Atoms. It will finally be converted to MultiSystems. @@ -140,7 +141,7 @@ def from_multi_systems( frames = ase.io.read(file_name, format=ase_fmt, index=slice(begin, end, step)) yield from frames - def to_system(self, data, **kwargs): + def to_system(self, data, **kwargs) -> list[ase.Atoms]: """Convert System to ASE Atom obj.""" from ase import Atoms @@ -158,7 +159,7 @@ def to_system(self, data, **kwargs): return structures - def to_labeled_system(self, data, *args, **kwargs): + def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]: """Convert System to ASE Atoms object.""" from ase import Atoms from ase.calculators.singlepoint import SinglePointCalculator @@ -300,6 +301,46 @@ def from_labeled_system( return dict_frames + def to_system(self, data, file_name: str = "confs.traj", **kwargs) -> None: + """Convert System to ASE Atoms object. + + Parameters + ---------- + file_name : str + path to file + """ + from ase.io import Trajectory + + if os.path.isfile(file_name): + os.remove(file_name) + + list_atoms = ASEStructureFormat().to_system(data, **kwargs) + traj = Trajectory(file_name, "a") + _ = [traj.write(atom) for atom in list_atoms] + traj.close() + return + + def to_labeled_system( + self, data, file_name: str = "labeled_confs.traj", *args, **kwargs + ) -> None: + """Convert System to ASE Atoms object. + + Parameters + ---------- + file_name : str + path to file + """ + from ase.io import Trajectory + + if os.path.isfile(file_name): + os.remove(file_name) + + list_atoms = ASEStructureFormat().to_labeled_system(data, *args, **kwargs) + traj = Trajectory(file_name, "a") + _ = [traj.write(atom) for atom in list_atoms] + traj.close() + return + @Driver.register("ase") class ASEDriver(Driver): diff --git a/tests/test_ase_traj.py b/tests/test_ase_traj.py index 8e4a6e12f..98bc8bc04 100644 --- a/tests/test_ase_traj.py +++ b/tests/test_ase_traj.py @@ -67,5 +67,29 @@ def setUp(self): self.v_places = 4 +@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix") +class TestASEtraj4(unittest.TestCase, CompSys, IsPBC): + def setUp(self): + self.system_1 = dpdata.System("ase_traj/MoS2", fmt="deepmd") + self.system_1.to(file_name="ase_traj/tmp.traj", fmt="ase/traj") + self.system_2 = dpdata.System("ase_traj/tmp.traj", fmt="ase/traj") + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 4 + + +@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix") +class TestASEtraj4Labeled(unittest.TestCase, CompLabeledSys, IsPBC): + def setUp(self): + self.system_1 = dpdata.LabeledSystem("ase_traj/MoS2", fmt="deepmd") + self.system_1.to(file_name="ase_traj/tmp1.traj", fmt="ase/traj") + self.system_2 = dpdata.LabeledSystem("ase_traj/tmp1.traj", fmt="ase/traj") + self.places = 6 + self.e_places = 6 + self.f_places = 6 + self.v_places = 4 + + if __name__ == "__main__": unittest.main()