From 2e1e993aa1dfbc8456360b5c69a11c33b9861f59 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 5 Jun 2024 22:56:16 +0800 Subject: [PATCH 1/5] fix: the replicate will fail if the atom types of system is not sorted --- dpdata/system.py | 6 +++--- tests/test_replicate.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index 2614bc23b..2de983685 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -776,9 +776,9 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]): tmp.data["atom_numbs"] = list( np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy) ) - tmp.data["atom_types"] = np.sort( - np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable" - ) + tmp.data["atom_types"] = np.tile(np.copy(data["atom_types"]), (np.prod(ncopy),) + (1,)) + tmp.data["atom_types"] = np.transpose(tmp.data["atom_types"]).reshape([-1]) + tmp.data["cells"] = np.copy(data["cells"]) for ii in range(3): tmp.data["cells"][:, ii, :] *= ncopy[ii] diff --git a/tests/test_replicate.py b/tests/test_replicate.py index 3add2dc02..bdd8c5bb5 100644 --- a/tests/test_replicate.py +++ b/tests/test_replicate.py @@ -2,6 +2,7 @@ import unittest +import numpy as np from comp_sys import CompSys, IsPBC from context import dpdata @@ -35,6 +36,24 @@ def setUp(self): self.system_2 = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar") self.places = 6 +class TestReplicateTriclinicBox(unittest.TestCase, CompSys, IsPBC): + def setUp(self): + self.system_1 = dpdata.System() + self.system_1.data["atom_names"] = ["foo", "bar"] + self.system_1.data["atom_types"] = np.array([1, 0], dtype=int) + self.system_1.data["atom_numbs"] = [1, 1] + self.system_1.data["cells"] = np.array([10, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float).reshape(1,3,3) + self.system_1.data["coords"] = np.array([0, 0, 0, 0, 0, 1], dtype=float).reshape(1,2,3) + self.system_1 = self.system_1.replicate([2,1,1]) + + self.system_2 = dpdata.System() + self.system_2.data["atom_names"] = ["foo", "bar"] + self.system_2.data["atom_types"] = np.array([1, 1, 0, 0], dtype=int) + self.system_2.data["atom_numbs"] = [2, 2] + self.system_2.data["cells"] = np.array([20, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float).reshape(1,3,3) + self.system_2.data["coords"] = np.array([0, 0, 0, 10, 0, 0, 0, 0, 1, 10, 0, 1], dtype=float).reshape(1,4,3) + self.places = 6 + if __name__ == "__main__": unittest.main() From 54c7c985cfa96eac7ca6aa93f48ff57783e60441 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Jun 2024 14:58:12 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/system.py | 4 +++- tests/test_replicate.py | 19 ++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index 2de983685..ae5fd97ce 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -776,7 +776,9 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]): tmp.data["atom_numbs"] = list( np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy) ) - tmp.data["atom_types"] = np.tile(np.copy(data["atom_types"]), (np.prod(ncopy),) + (1,)) + tmp.data["atom_types"] = np.tile( + np.copy(data["atom_types"]), (np.prod(ncopy),) + (1,) + ) tmp.data["atom_types"] = np.transpose(tmp.data["atom_types"]).reshape([-1]) tmp.data["cells"] = np.copy(data["cells"]) diff --git a/tests/test_replicate.py b/tests/test_replicate.py index bdd8c5bb5..ded164c17 100644 --- a/tests/test_replicate.py +++ b/tests/test_replicate.py @@ -36,22 +36,31 @@ def setUp(self): self.system_2 = dpdata.System("poscars/POSCAR.SiC", fmt="vasp/poscar") self.places = 6 + class TestReplicateTriclinicBox(unittest.TestCase, CompSys, IsPBC): def setUp(self): self.system_1 = dpdata.System() self.system_1.data["atom_names"] = ["foo", "bar"] self.system_1.data["atom_types"] = np.array([1, 0], dtype=int) self.system_1.data["atom_numbs"] = [1, 1] - self.system_1.data["cells"] = np.array([10, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float).reshape(1,3,3) - self.system_1.data["coords"] = np.array([0, 0, 0, 0, 0, 1], dtype=float).reshape(1,2,3) - self.system_1 = self.system_1.replicate([2,1,1]) + self.system_1.data["cells"] = np.array( + [10, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float + ).reshape(1, 3, 3) + self.system_1.data["coords"] = np.array( + [0, 0, 0, 0, 0, 1], dtype=float + ).reshape(1, 2, 3) + self.system_1 = self.system_1.replicate([2, 1, 1]) self.system_2 = dpdata.System() self.system_2.data["atom_names"] = ["foo", "bar"] self.system_2.data["atom_types"] = np.array([1, 1, 0, 0], dtype=int) self.system_2.data["atom_numbs"] = [2, 2] - self.system_2.data["cells"] = np.array([20, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float).reshape(1,3,3) - self.system_2.data["coords"] = np.array([0, 0, 0, 10, 0, 0, 0, 0, 1, 10, 0, 1], dtype=float).reshape(1,4,3) + self.system_2.data["cells"] = np.array( + [20, 0, 0, 0, 10, 0, 0, 0, 10], dtype=float + ).reshape(1, 3, 3) + self.system_2.data["coords"] = np.array( + [0, 0, 0, 10, 0, 0, 0, 0, 1, 10, 0, 1], dtype=float + ).reshape(1, 4, 3) self.places = 6 From c7bf8fccfcb7adf548e998629869926a41e252d1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 5 Jun 2024 23:04:02 +0800 Subject: [PATCH 3/5] fix pyright --- dpdata/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpdata/system.py b/dpdata/system.py index 2de983685..6d4a5904a 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -776,7 +776,7 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]): tmp.data["atom_numbs"] = list( np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy) ) - tmp.data["atom_types"] = np.tile(np.copy(data["atom_types"]), (np.prod(ncopy),) + (1,)) + tmp.data["atom_types"] = np.tile(np.copy(data["atom_types"]), (int(np.prod(ncopy)),) + (1,)) tmp.data["atom_types"] = np.transpose(tmp.data["atom_types"]).reshape([-1]) tmp.data["cells"] = np.copy(data["cells"]) From bfe862fead6a2b6cce1416c3f10dcee783cdce75 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 5 Jun 2024 23:30:45 +0800 Subject: [PATCH 4/5] fix ase bug... --- tests/test_predict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_predict.py b/tests/test_predict.py index 6ab00be36..2f8ddb919 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -106,7 +106,9 @@ def setUp(self): ) zero_driver = ZeroDriver() self.system_1 = ori_sys.predict(driver=zero_driver) - self.system_2 = ori_sys.minimize(driver=zero_driver, minimizer="ase") + self.system_2 = ori_sys.minimize( + driver=zero_driver, minimizer="ase", max_steps=100 + ) self.places = 6 self.e_places = 6 self.f_places = 6 From 42ceb08ebaf700c592b97a2f76379f54404c068c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 5 Jun 2024 23:38:42 +0800 Subject: [PATCH 5/5] fix ase bug... --- tests/test_predict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_predict.py b/tests/test_predict.py index 2f8ddb919..9721323ec 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -125,7 +125,9 @@ def setUp(self): zero_driver = ZeroDriver() self.system_1 = list(multi_sys.predict(driver=zero_driver).systems.values())[0] self.system_2 = list( - multi_sys.minimize(driver=zero_driver, minimizer="ase").systems.values() + multi_sys.minimize( + driver=zero_driver, minimizer="ase", max_steps=100 + ).systems.values() )[0] self.places = 6 self.e_places = 6