diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index d5f579a3b5..b0bf1f9dbe 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -510,9 +510,9 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix=""): if self.run_opt.init_mode == "init_from_frz_model": self._init_from_frz_model() elif self.run_opt.init_mode == "init_model": - self.ckpt_meta = self.run_opt.init_model + self._init_from_ckpt(self.run_opt.init_model) elif self.run_opt.init_mode == "restart": - self.ckpt_meta = self.run_opt.restart + self._init_from_ckpt(self.run_opt.restart) elif self.run_opt.init_mode == "finetune": self._init_from_pretrained_model( data=data, origin_type_map=origin_type_map @@ -1181,6 +1181,19 @@ def _init_from_frz_model(self): self.frz_model = self.run_opt.init_frz_model self.model.init_variables(graph, graph_def, model_type=self.model_type) + def _init_from_ckpt(self, ckpt_meta: str): + with tf.Graph().as_default() as graph: + tf.train.import_meta_graph(f"{ckpt_meta}.meta", clear_devices=True) + # get the model type from the model + try: + t_model_type = get_tensor_by_name_from_graph(graph, "model_type") + except GraphWithoutTensorError as e: + self.model_type = "original_model" + else: + self.model_type = bytes.decode(t_model_type) + if self.model_type == "compressed_model": + self.ckpt_meta = ckpt_meta + def _init_from_pretrained_model( self, data, origin_type_map=None, bias_shift="delta" ):