diff --git a/source/train/Trainer.py b/source/train/Trainer.py index f4830a2eba..b7112370ae 100644 --- a/source/train/Trainer.py +++ b/source/train/Trainer.py @@ -209,7 +209,10 @@ def build (self, data, stop_batch = 0) : self.ntypes = self.model.get_ntypes() - assert (self.ntypes == data.get_ntypes()), "ntypes should match that found in data" + # Usually, the type number of the model should be equal to that of the data + # However, nt_model > nt_data should be allowed, since users may only want to + # train using a dataset that only have some of elements + assert (self.ntypes >= data.get_ntypes()), "ntypes should match that found in data" self.stop_batch = stop_batch self.batch_size = data.get_batch_size() @@ -492,4 +495,4 @@ def test_on_the_fly (self, feed_dict_batch) print_str += " %8.1e\n" % current_lr fp.write(print_str) - fp.flush () \ No newline at end of file + fp.flush ()