From 33f4c1598a500b0f6433360283f72f8e32965418 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 17 Mar 2025 15:11:33 +0800 Subject: [PATCH 01/17] feat: add new batch size rules for large systems --- deepmd/pt/utils/dataloader.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 1c7a1884d3..44ab20046b 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -120,16 +120,37 @@ def construct_dataset(system): if isinstance(batch_size, str): if batch_size == "auto": rule = 32 + ceiling = True elif batch_size.startswith("auto:"): rule = int(batch_size.split(":")[1]) + ceiling = True + elif batch_size.startswith("max:"): + rule = int(batch_size.split(":")[1]) + ceiling = False + elif batch_size.startswith("cap:"): + # remove system with more than `cap` atoms + rule = int(batch_size.split(":")[1]) + len_before = len(self.systems) + self.systems = [ + system for system in self.systems if system._natoms <= rule + ] + len_after = len(self.systems) + if len_before != len_after: + logging.warning( + f"Remove {len_before - len_after} systems with more than {rule} atoms" + ) + ceiling = False else: - rule = None - log.error("Unsupported batch size type") + raise ValueError(f"Unsupported batch size rule: {batch_size}") for ii in self.systems: ni = ii._natoms bsi = rule // ni - if bsi * ni < rule: - bsi += 1 + if ceiling: + if bsi * ni < rule: + bsi += 1 + else: + if bsi == 0: + bsi = 1 self.batch_sizes.append(bsi) elif isinstance(batch_size, list): self.batch_sizes = batch_size From 8ba1d28f993d603e930fc4d295aabb07b5a34217 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 17 Mar 2025 15:22:39 +0800 Subject: [PATCH 02/17] use logger for class scope --- deepmd/pt/utils/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 44ab20046b..a725b1b0c4 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -136,7 +136,7 @@ def construct_dataset(system): ] len_after = len(self.systems) if len_before != len_after: - logging.warning( + log.warning( f"Remove {len_before - len_after} systems with more than {rule} atoms" ) ceiling = False From ec47d3170b189e2db29f1ac1b999b7ddf539a4ff Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 15:23:34 +0800 Subject: [PATCH 03/17] ensure at least one system left --- deepmd/pt/utils/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index a725b1b0c4..0022ec4d8d 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -139,6 +139,7 @@ def construct_dataset(system): log.warning( f"Remove {len_before - len_after} systems with more than {rule} atoms" ) + assert len(self.systems) > 0, "No system left after removing" ceiling = False else: raise ValueError(f"Unsupported batch size rule: {batch_size}") From 18a0f13ae2132d2aaff591b312496a70794e75ec Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 15:25:22 +0800 Subject: [PATCH 04/17] change keyword name to `filter` --- deepmd/pt/utils/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 0022ec4d8d..fcfe9f5629 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -127,8 +127,8 @@ def construct_dataset(system): elif batch_size.startswith("max:"): rule = int(batch_size.split(":")[1]) ceiling = False - elif batch_size.startswith("cap:"): - # remove system with more than `cap` atoms + elif batch_size.startswith("filter:"): + # remove system with more than `filter` atoms rule = int(batch_size.split(":")[1]) len_before = len(self.systems) self.systems = [ From f169a98bd8766e9b1845fb9a489e4219e03ef54c Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 15:25:27 +0800 Subject: [PATCH 05/17] add ut --- source/tests/pt/test_dploaderset.py | 75 +++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 source/tests/pt/test_dploaderset.py diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py new file mode 100644 index 0000000000..e9ae744713 --- /dev/null +++ b/source/tests/pt/test_dploaderset.py @@ -0,0 +1,75 @@ +import json +import unittest +from pathlib import Path + + +from deepmd.common import ( + expand_sys_str, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) + +class TestSampler(unittest.TestCase): + def setUp(self) -> None: + with open( + str(Path(__file__).parent / "water/se_e2_a.json") + ) as fin: + content = fin.read() + config = json.loads(content) + data_file = [str(Path(__file__).parent / "model/water/data/data_0"),] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.batch_size = config["training"]["training_data"]["batch_size"] + self.systems = config["training"]["validation_data"]["systems"] + self.type_map = model_config["type_map"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + def get_batch_sizes(self,batch_size) -> list[int]: + dataset = DpLoaderSet( + self.systems, + batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + return dataset.batch_sizes[0] + def test_batchsize(self) -> None: + # 192 atoms, 1 system + assert len(self.systems) == 1 + + # test: batch_size:int + self.assertEqual(self.get_batch_sizes(3), 3) + + # test: batch_size:list[int] + self.assertEqual(self.get_batch_sizes([3]), 3) + + # test: batch_size:str = "auto" + self.assertEqual(self.get_batch_sizes("auto:384"), 2) + self.assertEqual(self.get_batch_sizes("auto:383"), 2) + self.assertEqual(self.get_batch_sizes("auto:193"), 2) + self.assertEqual(self.get_batch_sizes("auto:192"), 1) + self.assertEqual(self.get_batch_sizes("auto:191"), 1) + self.assertEqual(self.get_batch_sizes("auto:32"), 1) + self.assertEqual(self.get_batch_sizes("auto"), 1) + + # test: batch_size:str = "max" + self.assertEqual(self.get_batch_sizes("max:384"), 2) + self.assertEqual(self.get_batch_sizes("max:383"), 1) + self.assertEqual(self.get_batch_sizes("max:193"), 1) + self.assertEqual(self.get_batch_sizes("max:192"), 1) + self.assertEqual(self.get_batch_sizes("max:191"), 1) + + # test: batch_size:str = "filter" + self.assertEqual(self.get_batch_sizes("filter:193"), 1) + self.assertEqual(self.get_batch_sizes("filter:192"), 1) + with self.assertLogs() as cm: + self.assertRaises(AssertionError, self.get_batch_sizes, "filter:191") + self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) + +if __name__ == "__main__": + unittest.main() From 7ade4af3efb734b6439175344bc7b83e7770a274 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Mar 2025 07:28:02 +0000 Subject: [PATCH 06/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_dploaderset.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index e9ae744713..75cb31ac03 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -1,7 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import json import unittest -from pathlib import Path - +from pathlib import ( + Path, +) from deepmd.common import ( expand_sys_str, @@ -10,14 +12,15 @@ DpLoaderSet, ) + class TestSampler(unittest.TestCase): def setUp(self) -> None: - with open( - str(Path(__file__).parent / "water/se_e2_a.json") - ) as fin: + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: content = fin.read() config = json.loads(content) - data_file = [str(Path(__file__).parent / "model/water/data/data_0"),] + data_file = [ + str(Path(__file__).parent / "model/water/data/data_0"), + ] config["training"]["training_data"]["systems"] = data_file config["training"]["validation_data"]["systems"] = data_file model_config = config["model"] @@ -29,7 +32,8 @@ def setUp(self) -> None: self.type_map = model_config["type_map"] if isinstance(self.systems, str): self.systems = expand_sys_str(self.systems) - def get_batch_sizes(self,batch_size) -> list[int]: + + def get_batch_sizes(self, batch_size) -> list[int]: dataset = DpLoaderSet( self.systems, batch_size, @@ -38,6 +42,7 @@ def get_batch_sizes(self,batch_size) -> list[int]: shuffle=False, ) return dataset.batch_sizes[0] + def test_batchsize(self) -> None: # 192 atoms, 1 system assert len(self.systems) == 1 @@ -71,5 +76,6 @@ def test_batchsize(self) -> None: self.assertRaises(AssertionError, self.get_batch_sizes, "filter:191") self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) + if __name__ == "__main__": unittest.main() From 10defcbb5a06a0d4bd5a9c023367b56a0fc38511 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 15:40:25 +0800 Subject: [PATCH 07/17] add test for unknown batch size handling --- source/tests/pt/test_dploaderset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index 75cb31ac03..f5b93aedfa 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -76,6 +76,10 @@ def test_batchsize(self) -> None: self.assertRaises(AssertionError, self.get_batch_sizes, "filter:191") self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) + # test: unknown batch_size: str + with self.assertRaises(ValueError, msg="Unsupported batch size rule: unknown"): + self.get_batch_sizes("unknown") + if __name__ == "__main__": unittest.main() From 199d9b797222f7aee0f6ec631c0a5a593b7d6cb8 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 16:58:34 +0800 Subject: [PATCH 08/17] add debug info --- source/tests/pt/test_dploaderset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index f5b93aedfa..6d44c45077 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -74,6 +74,7 @@ def test_batchsize(self) -> None: self.assertEqual(self.get_batch_sizes("filter:192"), 1) with self.assertLogs() as cm: self.assertRaises(AssertionError, self.get_batch_sizes, "filter:191") + print(cm.output) # DEBUG self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) # test: unknown batch_size: str From 63392f688e6d170b48fb9bb68c39af71a54295d6 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 17:07:53 +0800 Subject: [PATCH 09/17] Update source/tests/pt/test_dploaderset.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Chun Cai --- source/tests/pt/test_dploaderset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index 6d44c45077..3c92da4b54 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -78,8 +78,9 @@ def test_batchsize(self) -> None: self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) # test: unknown batch_size: str - with self.assertRaises(ValueError, msg="Unsupported batch size rule: unknown"): + with self.assertRaises(ValueError) as context: self.get_batch_sizes("unknown") + self.assertIn("Unsupported batch size rule: unknown", str(context.exception)) if __name__ == "__main__": From 5bf8fffc6d702442db50abc024f3c2304679474f Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 20 Mar 2025 17:08:05 +0800 Subject: [PATCH 10/17] Update source/tests/pt/test_dploaderset.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Chun Cai --- source/tests/pt/test_dploaderset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index 3c92da4b54..2d913b8a7f 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -33,7 +33,7 @@ def setUp(self) -> None: if isinstance(self.systems, str): self.systems = expand_sys_str(self.systems) - def get_batch_sizes(self, batch_size) -> list[int]: + def get_batch_sizes(self, batch_size) -> int: dataset = DpLoaderSet( self.systems, batch_size, From 5fa35820b612e49e1b84623cf9fe402b4ff00da6 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Fri, 21 Mar 2025 10:10:57 +0800 Subject: [PATCH 11/17] test: capture log by explicitly setting logger name `deepmd` --- source/tests/pt/test_dploaderset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index 2d913b8a7f..f8708f708d 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -72,10 +72,9 @@ def test_batchsize(self) -> None: # test: batch_size:str = "filter" self.assertEqual(self.get_batch_sizes("filter:193"), 1) self.assertEqual(self.get_batch_sizes("filter:192"), 1) - with self.assertLogs() as cm: + with self.assertLogs(logger="deepmd") as cm: self.assertRaises(AssertionError, self.get_batch_sizes, "filter:191") - print(cm.output) # DEBUG - self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) + self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) # test: unknown batch_size: str with self.assertRaises(ValueError) as context: From b43a4ab5d52f95b5bd1328ed83bbceddfd51b5bb Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Fri, 21 Mar 2025 14:01:06 +0800 Subject: [PATCH 12/17] add docs --- doc/train/training-advanced.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index d21feb2126..a8884d8900 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -106,7 +106,9 @@ The sections {ref}`training_data ` and {ref}`validation_ - `list`: the length of which is the same as the {ref}`systems`. The batch size of each system is given by the elements of the list. - `int`: all systems use the same batch size. - `"auto"`: the same as `"auto:32"`, see `"auto:N"` - - `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is no less than `N`. + - `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no less than** `N`. + - `"max:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no more than** `N`. The minimum batch size is 1. + - `"filter:N"`: the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set. Throws an error if no system is left in a dataset. - The key {ref}`numb_batch ` in {ref}`validate_data ` gives the number of batches of model validation. Note that the batches may not be from the same system The section {ref}`mixed_precision ` specifies the mixed precision settings, which will enable the mixed precision training workflow for DeePMD-kit. The keys are explained below: From 456868ffbdf9e5f085db4e6d19f00007df87fe95 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Fri, 21 Mar 2025 14:56:21 +0800 Subject: [PATCH 13/17] update argcheck params --- deepmd/utils/argcheck.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 47071066ae..a1afcaf1d0 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2831,6 +2831,8 @@ def training_data_args(): # ! added by Ziyao: new specification style for data - string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\ - string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\ - string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\ +- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\ +- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\ If MPI is used, the value should be considered as the batch size per task.' doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\ From 8a5fc8026e7037b708dbdfe32c1f55d7189912dd Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 24 Mar 2025 10:29:42 +0800 Subject: [PATCH 14/17] avoid using assert --- deepmd/pt/utils/dataloader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index fcfe9f5629..851a4713e9 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -139,7 +139,10 @@ def construct_dataset(system): log.warning( f"Remove {len_before - len_after} systems with more than {rule} atoms" ) - assert len(self.systems) > 0, "No system left after removing" + if len(self.systems) == 0: + raise ValueError( + f"No system left after removing systems with more than {rule} atoms" + ) ceiling = False else: raise ValueError(f"Unsupported batch size rule: {batch_size}") From 281959e3818570a828202b79183cea893ad7f7f8 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 24 Mar 2025 10:31:37 +0800 Subject: [PATCH 15/17] port the impl to paddle backend --- deepmd/pd/utils/dataloader.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/deepmd/pd/utils/dataloader.py b/deepmd/pd/utils/dataloader.py index 221a5e776d..0cb8adbc63 100644 --- a/deepmd/pd/utils/dataloader.py +++ b/deepmd/pd/utils/dataloader.py @@ -126,16 +126,41 @@ def construct_dataset(system): if isinstance(batch_size, str): if batch_size == "auto": rule = 32 + ceiling = True elif batch_size.startswith("auto:"): rule = int(batch_size.split(":")[1]) + ceiling = True + elif batch_size.startswith("max:"): + rule = int(batch_size.split(":")[1]) + ceiling = False + elif batch_size.startswith("filter:"): + # remove system with more than `filter` atoms + rule = int(batch_size.split(":")[1]) + len_before = len(self.systems) + self.systems = [ + system for system in self.systems if system._natoms <= rule + ] + len_after = len(self.systems) + if len_before != len_after: + log.warning( + f"Remove {len_before - len_after} systems with more than {rule} atoms" + ) + if len(self.systems) == 0: + raise ValueError( + f"No system left after removing systems with more than {rule} atoms" + ) + ceiling = False else: - rule = None - log.error("Unsupported batch size type") + raise ValueError(f"Unsupported batch size rule: {batch_size}") for ii in self.systems: ni = ii._natoms bsi = rule // ni - if bsi * ni < rule: - bsi += 1 + if ceiling: + if bsi * ni < rule: + bsi += 1 + else: + if bsi == 0: + bsi = 1 self.batch_sizes.append(bsi) elif isinstance(batch_size, list): self.batch_sizes = batch_size From aa07bf1082b7234483f2287488d3f78037c101e0 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 24 Mar 2025 10:45:36 +0800 Subject: [PATCH 16/17] update docs --- doc/train/training-advanced.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index a8884d8900..174c39d6d9 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -107,8 +107,8 @@ The sections {ref}`training_data ` and {ref}`validation_ - `int`: all systems use the same batch size. - `"auto"`: the same as `"auto:32"`, see `"auto:N"` - `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no less than** `N`. - - `"max:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no more than** `N`. The minimum batch size is 1. - - `"filter:N"`: the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set. Throws an error if no system is left in a dataset. + - `"max:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no more than** `N`. The minimum batch size is 1. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }} + - `"filter:N"`: the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set. Throws an error if no system is left in a dataset. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }} - The key {ref}`numb_batch ` in {ref}`validate_data ` gives the number of batches of model validation. Note that the batches may not be from the same system The section {ref}`mixed_precision ` specifies the mixed precision settings, which will enable the mixed precision training workflow for DeePMD-kit. The keys are explained below: From bb4184784b6025a3b9ef3512e9ef21725438e849 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Mon, 24 Mar 2025 15:55:34 +0800 Subject: [PATCH 17/17] fix UT --- source/tests/pt/test_dploaderset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py index f8708f708d..5d0382dce5 100644 --- a/source/tests/pt/test_dploaderset.py +++ b/source/tests/pt/test_dploaderset.py @@ -73,7 +73,7 @@ def test_batchsize(self) -> None: self.assertEqual(self.get_batch_sizes("filter:193"), 1) self.assertEqual(self.get_batch_sizes("filter:192"), 1) with self.assertLogs(logger="deepmd") as cm: - self.assertRaises(AssertionError, self.get_batch_sizes, "filter:191") + self.assertRaises(ValueError, self.get_batch_sizes, "filter:191") self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) # test: unknown batch_size: str