From b0b62b0d16c2bffbca80d7936d3f84952f4e5072 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 20 Nov 2020 14:54:28 -0500 Subject: [PATCH] allow ntypes_model > ntypes_data (fix #261) Usually, the type number of the model should be equal to that of the data However, nt_model > nt_data should be allowed, since users may only want to train using a dataset that only have some of elements --- source/train/Trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/source/train/Trainer.py b/source/train/Trainer.py index f4830a2eba..b7112370ae 100644 --- a/source/train/Trainer.py +++ b/source/train/Trainer.py @@ -209,7 +209,10 @@ def build (self, data, stop_batch = 0) : self.ntypes = self.model.get_ntypes() - assert (self.ntypes == data.get_ntypes()), "ntypes should match that found in data" + # Usually, the type number of the model should be equal to that of the data + # However, nt_model > nt_data should be allowed, since users may only want to + # train using a dataset that only have some of elements + assert (self.ntypes >= data.get_ntypes()), "ntypes should match that found in data" self.stop_batch = stop_batch self.batch_size = data.get_batch_size() @@ -492,4 +495,4 @@ def test_on_the_fly (self, feed_dict_batch) print_str += " %8.1e\n" % current_lr fp.write(print_str) - fp.flush () \ No newline at end of file + fp.flush ()