diff --git a/deepmd/pd/utils/dataloader.py b/deepmd/pd/utils/dataloader.py index 221a5e776d..0cb8adbc63 100644 --- a/deepmd/pd/utils/dataloader.py +++ b/deepmd/pd/utils/dataloader.py @@ -126,16 +126,41 @@ def construct_dataset(system): if isinstance(batch_size, str): if batch_size == "auto": rule = 32 + ceiling = True elif batch_size.startswith("auto:"): rule = int(batch_size.split(":")[1]) + ceiling = True + elif batch_size.startswith("max:"): + rule = int(batch_size.split(":")[1]) + ceiling = False + elif batch_size.startswith("filter:"): + # remove system with more than `filter` atoms + rule = int(batch_size.split(":")[1]) + len_before = len(self.systems) + self.systems = [ + system for system in self.systems if system._natoms <= rule + ] + len_after = len(self.systems) + if len_before != len_after: + log.warning( + f"Remove {len_before - len_after} systems with more than {rule} atoms" + ) + if len(self.systems) == 0: + raise ValueError( + f"No system left after removing systems with more than {rule} atoms" + ) + ceiling = False else: - rule = None - log.error("Unsupported batch size type") + raise ValueError(f"Unsupported batch size rule: {batch_size}") for ii in self.systems: ni = ii._natoms bsi = rule // ni - if bsi * ni < rule: - bsi += 1 + if ceiling: + if bsi * ni < rule: + bsi += 1 + else: + if bsi == 0: + bsi = 1 self.batch_sizes.append(bsi) elif isinstance(batch_size, list): self.batch_sizes = batch_size diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 1c7a1884d3..851a4713e9 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -120,16 +120,41 @@ def construct_dataset(system): if isinstance(batch_size, str): if batch_size == "auto": rule = 32 + ceiling = True elif batch_size.startswith("auto:"): rule = int(batch_size.split(":")[1]) + ceiling = True + elif batch_size.startswith("max:"): + rule = int(batch_size.split(":")[1]) + ceiling = False + elif batch_size.startswith("filter:"): + # remove system with more than `filter` atoms + rule = int(batch_size.split(":")[1]) + len_before = len(self.systems) + self.systems = [ + system for system in self.systems if system._natoms <= rule + ] + len_after = len(self.systems) + if len_before != len_after: + log.warning( + f"Remove {len_before - len_after} systems with more than {rule} atoms" + ) + if len(self.systems) == 0: + raise ValueError( + f"No system left after removing systems with more than {rule} atoms" + ) + ceiling = False else: - rule = None - log.error("Unsupported batch size type") + raise ValueError(f"Unsupported batch size rule: {batch_size}") for ii in self.systems: ni = ii._natoms bsi = rule // ni - if bsi * ni < rule: - bsi += 1 + if ceiling: + if bsi * ni < rule: + bsi += 1 + else: + if bsi == 0: + bsi = 1 self.batch_sizes.append(bsi) elif isinstance(batch_size, list): self.batch_sizes = batch_size diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 47071066ae..a1afcaf1d0 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2831,6 +2831,8 @@ def training_data_args(): # ! added by Ziyao: new specification style for data - string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\ - string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\ - string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\ +- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\ +- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\ If MPI is used, the value should be considered as the batch size per task.' doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\ diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index d21feb2126..174c39d6d9 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -106,7 +106,9 @@ The sections {ref}`training_data ` and {ref}`validation_ - `list`: the length of which is the same as the {ref}`systems`. The batch size of each system is given by the elements of the list. - `int`: all systems use the same batch size. - `"auto"`: the same as `"auto:32"`, see `"auto:N"` - - `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is no less than `N`. + - `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no less than** `N`. + - `"max:N"`: automatically determines the batch size so that the {ref}`batch_size ` times the number of atoms in the system is **no more than** `N`. The minimum batch size is 1. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }} + - `"filter:N"`: the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set. Throws an error if no system is left in a dataset. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }} - The key {ref}`numb_batch ` in {ref}`validate_data ` gives the number of batches of model validation. Note that the batches may not be from the same system The section {ref}`mixed_precision ` specifies the mixed precision settings, which will enable the mixed precision training workflow for DeePMD-kit. The keys are explained below: diff --git a/source/tests/pt/test_dploaderset.py b/source/tests/pt/test_dploaderset.py new file mode 100644 index 0000000000..5d0382dce5 --- /dev/null +++ b/source/tests/pt/test_dploaderset.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import unittest +from pathlib import ( + Path, +) + +from deepmd.common import ( + expand_sys_str, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) + + +class TestSampler(unittest.TestCase): + def setUp(self) -> None: + with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin: + content = fin.read() + config = json.loads(content) + data_file = [ + str(Path(__file__).parent / "model/water/data/data_0"), + ] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + model_config = config["model"] + self.rcut = model_config["descriptor"]["rcut"] + self.rcut_smth = model_config["descriptor"]["rcut_smth"] + self.sel = model_config["descriptor"]["sel"] + self.batch_size = config["training"]["training_data"]["batch_size"] + self.systems = config["training"]["validation_data"]["systems"] + self.type_map = model_config["type_map"] + if isinstance(self.systems, str): + self.systems = expand_sys_str(self.systems) + + def get_batch_sizes(self, batch_size) -> int: + dataset = DpLoaderSet( + self.systems, + batch_size, + self.type_map, + seed=10, + shuffle=False, + ) + return dataset.batch_sizes[0] + + def test_batchsize(self) -> None: + # 192 atoms, 1 system + assert len(self.systems) == 1 + + # test: batch_size:int + self.assertEqual(self.get_batch_sizes(3), 3) + + # test: batch_size:list[int] + self.assertEqual(self.get_batch_sizes([3]), 3) + + # test: batch_size:str = "auto" + self.assertEqual(self.get_batch_sizes("auto:384"), 2) + self.assertEqual(self.get_batch_sizes("auto:383"), 2) + self.assertEqual(self.get_batch_sizes("auto:193"), 2) + self.assertEqual(self.get_batch_sizes("auto:192"), 1) + self.assertEqual(self.get_batch_sizes("auto:191"), 1) + self.assertEqual(self.get_batch_sizes("auto:32"), 1) + self.assertEqual(self.get_batch_sizes("auto"), 1) + + # test: batch_size:str = "max" + self.assertEqual(self.get_batch_sizes("max:384"), 2) + self.assertEqual(self.get_batch_sizes("max:383"), 1) + self.assertEqual(self.get_batch_sizes("max:193"), 1) + self.assertEqual(self.get_batch_sizes("max:192"), 1) + self.assertEqual(self.get_batch_sizes("max:191"), 1) + + # test: batch_size:str = "filter" + self.assertEqual(self.get_batch_sizes("filter:193"), 1) + self.assertEqual(self.get_batch_sizes("filter:192"), 1) + with self.assertLogs(logger="deepmd") as cm: + self.assertRaises(ValueError, self.get_batch_sizes, "filter:191") + self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1]) + + # test: unknown batch_size: str + with self.assertRaises(ValueError) as context: + self.get_batch_sizes("unknown") + self.assertIn("Unsupported batch size rule: unknown", str(context.exception)) + + +if __name__ == "__main__": + unittest.main()