From e974921a89926784470cda035b2448db1315299f Mon Sep 17 00:00:00 2001 From: Dr-Corgi Date: Fri, 31 Mar 2023 17:19:36 +0800 Subject: [PATCH] Move the function "save_model" into PPOTrainer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The function save_model should be a part of PPOTrainer. save_model函数被错误地外置到模型外部,导致rlhf train阶段无法保存模型。 --- applications/Chat/coati/trainer/ppo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 84254d50d7e7..873d97fdf6b5 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -116,6 +116,9 @@ def training_step(self, experience: Experience) -> Dict[str, float]: self.critic_optim.zero_grad() return {'reward': experience.reward.mean().item()} + + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: @@ -129,7 +132,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn return new_kwargs - - -def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)