Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions deepmd/pd/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,41 @@
if isinstance(batch_size, str):
if batch_size == "auto":
rule = 32
ceiling = True

Check warning on line 129 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L129

Added line #L129 was not covered by tests
elif batch_size.startswith("auto:"):
rule = int(batch_size.split(":")[1])
ceiling = True
elif batch_size.startswith("max:"):
rule = int(batch_size.split(":")[1])
ceiling = False
elif batch_size.startswith("filter:"):

Check warning on line 136 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L132-L136

Added lines #L132 - L136 were not covered by tests
# remove system with more than `filter` atoms
rule = int(batch_size.split(":")[1])
len_before = len(self.systems)
self.systems = [

Check warning on line 140 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L138-L140

Added lines #L138 - L140 were not covered by tests
system for system in self.systems if system._natoms <= rule
]
len_after = len(self.systems)
if len_before != len_after:
log.warning(

Check warning on line 145 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L143-L145

Added lines #L143 - L145 were not covered by tests
f"Remove {len_before - len_after} systems with more than {rule} atoms"
)
if len(self.systems) == 0:
raise ValueError(

Check warning on line 149 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L148-L149

Added lines #L148 - L149 were not covered by tests
f"No system left after removing systems with more than {rule} atoms"
)
ceiling = False

Check warning on line 152 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L152

Added line #L152 was not covered by tests
else:
rule = None
log.error("Unsupported batch size type")
raise ValueError(f"Unsupported batch size rule: {batch_size}")

Check warning on line 154 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L154

Added line #L154 was not covered by tests
for ii in self.systems:
ni = ii._natoms
bsi = rule // ni
if bsi * ni < rule:
bsi += 1
if ceiling:
if bsi * ni < rule:
bsi += 1

Check warning on line 160 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L158-L160

Added lines #L158 - L160 were not covered by tests
else:
if bsi == 0:
bsi = 1

Check warning on line 163 in deepmd/pd/utils/dataloader.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/utils/dataloader.py#L162-L163

