From 189ae4137ed5c326d729f81b58d9effc0c1bd15f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 24 Jul 2023 16:03:41 -0400 Subject: [PATCH 1/2] support `-1` shape in DataType (#62) Sometimes, the shape is unknown and can be any integer. This PR supports this situation. Note: Only one `-1` can be used. --------- Signed-off-by: Jinzhe Zeng (cherry picked from commit 3bbb8253bd2e4af531d8bfc728e7b7a36006a9f3) --- dpdata/data_type.py | 8 ++++ dpdata/deepmd/comp.py | 3 +- dpdata/deepmd/raw.py | 3 +- tests/plugin/dpdata_plugin_test/__init__.py | 4 ++ tests/test_custom_data_type.py | 41 +++++++++++++++++++++ 5 files changed, 57 insertions(+), 2 deletions(-) diff --git a/dpdata/data_type.py b/dpdata/data_type.py index c0bd944fb..90ab7c06e 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -19,6 +19,12 @@ class Axis(Enum): NBONDS = "nbonds" +class AnyInt(int): + """AnyInt equals to any other integer.""" + def __eq__(self, other): + return True + + class DataError(Exception): """Data is not correct.""" @@ -64,6 +70,8 @@ def real_shape(self, system: "System") -> Tuple[int]: elif ii is Axis.NBONDS: # BondOrderSystem shape.append(system.get_nbonds()) + elif ii == -1: + shape.append(AnyInt(-1)) elif isinstance(ii, int): shape.append(ii) else: diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 9d63e1dd4..597e889d3 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -87,8 +87,9 @@ def to_system_data(folder, type_map=None, labels=True): f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format." ) continue + natoms = data["coords"].shape[1] shape = [ - -1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] + natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] ] all_data = [] for ii in sets: diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index fdb2fc649..c7a64ec47 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -86,8 +86,9 @@ def to_system_data(folder, type_map=None, labels=True): f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format." ) continue + natoms = data["coords"].shape[1] shape = [ - -1 if xx == dpdata.system.Axis.NATOMS else xx + natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] ] if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")): diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py index 29911b5f1..4c15115ad 100644 --- a/tests/plugin/dpdata_plugin_test/__init__.py +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -8,4 +8,8 @@ DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True ) +register_data_type( + DataType("bar", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False), labeled=True +) + ep = None diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 49c99586f..05569d8aa 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -43,3 +43,44 @@ def test_from_deepmd_hdf5(self): self.system.to_deepmd_hdf5("data_foo.h5") x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5") np.testing.assert_allclose(x.data["foo"], self.foo) + + +class TestDeepmdLoadDumpCompAny(unittest.TestCase): + def setUp(self): + self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") + self.bar = np.ones((len(self.system), self.system.get_natoms(), 2)) + self.system.data["bar"] = self.bar + self.system.check_data() + + def test_to_deepmd_raw(self): + self.system.to_deepmd_raw("data_bar") + bar = np.loadtxt("data_bar/bar.raw") + np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar) + + def test_from_deepmd_raw(self): + self.system.to_deepmd_raw("data_bar") + x = dpdata.LabeledSystem("data_bar", fmt="deepmd/raw") + np.testing.assert_allclose(x.data["bar"], self.bar) + + def test_to_deepmd_npy(self): + self.system.to_deepmd_npy("data_bar") + bar = np.load("data_bar/set.000/bar.npy") + np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar) + + def test_from_deepmd_npy(self): + self.system.to_deepmd_npy("data_bar") + x = dpdata.LabeledSystem("data_bar", fmt="deepmd/npy") + np.testing.assert_allclose(x.data["bar"], self.bar) + + def test_to_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_bar.h5") + with h5py.File("data_bar.h5") as f: + bar = f["set.000/bar.npy"][:] + np.testing.assert_allclose(bar.reshape(self.bar.shape), self.bar) + + def test_from_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_bar.h5") + x = dpdata.LabeledSystem("data_bar.h5", fmt="deepmd/hdf5") + np.testing.assert_allclose(x.data["bar"], self.bar) + + From d92ed20c423544a8f5bb7a8ac838d6a64de037ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:47:08 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/data_type.py | 1 + dpdata/deepmd/comp.py | 3 ++- tests/plugin/dpdata_plugin_test/__init__.py | 3 ++- tests/test_custom_data_type.py | 2 -- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 90ab7c06e..b5141a4e0 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -21,6 +21,7 @@ class Axis(Enum): class AnyInt(int): """AnyInt equals to any other integer.""" + def __eq__(self, other): return True diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 597e889d3..7b909b162 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -89,7 +89,8 @@ def to_system_data(folder, type_map=None, labels=True): continue natoms = data["coords"].shape[1] shape = [ - natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] + natoms if xx == dpdata.system.Axis.NATOMS else xx + for xx in dtype.shape[1:] ] all_data = [] for ii in sets: diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py index 4c15115ad..b3821cb34 100644 --- a/tests/plugin/dpdata_plugin_test/__init__.py +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -9,7 +9,8 @@ ) register_data_type( - DataType("bar", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False), labeled=True + DataType("bar", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False), + labeled=True, ) ep = None diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 05569d8aa..5a0e2bab9 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -82,5 +82,3 @@ def test_from_deepmd_hdf5(self): self.system.to_deepmd_hdf5("data_bar.h5") x = dpdata.LabeledSystem("data_bar.h5", fmt="deepmd/hdf5") np.testing.assert_allclose(x.data["bar"], self.bar) - -