From 537c35d0991cc7adc3e2153713bd3dc249130a87 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 13 Feb 2023 21:12:36 -0500 Subject: [PATCH 1/3] fix restarting from the original model Signed-off-by: Jinzhe Zeng --- deepmd/train/trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index d5f579a3b5..592f3b1a3c 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -510,9 +510,9 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix=""): if self.run_opt.init_mode == "init_from_frz_model": self._init_from_frz_model() elif self.run_opt.init_mode == "init_model": - self.ckpt_meta = self.run_opt.init_model + self._init_from_ckpt(self.run_opt.init_model) elif self.run_opt.init_mode == "restart": - self.ckpt_meta = self.run_opt.restart + self._init_from_ckpt(self.run_opt.restart) elif self.run_opt.init_mode == "finetune": self._init_from_pretrained_model( data=data, origin_type_map=origin_type_map @@ -1181,6 +1181,24 @@ def _init_from_frz_model(self): self.frz_model = self.run_opt.init_frz_model self.model.init_variables(graph, graph_def, model_type=self.model_type) + def _init_from_ckpt(self, ckpt_meta: str): + with tf.Graph().as_default() as graph: + tf.train.import_meta_graph(f"{ckpt_meta}.meta", clear_devices=True) + # get the model type from the model + try: + t_model_type = get_tensor_by_name_from_graph(graph, "model_type") + except GraphWithoutTensorError as e: + # throw runtime error if the frozen_model has no model type information... + raise RuntimeError( + "The input frozen model: %s has no 'model_type' information, " + "which is not supported by the 'dp train init-frz-model' interface. " + % self.run_opt.init_frz_model + ) from e + else: + self.model_type = bytes.decode(t_model_type) + if self.model_type == "compressed_model": + self.ckpt_meta = ckpt_meta + def _init_from_pretrained_model( self, data, origin_type_map=None, bias_shift="delta" ): From 8f76b7d88825d5aadb6d60105754ef68923ddcc7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 13 Feb 2023 21:15:38 -0500 Subject: [PATCH 2/3] compatibility Signed-off-by: Jinzhe Zeng --- deepmd/train/trainer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 592f3b1a3c..80a5266312 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -1188,12 +1188,7 @@ def _init_from_ckpt(self, ckpt_meta: str): try: t_model_type = get_tensor_by_name_from_graph(graph, "model_type") except GraphWithoutTensorError as e: - # throw runtime error if the frozen_model has no model type information... - raise RuntimeError( - "The input frozen model: %s has no 'model_type' information, " - "which is not supported by the 'dp train init-frz-model' interface. " - % self.run_opt.init_frz_model - ) from e + t_model_type = "original_model" else: self.model_type = bytes.decode(t_model_type) if self.model_type == "compressed_model": From 20a0864ace52890254b0a71912afc108235d694b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 13 Feb 2023 21:16:47 -0500 Subject: [PATCH 3/3] typo Signed-off-by: Jinzhe Zeng --- deepmd/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 80a5266312..b0bf1f9dbe 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -1188,7 +1188,7 @@ def _init_from_ckpt(self, ckpt_meta: str): try: t_model_type = get_tensor_by_name_from_graph(graph, "model_type") except GraphWithoutTensorError as e: - t_model_type = "original_model" + self.model_type = "original_model" else: self.model_type = bytes.decode(t_model_type) if self.model_type == "compressed_model":