diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 66bf4ee66..9d63e1dd4 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -1,9 +1,12 @@ import glob import os import shutil +import warnings import numpy as np +import dpdata + from .raw import load_type @@ -60,6 +63,40 @@ def to_system_data(folder, type_map=None, labels=True): data["forces"] = np.concatenate(all_forces, axis=0) if len(all_virs) > 0: data["virials"] = np.concatenate(all_virs, axis=0) + # allow custom dtypes + if labels: + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format." + ) + continue + shape = [ + -1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] + ] + all_data = [] + for ii in sets: + tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy")) + if tmp is not None: + all_data.append(np.reshape(tmp, [tmp.shape[0], *shape])) + if len(all_data) > 0: + data[dtype.name] = np.concatenate(all_data, axis=0) return data @@ -131,3 +168,34 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True): if data.get("nopbc", False): with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: pass + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): + # skip as these data contains specific rules + continue + if dtype.name not in data: + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/npy format." + ) + continue + ddata = np.reshape(data[dtype.name], [nframes, -1]).astype(comp_prec) + for ii in range(nsets): + set_stt = ii * set_size + set_end = (ii + 1) * set_size + set_folder = os.path.join(folder, "set.%03d" % ii) + np.save(os.path.join(set_folder, dtype.name), ddata[set_stt:set_end]) diff --git a/dpdata/deepmd/hdf5.py b/dpdata/deepmd/hdf5.py index 1afb64894..b4ae1a3c6 100644 --- a/dpdata/deepmd/hdf5.py +++ b/dpdata/deepmd/hdf5.py @@ -1,10 +1,13 @@ """Utils for deepmd/hdf5 format.""" +import warnings from typing import Optional, Union import h5py import numpy as np from wcmatch.glob import globfilter +import dpdata + __all__ = ["to_system_data", "dump"] @@ -92,6 +95,36 @@ def to_system_data( "required": False, }, } + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/hdf5 format." + ) + continue + + data_types[dtype.name] = { + "fn": dtype.name, + "labeled": True, + "shape": dtype.shape[1:], + "required": False, + } for dt, prop in data_types.items(): all_data = [] @@ -167,6 +200,37 @@ def dump( "forces": {"fn": "force", "shape": (nframes, -1), "dump": True}, "virials": {"fn": "virial", "shape": (nframes, 9), "dump": True}, } + + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/hdf5 format." + ) + continue + + data_types[dtype.name] = { + "fn": dtype.name, + "shape": (nframes, -1), + "dump": True, + } + for dt, prop in data_types.items(): if dt in data: if prop["dump"]: diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index 2f2021d44..fdb2fc649 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -1,7 +1,10 @@ import os +import warnings import numpy as np +import dpdata + def load_type(folder, type_map=None): data = {} @@ -57,6 +60,41 @@ def to_system_data(folder, type_map=None, labels=True): data["virials"] = np.reshape(data["virials"], [nframes, 3, 3]) if os.path.isfile(os.path.join(folder, "nopbc")): data["nopbc"] = True + # allow custom dtypes + if labels: + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): + # skip as these data contains specific rules + continue + if not ( + len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES + ): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format." + ) + continue + shape = [ + -1 if xx == dpdata.system.Axis.NATOMS else xx + for xx in dtype.shape[1:] + ] + if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")): + data[dtype.name] = np.reshape( + np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")), + [nframes, *shape], + ) return data else: raise RuntimeError("not dir " + folder) @@ -102,3 +140,30 @@ def dump(folder, data): if data.get("nopbc", False): with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: pass + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): + # skip as these data contains specific rules + continue + if dtype.name not in data: + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/raw format." + ) + continue + ddata = np.reshape(data[dtype.name], [nframes, -1]) + np.savetxt(os.path.join(folder, f"{dtype.name}.raw"), ddata) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py new file mode 100644 index 000000000..58b9f5e49 --- /dev/null +++ b/tests/test_custom_data_type.py @@ -0,0 +1,53 @@ +import unittest + +import h5py +import numpy as np + +import dpdata +from dpdata.system import Axis, DataType + + +class TestDeepmdLoadDumpComp(unittest.TestCase): + def setUp(self): + self.backup = dpdata.system.LabeledSystem.DTYPES + dpdata.system.LabeledSystem.DTYPES = dpdata.system.LabeledSystem.DTYPES + ( + DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), + ) + self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") + self.foo = np.ones((len(self.system), 2, 4)) + self.system.data["foo"] = self.foo + self.system.check_data() + + def tearDown(self) -> None: + dpdata.system.LabeledSystem.DTYPES = self.backup + + def test_to_deepmd_raw(self): + self.system.to_deepmd_raw("data_foo") + foo = np.loadtxt("data_foo/foo.raw") + np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) + + def test_from_deepmd_raw(self): + self.system.to_deepmd_raw("data_foo") + x = dpdata.LabeledSystem("data_foo", fmt="deepmd/raw") + np.testing.assert_allclose(x.data["foo"], self.foo) + + def test_to_deepmd_npy(self): + self.system.to_deepmd_npy("data_foo") + foo = np.load("data_foo/set.000/foo.npy") + np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) + + def test_from_deepmd_npy(self): + self.system.to_deepmd_npy("data_foo") + x = dpdata.LabeledSystem("data_foo", fmt="deepmd/npy") + np.testing.assert_allclose(x.data["foo"], self.foo) + + def test_to_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_foo.h5") + with h5py.File("data_foo.h5") as f: + foo = f["set.000/foo.npy"][:] + np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) + + def test_from_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_foo.h5") + x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5") + np.testing.assert_allclose(x.data["foo"], self.foo)