Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
# init data
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
train_data.print_summary("training")
if jdata["training"]["validation_data"] is not None:
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier)
valid_data.print_summary("validation")
else:
Expand Down
32 changes: 20 additions & 12 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,14 @@ def train (self, train_data, valid_data=None) :
while cur_batch < stop_batch :

# first round validation:
train_batch = train_data.get_batch()
if self.display_in_training and is_first_step:
self.valid_on_the_fly(fp, train_data, valid_data, print_header=True)
valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None
self.valid_on_the_fly(fp, [train_batch], valid_batches, print_header=True)
is_first_step = False

if self.timing_in_training: tic = time.time()
train_feed_dict = self.get_feed_dict(train_data.get_batch(), is_training=True)
train_feed_dict = self.get_feed_dict(train_batch, is_training=True)
# use tensorboard to visualize the training of deepmd-kit
# it will takes some extra execution time to generate the tensorboard data
if self.tensorboard :
Expand All @@ -501,7 +503,8 @@ def train (self, train_data, valid_data=None) :
if self.display_in_training and (cur_batch % self.disp_freq == 0):
if self.timing_in_training:
tic = time.time()
self.valid_on_the_fly(fp, train_data, valid_data)
valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None
self.valid_on_the_fly(fp, [train_batch], valid_batches)
if self.timing_in_training:
toc = time.time()
test_time = toc - tic
Expand Down Expand Up @@ -550,11 +553,11 @@ def get_global_step(self):

def valid_on_the_fly(self,
fp,
train_data,
valid_data,
train_batches,
valid_batches,
print_header=False):
train_results = self.get_evaluation_results(train_data, self.valid_numb_batch)
valid_results = self.get_evaluation_results(valid_data, self.valid_numb_batch)
train_results = self.get_evaluation_results(train_batches)
valid_results = self.get_evaluation_results(valid_batches)

cur_batch = self.cur_batch
current_lr = self.sess.run(self.learning_rate)
Expand Down Expand Up @@ -595,17 +598,22 @@ def print_on_training(fp, train_results, valid_results, cur_batch, cur_lr):
fp.write(print_str)
fp.flush()

def get_evaluation_results(self, data, numb_batch):
if data is None: return None
def get_evaluation_results(self, batch_list):
if batch_list is None: return None
numb_batch = len(batch_list)

sum_results = {} # sum of losses on all atoms
sum_natoms = 0
for i in range(numb_batch):
batch = data.get_batch()
batch = batch_list[i]
natoms = batch["natoms_vec"]
feed_dict = self.get_feed_dict(batch, is_training=False)
results = self.loss.eval(self.sess, feed_dict, natoms)

for k, v in results.items():
sum_results[k] = sum_results.get(k, 0.) + v * results["natoms"]
avg_results = {k: v / sum_results["natoms"] for k, v in sum_results.items() if not k == "natoms"}
if k == "natoms":
sum_natoms += v
else:
sum_results[k] = sum_results.get(k, 0.) + v * results["natoms"]
avg_results = {k: v / sum_natoms for k, v in sum_results.items() if not k == "natoms"}
return avg_results
5 changes: 2 additions & 3 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,14 @@ def print_summary(self, name) :
log.info(f"---Summary of DataSystem: {name:13s}-----------------------------------------------")
log.info("found %d system(s):" % self.nsystems)
log.info(("%s " % self._format_name_length('system', sys_width)) +
("%6s %6s %6s %6s %5s %3s" % ('natoms', 'bch_sz', 'n_bch', "n_test", 'prob', 'pbc')))
("%6s %6s %6s %5s %3s" % ('natoms', 'bch_sz', 'n_bch', 'prob', 'pbc')))
for ii in range(self.nsystems) :
log.info("%s %6d %6d %6d %6d %5.3f %3s" %
log.info("%s %6d %6d %6d %5.3f %3s" %
(self._format_name_length(self.system_dirs[ii], sys_width),
self.natoms[ii],
# TODO batch size * nbatches = number of structures
self.batch_size[ii],
self.nbatches[ii],
self.test_size[ii],
self.sys_probs[ii],
"T" if self.data_systems[ii].pbc else "F"
) )
Expand Down