diff --git a/dpdata/data_type.py b/dpdata/data_type.py index c0bd944fb..b5141a4e0 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -19,6 +19,13 @@ class Axis(Enum): NBONDS = "nbonds" +class AnyInt(int): + """AnyInt equals to any other integer.""" + + def __eq__(self, other): + return True + + class DataError(Exception): """Data is not correct.""" @@ -64,6 +71,8 @@ def real_shape(self, system: "System") -> Tuple[int]: elif ii is Axis.NBONDS: # BondOrderSystem shape.append(system.get_nbonds()) + elif ii == -1: + shape.append(AnyInt(-1)) elif isinstance(ii, int): shape.append(ii) else: diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 9d63e1dd4..7b909b162 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -87,8 +87,10 @@ def to_system_data(folder, type_map=None, labels=True): f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format." ) continue + natoms = data["coords"].shape[1] shape = [ - -1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] + natoms if xx == dpdata.system.Axis.NATOMS else xx + for xx in dtype.shape[1:] ] all_data = [] for ii in sets: diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index fdb2fc649..c7a64ec47 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -86,8 +86,9 @@ def to_system_data(folder, type_map=None, labels=True): f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format." ) continue + natoms = data["coords"].shape[1] shape = [ - -1 if xx == dpdata.system.Axis.NATOMS else xx + natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] ] if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")): diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py index 29911b5f1..b3821cb34 100644 --- a/tests/plugin/dpdata_plugin_test/__init__.py +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -8,4 +8,9 @@ DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True ) +register_data_type( + DataType("bar", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False), + labeled=True, +) + ep = None diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 49c99586f..5a0e2bab9 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -43,3 +43,42 @@ def test_from_deepmd_hdf5(self): self.system.to_deepmd_hdf5("data_foo.h5") x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5") np.testing.assert_allclose(x.data["foo"], self.foo) + + +class TestDeepmdLoadDumpCompAny(unittest.TestCase): + def setUp(self): + self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") + self.bar = np.ones((len(self.system), self.system.get_natoms(), 2)) + self.system.data["bar"] = self.bar + self.system.check_data() + + def test_to_deepmd_raw(self): + self.system.to_deepmd_raw("data_bar") + bar = np.loadtxt("data_bar/bar.raw") + np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar) + + def test_from_deepmd_raw(self): + self.system.to_deepmd_raw("data_bar") + x = dpdata.LabeledSystem("data_bar", fmt="deepmd/raw") + np.testing.assert_allclose(x.data["bar"], self.bar) + + def test_to_deepmd_npy(self): + self.system.to_deepmd_npy("data_bar") + bar = np.load("data_bar/set.000/bar.npy") + np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar) + + def test_from_deepmd_npy(self): + self.system.to_deepmd_npy("data_bar") + x = dpdata.LabeledSystem("data_bar", fmt="deepmd/npy") + np.testing.assert_allclose(x.data["bar"], self.bar) + + def test_to_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_bar.h5") + with h5py.File("data_bar.h5") as f: + bar = f["set.000/bar.npy"][:] + np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar) + + def test_from_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_bar.h5") + x = dpdata.LabeledSystem("data_bar.h5", fmt="deepmd/hdf5") + np.testing.assert_allclose(x.data["bar"], self.bar)