Added lines #L162 - L163 were not covered by tests
self.batch_sizes.append(bsi)
elif isinstance(batch_size, list):
self.batch_sizes = batch_size
Expand Down
33 changes: 29 additions & 4 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,41 @@ def construct_dataset(system):
if isinstance(batch_size, str):
if batch_size == "auto":
rule = 32
ceiling = True
elif batch_size.startswith("auto:"):
rule = int(batch_size.split(":")[1])
ceiling = True
elif batch_size.startswith("max:"):
rule = int(batch_size.split(":")[1])
ceiling = False
elif batch_size.startswith("filter:"):
# remove system with more than `filter` atoms
rule = int(batch_size.split(":")[1])
len_before = len(self.systems)
self.systems = [
system for system in self.systems if system._natoms <= rule
]
len_after = len(self.systems)
if len_before != len_after:
log.warning(
f"Remove {len_before - len_after} systems with more than {rule} atoms"
)
if len(self.systems) == 0:
raise ValueError(
f"No system left after removing systems with more than {rule} atoms"
)
ceiling = False
else:
rule = None
log.error("Unsupported batch size type")
raise ValueError(f"Unsupported batch size rule: {batch_size}")
for ii in self.systems:
ni = ii._natoms
bsi = rule // ni
if bsi * ni < rule:
bsi += 1
if ceiling:
if bsi * ni < rule:
bsi += 1
else:
if bsi == 0:
bsi = 1
self.batch_sizes.append(bsi)
elif isinstance(batch_size, list):
self.batch_sizes = batch_size
Expand Down
2 changes: 2 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,6 +2831,8 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
If MPI is used, the value should be considered as the batch size per task.'
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
Expand Down
4 changes: 3 additions & 1 deletion doc/train/training-advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ The sections {ref}`training_data <training/training_data>` and {ref}`validation_
- `list`: the length of which is the same as the {ref}`systems`. The batch size of each system is given by the elements of the list.
- `int`: all systems use the same batch size.
- `"auto"`: the same as `"auto:32"`, see `"auto:N"`
- `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size <training/training_data/batch_size>` times the number of atoms in the system is no less than `N`.
- `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size <training/training_data/batch_size>` times the number of atoms in the system is **no less than** `N`.
- `"max:N"`: automatically determines the batch size so that the {ref}`batch_size <training/training_data/batch_size>` times the number of atoms in the system is **no more than** `N`. The minimum batch size is 1. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }}
- `"filter:N"`: the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set. Throws an error if no system is left in a dataset. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }}
- The key {ref}`numb_batch <training/validation_data/numb_btch>` in {ref}`validate_data <training/validation_data>` gives the number of batches of model validation. Note that the batches may not be from the same system

The section {ref}`mixed_precision <training/mixed_precision>` specifies the mixed precision settings, which will enable the mixed precision training workflow for DeePMD-kit. The keys are explained below:
Expand Down
86 changes: 86 additions & 0 deletions source/tests/pt/test_dploaderset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import unittest
from pathlib import (
Path,
)

from deepmd.common import (
expand_sys_str,
)
from deepmd.pt.utils.dataloader import (
DpLoaderSet,
)


class TestSampler(unittest.TestCase):
def setUp(self) -> None:
with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin:
content = fin.read()
config = json.loads(content)
data_file = [
str(Path(__file__).parent / "model/water/data/data_0"),
]
config["training"]["training_data"]["systems"] = data_file
config["training"]["validation_data"]["systems"] = data_file
model_config = config["model"]
self.rcut = model_config["descriptor"]["rcut"]
self.rcut_smth = model_config["descriptor"]["rcut_smth"]
self.sel = model_config["descriptor"]["sel"]
self.batch_size = config["training"]["training_data"]["batch_size"]
self.systems = config["training"]["validation_data"]["systems"]
self.type_map = model_config["type_map"]
if isinstance(self.systems, str):
self.systems = expand_sys_str(self.systems)

def get_batch_sizes(self, batch_size) -> int:
dataset = DpLoaderSet(
self.systems,
batch_size,
self.type_map,
seed=10,
shuffle=False,
)
return dataset.batch_sizes[0]

def test_batchsize(self) -> None:
# 192 atoms, 1 system
assert len(self.systems) == 1

# test: batch_size:int
self.assertEqual(self.get_batch_sizes(3), 3)

# test: batch_size:list[int]
self.assertEqual(self.get_batch_sizes([3]), 3)

# test: batch_size:str = "auto"
self.assertEqual(self.get_batch_sizes("auto:384"), 2)
self.assertEqual(self.get_batch_sizes("auto:383"), 2)
self.assertEqual(self.get_batch_sizes("auto:193"), 2)
self.assertEqual(self.get_batch_sizes("auto:192"), 1)
self.assertEqual(self.get_batch_sizes("auto:191"), 1)
self.assertEqual(self.get_batch_sizes("auto:32"), 1)
self.assertEqual(self.get_batch_sizes("auto"), 1)

# test: batch_size:str = "max"
self.assertEqual(self.get_batch_sizes("max:384"), 2)
self.assertEqual(self.get_batch_sizes("max:383"), 1)
self.assertEqual(self.get_batch_sizes("max:193"), 1)
self.assertEqual(self.get_batch_sizes("max:192"), 1)
self.assertEqual(self.get_batch_sizes("max:191"), 1)

# test: batch_size:str = "filter"
self.assertEqual(self.get_batch_sizes("filter:193"), 1)
self.assertEqual(self.get_batch_sizes("filter:192"), 1)
with self.assertLogs(logger="deepmd") as cm:
self.assertRaises(ValueError, self.get_batch_sizes, "filter:191")
self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1])

# test: unknown batch_size: str
with self.assertRaises(ValueError) as context:
self.get_batch_sizes("unknown")
self.assertIn("Unsupported batch size rule: unknown", str(context.exception))


if __name__ == "__main__":
unittest.main()