From 5301b537e4b89ac51ac483154feb8ac4124663e9 Mon Sep 17 00:00:00 2001 From: robinzyb <38876805+robinzyb@users.noreply.github.com> Date: Thu, 1 Feb 2024 21:00:06 +0100 Subject: [PATCH 1/2] fix invalid condition in compLabeledSys in unittest --- tests/comp_sys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/comp_sys.py b/tests/comp_sys.py index cfa92c497..f4663780b 100644 --- a/tests/comp_sys.py +++ b/tests/comp_sys.py @@ -92,8 +92,8 @@ def test_virial(self): # if len(self.system_1['virials']) == 0: # self.assertEqual(len(self.system_1['virials']), 0) # return - if "virials" not in self.system_1: - self.assertFalse("virials" in self.system_2) + if not self.system_1.has_virial(): + self.assertFalse(self.system_2.has_virial()) return np.testing.assert_almost_equal( self.system_1["virials"], From dd606fba096237f29b21818269ceb259a1815e2f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 1 Feb 2024 17:03:17 -0500 Subject: [PATCH 2/2] fix virial in HybridDriver --- dpdata/driver.py | 2 ++ tests/test_predict.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dpdata/driver.py b/dpdata/driver.py index de72d3fe7..71a1e0b35 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -163,6 +163,8 @@ def label(self, data: dict) -> dict: else: labeled_data["energies"] += lb_data["energies"] labeled_data["forces"] += lb_data["forces"] + if "virials" in labeled_data and "virials" in lb_data: + labeled_data["virials"] += lb_data["virials"] return labeled_data diff --git a/tests/test_predict.py b/tests/test_predict.py index ad85464a2..f08125ab2 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -72,7 +72,7 @@ def setUp(self): self.system_2 = dpdata.LabeledSystem( "poscars/deepmd.h2o.md", fmt="deepmd/raw", type_map=["O", "H"] ) - for pp in ("energies", "forces"): + for pp in ("energies", "forces", "virials"): self.system_2.data[pp][:] = 3.0 self.places = 6