-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[coati] fix RM & MDP #3645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[coati] fix RM & MDP #3645
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
282f831
first commit
ht-zhou 85dda25
running mean rewards
ht-zhou 957f21d
update
ht-zhou 972e485
revert: remove debug config
cwher 38f218e
revert: undo unnecessary changes
cwher d6934a9
revert: undo unnecessary changes
cwher e0a2c5f
revert: undo unnecessary changes
cwher File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from .base import Experience, ExperienceMaker | ||
| from .naive import NaiveExperienceMaker | ||
| from .multi_step import MultiStepExperienceMaker | ||
|
|
||
| __all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] | ||
| __all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker', 'MultiStepExperienceMaker'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| import torch | ||
| from coati.models.utils import compute_reward, normalize, compute_approx_kl | ||
| import torch.nn.functional as F | ||
| from .base import Experience, ExperienceMaker | ||
| from typing import Any, Callable, Optional | ||
| import torch.distributed as dist | ||
| from coati.models.base import ActorCritic | ||
| import numpy as np | ||
|
|
||
| class MultiStepExperienceMaker(ExperienceMaker): | ||
| """ | ||
| Multi Step experience maker. | ||
| """ | ||
| @torch.no_grad() | ||
| def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: | ||
| self.actor.eval() | ||
| self.critic.eval() | ||
| self.initial_model.eval() | ||
| self.reward_model.eval() | ||
| self.ac = ActorCritic(self.actor, self.critic) | ||
| self.buffer = [] | ||
|
|
||
| self.gamma = 1 | ||
| self.lamda = 0.25 | ||
| self.kl_coef = 1/128 | ||
| in_len = input_ids.size(1) | ||
|
|
||
| # generate the sequence and get value at the same time | ||
| sequences, values , attention_mask, action_mask = self.ac.generate(input_ids, | ||
| **generate_kwargs) | ||
|
|
||
| # compute action log probs | ||
| action_logits = self.actor.model(sequences)['logits'] | ||
| base_action_logits = self.initial_model.model(sequences)['logits'] | ||
|
|
||
| action_log_probs = F.log_softmax(action_logits, dim=-1) | ||
| base_log_probs = F.log_softmax(base_action_logits, dim=-1) | ||
|
|
||
| # compute kl_div | ||
| kl_list = compute_approx_kl(action_log_probs, base_log_probs) | ||
| # clip kl | ||
| kl_list = torch.clamp(kl_list, max=10, min=1e-4) | ||
|
|
||
| # add eos token to the end of sequence and compute reward | ||
| eos_tensor = torch.tensor([self.eos_token_id], device=input_ids.device).repeat(input_ids.size(0), 1) | ||
| sequence_with_eos = torch.cat([sequences, eos_tensor], dim=-1) | ||
| rewards = self.reward_model(sequence_with_eos) | ||
|
|
||
| # reward clip | ||
| rewards = torch.clamp(rewards, max=10, min=-10) | ||
|
|
||
| # running mean reward | ||
| for i in range(rewards.size(0)): | ||
| value = rewards[i] | ||
| self.reward_count += 1 | ||
| delta = value - self.reward_mean | ||
| self.reward_mean += delta / self.reward_count | ||
| delta2 = value - self.reward_mean | ||
| self.reward_M2 += delta * delta2 | ||
|
|
||
| std = self.reward_M2 / (self.reward_count - 1) | ||
| rewards = (rewards - self.reward_mean)/std | ||
|
|
||
| print('rewards: ', rewards) | ||
| rewards = rewards * (1 - self.kl_coef) | ||
|
|
||
| # get action mask | ||
| action_mask = action_mask[:, in_len:] | ||
| kl_list = kl_list[:, in_len:] | ||
|
|
||
| # compute the advantages | ||
| advantages, returns = self.compute_gae(kl_list, rewards, values, action_mask) | ||
|
|
||
| for i in range(in_len, sequences.size(1) - 1): | ||
| for j in range(sequences.size(0)): | ||
| if sequences[j, i] != self.eos_token_id: | ||
| _state = sequences[j, :i] | ||
| _action_log_prob = action_log_probs[j, i] | ||
| _value = values[j, i-in_len] | ||
| _return = returns[j, i-in_len] | ||
| _adv = advantages[j, i-in_len] | ||
| _attention_mask = attention_mask[j, :i] | ||
| _action_mask = action_mask[j, :i-in_len] | ||
| exp = Experience(_state, _action_log_prob, _value, _return, _adv, _attention_mask, _action_mask) | ||
| self.buffer.append(exp) | ||
| buffer = self.buffer | ||
| return buffer | ||
|
|
||
| @torch.no_grad() | ||
| def compute_gae(self, kl_list: torch.Tensor, | ||
| reward: torch.Tensor, | ||
| values: torch.Tensor, | ||
| action_mask: torch.Tensor) -> torch.Tensor: | ||
| kl = -kl_list * action_mask * self.kl_coef | ||
| values = values * action_mask | ||
| T = torch.sum(values.ne(0), dim=1) | ||
| self.total_len = sum(T) | ||
| max_len = max(T) | ||
| gae_values = torch.zeros_like(values) | ||
| delta_list = torch.zeros_like(values) | ||
|
|
||
| # add reward to kl[:, -1] | ||
| for i in range(len(T)): | ||
| kl[i, T[i]-1] += reward[i] | ||
|
|
||
| # compute delta | ||
| for t in range(max_len - 1): | ||
| next_v = values[:,t + 1] if t + 1 < max_len else 0 | ||
| delta_list[:, t] = kl[:, t] + self.gamma * next_v - values[:, t] | ||
|
|
||
| # compute gae | ||
| gae_values[:, max_len - 1] = delta_list[:, max_len - 1] | ||
| for t in range(max_len - 2, -1, -1): | ||
| gae_values[:, t] = delta_list[:, t] + self.gamma * self.lamda * gae_values[:, t + 1] | ||
|
|
||
| # compute return | ||
| returns = gae_values + values | ||
| return gae_values, returns |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from .base import Actor, Critic, RewardModel | ||
| from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss | ||
| from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss, MPolicyLoss | ||
|
|
||
| __all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss'] | ||
| __all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', 'MPolicyLoss'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from typing import Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from ..generation import generate_with_value | ||
| from ..lora import LoRAModule | ||
| from ..utils import log_probs_from_logits | ||
| from coati.models.base import Actor, Critic | ||
|
|
||
|
|
||
|
|
||
| class ActorCritic(nn.Module): | ||
| """ | ||
| ActorCritic model class. | ||
|
|
||
| Args: | ||
| model (nn.Module): Actor Model. | ||
| lora_rank (int): LoRA rank. | ||
| lora_train_bias (str): LoRA bias training mode. | ||
| """ | ||
|
|
||
| def __init__(self, actor:Actor, critic:Critic) -> None: | ||
| super().__init__() | ||
| self.actor = actor | ||
| self.critic = critic | ||
|
|
||
| @torch.no_grad() | ||
| def generate( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| **kwargs | ||
| ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: | ||
| sequences, values = generate_with_value(self.actor.model, self.critic, input_ids, **kwargs) | ||
|
|
||
| pad_token_id = kwargs.get('pad_token_id', None) | ||
| if pad_token_id is not None: | ||
| attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) | ||
| input_mask = input_ids.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) | ||
| # pad input_mask to be as long as attention_mask | ||
| input_mask = F.pad(input_mask, (0, attention_mask.shape[-1]-input_mask.shape[-1], 0, 0), value=0) | ||
| action_mask = attention_mask - input_mask | ||
|
|
||
| return sequences, values, attention_mask, action_mask | ||
|
|
||
| # def forward(self, | ||
| # sequences: torch.LongTensor, | ||
| # num_actions: int, | ||
| # attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| # """Returns action log probs | ||
| # """ | ||
| # output = self.actor.model(sequences, attention_mask=attention_mask) | ||
| # logits = output['logits'] | ||
| # log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) | ||
| # return log_probs[:, -num_actions:] | ||
|
|
||
| # def get_base_model(self): | ||
| # return self.model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all generation kwargs. Is it essential to declare these kwargs here explicitly?