Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,23 @@ def log_loss_valid(_task_key="Default"):
f"The profiling trace have been saved to: {self.profiling_file}"
)

def delete_dataloader(self):
if self.multi_task:
for model_key in self.model_keys:
del (
self.training_data[model_key],
self.training_dataloader[model_key],
self.validation_data[model_key],
self.validation_dataloader[model_key],
)
else:
del (
self.training_data,
self.training_dataloader,
self.validation_data,
self.validation_dataloader,
)
Comment on lines +1050 to +1055

Check warning

Code scanning / CodeQL

Unnecessary delete statement in function

Unnecessary deletion of local variable [Tuple](1) in function [delete_dataloader](2).

def save_model(self, save_path, lr=0.0, step=0):
module = (
self.wrapper.module
Expand Down
20 changes: 17 additions & 3 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
Expand Down Expand Up @@ -34,6 +35,7 @@ def test_dp_train(self):
# test training from scratch
trainer = get_trainer(deepcopy(self.config))
trainer.run()
trainer.delete_dataloader()
state_dict_trained = trainer.wrapper.model.state_dict()

# test fine-tuning using same input
Expand Down Expand Up @@ -100,6 +102,11 @@ def test_dp_train(self):
trainer_finetune_empty.run()
trainer_finetune_random.run()

# delete dataloader to stop buffer fetching
trainer_finetune.delete_dataloader()
trainer_finetune_empty.delete_dataloader()
trainer_finetune_random.delete_dataloader()

def test_trainable(self):
fix_params = deepcopy(self.config)
fix_params["model"]["descriptor"]["trainable"] = False
Expand Down Expand Up @@ -195,18 +202,25 @@ def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.original_data_path = Path(__file__).parent / "water/data/data_0"
# Create a temporary directory for this test
self.temp_dir = Path(tempfile.mkdtemp())
self.temp_data_path = self.temp_dir / "data_0"
shutil.copytree(self.original_data_path, self.temp_data_path)

data_file = [str(self.temp_data_path)]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["model"]["fitting_net"]["numb_fparam"] = 1
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
self.set_path = self.temp_data_path / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
# Remove the temporary directory and all its contents
shutil.rmtree(self.temp_dir)
DPTrainTest.tearDown(self)


Expand Down