diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py index faa1c94d2728..1c7ba233995d 100644 --- a/applications/Chat/coati/dataset/reward_dataset.py +++ b/applications/Chat/coati/dataset/reward_dataset.py @@ -80,6 +80,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token= self.end_token = tokenizer.eos_token else: self.end_token = special_token + tokenizer.padding_side='left' for data in tqdm(dataset, disable=not is_rank_0()): chosen = data['chosen'] + self.end_token chosen_token = tokenizer(chosen, diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py index 39ca7576b227..18f38476ea23 100644 --- a/applications/Chat/coati/experience_maker/__init__.py +++ b/applications/Chat/coati/experience_maker/__init__.py @@ -1,4 +1,5 @@ from .base import Experience, ExperienceMaker from .naive import NaiveExperienceMaker +from .multi_step import MultiStepExperienceMaker -__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] +__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker', 'MultiStepExperienceMaker'] diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py index 61fd4f6744dc..890debcb96ab 100644 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/Chat/coati/experience_maker/base.py @@ -64,13 +64,29 @@ def __init__(self, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, - kl_coef: float = 0.1) -> None: + kl_coef: float = 0.1, + max_length: int = 128, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + **generate_kwargs) -> None: super().__init__() self.actor = actor self.critic = critic self.reward_model = reward_model self.initial_model = initial_model self.kl_coef = kl_coef + self.max_length = max_length + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.reward_count = 0 + self.reward_mean = 0.0 + self.reward_M2 = 0.0 @abstractmethod def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: diff --git a/applications/Chat/coati/experience_maker/multi_step.py b/applications/Chat/coati/experience_maker/multi_step.py new file mode 100644 index 000000000000..31d7acac6aa0 --- /dev/null +++ b/applications/Chat/coati/experience_maker/multi_step.py @@ -0,0 +1,118 @@ +import torch +from coati.models.utils import compute_reward, normalize, compute_approx_kl +import torch.nn.functional as F +from .base import Experience, ExperienceMaker +from typing import Any, Callable, Optional +import torch.distributed as dist +from coati.models.base import ActorCritic +import numpy as np + +class MultiStepExperienceMaker(ExperienceMaker): + """ + Multi Step experience maker. + """ + @torch.no_grad() + def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience: + self.actor.eval() + self.critic.eval() + self.initial_model.eval() + self.reward_model.eval() + self.ac = ActorCritic(self.actor, self.critic) + self.buffer = [] + + self.gamma = 1 + self.lamda = 0.25 + self.kl_coef = 1/128 + in_len = input_ids.size(1) + + # generate the sequence and get value at the same time + sequences, values , attention_mask, action_mask = self.ac.generate(input_ids, + **generate_kwargs) + + # compute action log probs + action_logits = self.actor.model(sequences)['logits'] + base_action_logits = self.initial_model.model(sequences)['logits'] + + action_log_probs = F.log_softmax(action_logits, dim=-1) + base_log_probs = F.log_softmax(base_action_logits, dim=-1) + + # compute kl_div + kl_list = compute_approx_kl(action_log_probs, base_log_probs) + # clip kl + kl_list = torch.clamp(kl_list, max=10, min=1e-4) + + # add eos token to the end of sequence and compute reward + eos_tensor = torch.tensor([self.eos_token_id], device=input_ids.device).repeat(input_ids.size(0), 1) + sequence_with_eos = torch.cat([sequences, eos_tensor], dim=-1) + rewards = self.reward_model(sequence_with_eos) + + # reward clip + rewards = torch.clamp(rewards, max=10, min=-10) + + # running mean reward + for i in range(rewards.size(0)): + value = rewards[i] + self.reward_count += 1 + delta = value - self.reward_mean + self.reward_mean += delta / self.reward_count + delta2 = value - self.reward_mean + self.reward_M2 += delta * delta2 + + std = self.reward_M2 / (self.reward_count - 1) + rewards = (rewards - self.reward_mean)/std + + print('rewards: ', rewards) + rewards = rewards * (1 - self.kl_coef) + + # get action mask + action_mask = action_mask[:, in_len:] + kl_list = kl_list[:, in_len:] + + # compute the advantages + advantages, returns = self.compute_gae(kl_list, rewards, values, action_mask) + + for i in range(in_len, sequences.size(1) - 1): + for j in range(sequences.size(0)): + if sequences[j, i] != self.eos_token_id: + _state = sequences[j, :i] + _action_log_prob = action_log_probs[j, i] + _value = values[j, i-in_len] + _return = returns[j, i-in_len] + _adv = advantages[j, i-in_len] + _attention_mask = attention_mask[j, :i] + _action_mask = action_mask[j, :i-in_len] + exp = Experience(_state, _action_log_prob, _value, _return, _adv, _attention_mask, _action_mask) + self.buffer.append(exp) + buffer = self.buffer + return buffer + + @torch.no_grad() + def compute_gae(self, kl_list: torch.Tensor, + reward: torch.Tensor, + values: torch.Tensor, + action_mask: torch.Tensor) -> torch.Tensor: + kl = -kl_list * action_mask * self.kl_coef + values = values * action_mask + T = torch.sum(values.ne(0), dim=1) + self.total_len = sum(T) + max_len = max(T) + gae_values = torch.zeros_like(values) + delta_list = torch.zeros_like(values) + + # add reward to kl[:, -1] + for i in range(len(T)): + kl[i, T[i]-1] += reward[i] + + # compute delta + for t in range(max_len - 1): + next_v = values[:,t + 1] if t + 1 < max_len else 0 + delta_list[:, t] = kl[:, t] + self.gamma * next_v - values[:, t] + + # compute gae + gae_values[:, max_len - 1] = delta_list[:, max_len - 1] + for t in range(max_len - 2, -1, -1): + gae_values[:, t] = delta_list[:, t] + self.gamma * self.lamda * gae_values[:, t + 1] + + # compute return + returns = gae_values + values + return gae_values, returns \ No newline at end of file diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index 94546eeb28e7..358736f878ad 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -23,7 +23,7 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie action_log_probs = self.actor(sequences, num_actions, attention_mask) base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) - value = self.critic(sequences, action_mask, attention_mask) + value = self.critic(sequences, attention_mask) r = self.reward_model(sequences, attention_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py index 7489b2e87ca0..55c8eb8c3ae6 100644 --- a/applications/Chat/coati/models/__init__.py +++ b/applications/Chat/coati/models/__init__.py @@ -1,4 +1,4 @@ from .base import Actor, Critic, RewardModel -from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss +from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss, MPolicyLoss -__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss'] +__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', 'MPolicyLoss'] diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py index 7cf82309af7b..886987f2fbd5 100644 --- a/applications/Chat/coati/models/base/__init__.py +++ b/applications/Chat/coati/models/base/__init__.py @@ -2,5 +2,6 @@ from .critic import Critic from .lm import LM from .reward_model import RewardModel +from .ac import ActorCritic -__all__ = ['Actor', 'Critic', 'RewardModel', 'LM'] +__all__ = ['Actor', 'Critic', 'RewardModel', 'LM', 'ActorCritic'] diff --git a/applications/Chat/coati/models/base/ac.py b/applications/Chat/coati/models/base/ac.py new file mode 100644 index 000000000000..5769e0908945 --- /dev/null +++ b/applications/Chat/coati/models/base/ac.py @@ -0,0 +1,59 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..generation import generate_with_value +from ..lora import LoRAModule +from ..utils import log_probs_from_logits +from coati.models.base import Actor, Critic + + + +class ActorCritic(nn.Module): + """ + ActorCritic model class. + + Args: + model (nn.Module): Actor Model. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, actor:Actor, critic:Critic) -> None: + super().__init__() + self.actor = actor + self.critic = critic + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + sequences, values = generate_with_value(self.actor.model, self.critic, input_ids, **kwargs) + + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + input_mask = input_ids.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + # pad input_mask to be as long as attention_mask + input_mask = F.pad(input_mask, (0, attention_mask.shape[-1]-input_mask.shape[-1], 0, 0), value=0) + action_mask = attention_mask - input_mask + + return sequences, values, attention_mask, action_mask + + # def forward(self, + # sequences: torch.LongTensor, + # num_actions: int, + # attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + # """Returns action log probs + # """ + # output = self.actor.model(sequences, attention_mask=attention_mask) + # logits = output['logits'] + # log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + # return log_probs[:, -num_actions:] + + # def get_base_model(self): + # return self.model diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py index e68a743a7762..69cc75f16e5f 100644 --- a/applications/Chat/coati/models/base/critic.py +++ b/applications/Chat/coati/models/base/critic.py @@ -23,32 +23,21 @@ def __init__( model: nn.Module, value_head: nn.Module, lora_rank: int = 0, - lora_train_bias: str = 'none', - use_action_mask: bool = False, + lora_train_bias: str = 'none' ) -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.value_head = value_head - self.use_action_mask = use_action_mask self.convert_to_lora() def forward(self, sequences: torch.LongTensor, - action_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + # we have added a token at the end of each sequence + # so we can use the last hidden state to be the input of the marking linear layer to get the value score outputs = self.model(sequences, attention_mask=attention_mask) last_hidden_states = outputs['last_hidden_state'] - - values = self.value_head(last_hidden_states).squeeze(-1) - - if action_mask is not None and self.use_action_mask: - num_actions = action_mask.size(1) - prompt_mask = attention_mask[:, :-num_actions] - values = values[:, :-num_actions] - value = masked_mean(values, prompt_mask, dim=1) - return value - - values = values[:, :-1] - value = values.mean(dim=1) + value_prob = last_hidden_states[:, -1] + value = self.value_head(value_prob).squeeze(1) # ensure shape is (B) return value diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py index ce8c0a1d3568..fe01c8945033 100644 --- a/applications/Chat/coati/models/base/reward_model.py +++ b/applications/Chat/coati/models/base/reward_model.py @@ -34,8 +34,10 @@ def __init__(self, self.value_head = nn.Linear(model.config.n_embd, 1) def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + # we have added a token at the end of each sequence while preparing for the dataset + # so we can use the last hidden state to be the input of the marking linear layer to get the reward score outputs = self.model(sequences, attention_mask=attention_mask) last_hidden_states = outputs['last_hidden_state'] - values = self.value_head(last_hidden_states)[:, :-1] - value = values.mean(dim=1).squeeze(1) # ensure shape is (B) - return value + reward_prob = last_hidden_states[:, -1] + reward = self.value_head(reward_prob).squeeze(1) # ensure shape is (B) + return reward diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index eb30c36d0f84..08fcd40fc77a 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -36,6 +36,7 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 +@torch.no_grad() def sample(model: nn.Module, input_ids: torch.Tensor, max_length: int, @@ -88,13 +89,13 @@ def sample(model: nn.Module, return input_ids - +@torch.no_grad() def generate(model: nn.Module, input_ids: torch.Tensor, max_length: int, num_beams: int = 1, do_sample: bool = True, - early_stopping: bool = False, + early_stopping: bool = True, eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, top_k: Optional[int] = None, @@ -144,3 +145,74 @@ def generate(model: nn.Module, raise NotImplementedError else: raise ValueError("Unsupported generation mode") + + +@torch.no_grad() +def generate_with_value(actor: nn.Module, + critic: nn.Module, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = True, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + if input_ids.size(1) >= max_length: + return input_ids + + temperature = 1.0 + logits_processor = prepare_logits_processor(top_k, top_p, temperature) + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + values = [] + + for _ in range(input_ids.size(1), max_length): + model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { + 'input_ids': input_ids + } + outputs = actor(**model_inputs) + + next_token_logits = outputs['logits'][:, -1, :] + # pre-process distribution + next_token_logits = logits_processor(input_ids, next_token_logits) + # sample + probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + # if 'nan' in str(probs): + # for name, param in actor.named_parameters(): + # print(name, param) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + + # compute value on the last hidden_state + eos_tensor = torch.tensor([eos_token_id], device=input_ids.device).repeat(input_ids.size(0), 1) + value_input = torch.cat([input_ids, eos_tensor], dim=-1) + value = critic(value_input) + values.append(value) + + # update generated ids, model inputs for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + if update_model_kwargs_fn is not None: + model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # stop when each sentence is finished if early_stopping=True + if early_stopping and _is_sequence_finished(unfinished_sequences): + break + # transform values to tensor + values = torch.cat(values, dim=0) + # reshape to (x,4) + values = values.view(4, -1) + return input_ids, values \ No newline at end of file diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py index 926c6e2a4e41..23499247b5cd 100644 --- a/applications/Chat/coati/models/loss.py +++ b/applications/Chat/coati/models/loss.py @@ -46,6 +46,30 @@ def forward(self, return loss +class MPolicyLoss(nn.Module): + """ + Policy Loss for multistep-PPO + """ + + def __init__(self, clip_eps: float = 0.2) -> None: + super().__init__() + self.clip_eps = clip_eps + + def forward(self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + ratio = (log_probs - old_log_probs).exp() + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + loss = -torch.min(surr1, surr2) + if action_mask is not None: + loss = masked_mean(loss, action_mask) + loss = loss.mean() + return loss + + class ValueLoss(nn.Module): """ Value Loss for PPO diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 0ff13181fcd2..49d0e4e6da43 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -19,6 +19,27 @@ def compute_approx_kl(log_probs: torch.Tensor, action_mask: Mask for actions. """ + log_ratio = log_probs - log_probs_base + approx_kl = (log_ratio.exp() - 1) - log_ratio + # if action_mask is not None: + # approx_kl = masked_mean(approx_kl, action_mask, dim=1) + # return approx_kl + approx_kl = approx_kl.sum(dim=-1) + return approx_kl + + +def compute_approx_kl_mean(log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute the approximate KL divergence between two distributions. + Schulman blog: http://joschu.net/blog/kl-approx.html + Args: + log_probs: Log probabilities of the new distribution. + log_probs_base: Log probabilities of the base distribution. + action_mask: Mask for actions. + """ + log_ratio = log_probs - log_probs_base approx_kl = (log_ratio.exp() - 1) - log_ratio if action_mask is not None: @@ -28,6 +49,7 @@ def compute_approx_kl(log_probs: torch.Tensor, return approx_kl + def compute_reward(r: Union[torch.Tensor, float], kl_coef: float, log_probs: torch.Tensor, @@ -35,7 +57,7 @@ def compute_reward(r: Union[torch.Tensor, float], action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if kl_coef <= 0.0: return r - kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) + kl = compute_approx_kl_mean(log_probs, log_probs_base, action_mask=action_mask) reward = r - kl_coef * kl return reward diff --git a/applications/Chat/coati/replay_buffer/naive.py b/applications/Chat/coati/replay_buffer/naive.py index 938f500643c9..b445e3959003 100644 --- a/applications/Chat/coati/replay_buffer/naive.py +++ b/applications/Chat/coati/replay_buffer/naive.py @@ -5,7 +5,7 @@ from coati.experience_maker.base import Experience from .base import ReplayBuffer -from .utils import BufferItem, make_experience_batch, split_experience_batch +from .utils import BufferItem, make_experience_batch, split_experience_batch, exp_to_buffer_item class NaiveReplayBuffer(ReplayBuffer): @@ -25,11 +25,19 @@ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = T self.items: List[BufferItem] = [] @torch.no_grad() - def append(self, experience: Experience) -> None: - if self.cpu_offload: - experience.to_device(torch.device('cpu')) - items = split_experience_batch(experience) - self.items.extend(items) + def append(self, experience) -> None: + if isinstance(experience, Experience): + if self.cpu_offload: + experience.to_device(torch.device('cpu')) + items = split_experience_batch(experience) + self.items.extend(items) + + elif isinstance(experience, list): + for exp in experience: + exp.to_device(torch.device('cpu')) + item = exp_to_buffer_item(exp) + self.items.append(item) + if self.limit > 0: samples_to_remove = len(self.items) - self.limit if samples_to_remove > 0: diff --git a/applications/Chat/coati/replay_buffer/utils.py b/applications/Chat/coati/replay_buffer/utils.py index 55ddb2ae8191..f8c43f987875 100644 --- a/applications/Chat/coati/replay_buffer/utils.py +++ b/applications/Chat/coati/replay_buffer/utils.py @@ -47,6 +47,12 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]: items = [BufferItem(**kwargs) for kwargs in batch_kwargs] return items +def exp_to_buffer_item(exp: Experience) -> BufferItem: + kwargs = {} + keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask') + for key in keys: + kwargs[key] = getattr(exp, key) + return BufferItem(**kwargs) def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: assert side in ('left', 'right') diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py index 525b57bf21d3..3795d30c2355 100644 --- a/applications/Chat/coati/trainer/__init__.py +++ b/applications/Chat/coati/trainer/__init__.py @@ -2,5 +2,6 @@ from .ppo import PPOTrainer from .rm import RewardModelTrainer from .sft import SFTTrainer +from .multi_ppo import MPPOTrainer -__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer'] +__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer', 'MPPOTrainer'] diff --git a/applications/Chat/coati/trainer/multi_ppo.py b/applications/Chat/coati/trainer/multi_ppo.py new file mode 100644 index 000000000000..7b0c806fcac7 --- /dev/null +++ b/applications/Chat/coati/trainer/multi_ppo.py @@ -0,0 +1,147 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from coati.experience_maker import Experience, NaiveExperienceMaker, MultiStepExperienceMaker +from coati.models.base import Actor, Critic +from coati.models.generation_utils import update_model_kwargs_fn +from coati.models.loss import PolicyLoss, ValueLoss, MPolicyLoss +from coati.replay_buffer import NaiveReplayBuffer +from torch.optim import Optimizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from .base import Trainer +from .callbacks import Callback +from .strategies import Strategy + + +class MPPOTrainer(Trainer): + """ + Trainer for PPO algorithm. + + Args: + strategy (Strategy): the strategy to use for training + actor (Actor): the actor model in ppo algorithm + critic (Critic): the critic model in ppo algorithm + reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences + initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor + actor_optim (Optimizer): the optimizer to use for actor model + critic_optim (Optimizer): the optimizer to use for critic model + kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss + train_batch_size (int, defaults to 8): the batch size to use for training + buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer + buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu + eps_clip (float, defaults to 0.2): the clip coefficient of policy loss + vf_coef (float, defaults to 1.0): the coefficient of value loss + value_clip (float, defaults to 0.4): the clip coefficient of value loss + experience_batch_size (int, defaults to 8): the batch size to use for experience generation + max_epochs (int, defaults to 1): the number of epochs of training process + tokenier (Callable, optional): the tokenizer to use for tokenizing the input + sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer + dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader + callbacks (List[Callback], defaults to []): the callbacks to call during training process + generate_kwargs (dict, optional): the kwargs to use while model generating + """ + + def __init__(self, + strategy: Strategy, + actor: Actor, + critic: Critic, + reward_model: nn.Module, + initial_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + kl_coef: float = 0.1, + ptx_coef: float = 0.4, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.4, + experience_batch_size: int = 8, + max_epochs: int = 1, + tokenizer: Optional[Callable[[Any], dict]] = None, + sample_replay_buffer: bool = False, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs) -> None: + experience_maker = MultiStepExperienceMaker(actor, critic, reward_model, initial_model, kl_coef, **generate_kwargs) + # experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef, **generate_kwargs) + replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) + super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, + sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs) + self.actor = actor + self.critic = critic + + self.actor_loss_fn = MPolicyLoss(eps_clip) + self.critic_loss_fn = ValueLoss(value_clip) + self.vf_coef = vf_coef + self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100) + self.ptx_coef = ptx_coef + self.actor_optim = actor_optim + self.critic_optim = critic_optim + + + def training_step(self, experience: Experience) -> Dict[str, float]: + self.actor.train() + self.critic.train() + # policy loss + num_actions = experience.action_mask.size(1) + action_logits = self.actor.model(experience.sequences)['logits'][:, -1] + action_log_probs = F.log_softmax(action_logits, dim=-1) + actor_loss = self.actor_loss_fn(action_log_probs, + experience.action_log_probs, + experience.advantages) + # ptx loss + if self.ptx_coef != 0: + batch = next(self.pretrain_dataloader) + ptx = batch['input_ids'].to(torch.cuda.current_device()) + label = batch['labels'].to(torch.cuda.current_device())[:, 1:] + attention_mask = batch['attention_mask'].to(torch.cuda.current_device()) + ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :] + ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1)) + actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) + + self.strategy.backward(actor_loss, self.actor.model, self.actor_optim) + + # for name, param in self.actor.named_parameters(): + # print(name, param.grad) + + self.strategy.optimizer_step(self.actor_optim) + self.actor_optim.zero_grad() + + # value loss + values = self.critic(experience.sequences, + attention_mask=experience.attention_mask) + critic_loss = self.critic_loss_fn(values, + experience.values, + experience.reward, + action_mask=experience.action_mask) + critic_loss = critic_loss * self.vf_coef + self.strategy.backward(critic_loss, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim) + self.critic_optim.zero_grad() + + return {'returns': experience.reward.mean().item()} + + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) + + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) + + +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: + origin_model = strategy._unwrap_actor(actor) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs: + new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn + + return new_kwargs diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index d58e437e6e61..90095df61cbf 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -109,7 +109,6 @@ def training_step(self, experience: Experience) -> Dict[str, float]: # value loss values = self.critic(experience.sequences, - action_mask=experience.action_mask, attention_mask=experience.attention_mask) critic_loss = self.critic_loss_fn(values, experience.values, diff --git a/applications/Chat/examples/norm.sh b/applications/Chat/examples/norm.sh new file mode 100644 index 000000000000..3d85ab57f8f4 --- /dev/null +++ b/applications/Chat/examples/norm.sh @@ -0,0 +1,17 @@ +# set_n_least_used_CUDA_VISIBLE_DEVICES 4 + +# CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --master_port 29500 --nproc_per_node=4 \ +# train_reward_model.py --pretrain '/data/scratch/alpaca-7B' \ +# --model 'bloom' \ +# --strategy colossalai_zero2 \ +# --loss_fn 'log_exp'\ +# --save_path '/home/lczht/data2/Coati/examples/rm_bloomz_1b7.pt' \ +# --dataset 'Anthropic/hh-rlhf'\ +# --subset 'harmless-base'\ +# --test True + +CUDA_VISIBLE_DEVICES=5 python normalize_rm.py --pretrain /home/lczht/data2/bloom-560m \ + --model 'bloom' \ + --model_path '/home/lczht/data2/Coati/examples/rm_bloom560m.pt' \ + --dataset 'Anthropic/hh-rlhf'\ + --test True \ \ No newline at end of file diff --git a/applications/Chat/examples/normalize_rm.py b/applications/Chat/examples/normalize_rm.py new file mode 100644 index 000000000000..b29cf7951584 --- /dev/null +++ b/applications/Chat/examples/normalize_rm.py @@ -0,0 +1,128 @@ +import argparse +from random import randint + +import loralib as lora +import torch +from coati.dataset import HhRlhfDataset, RmStaticDataset +from coati.models import LogExpLoss, LogSigLoss +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMRM +from coati.models.deberta import DebertaRM +from coati.models.gpt import GPTRM +from coati.models.llama import LlamaRM +from coati.models.opt import OPTRM +from coati.models.roberta import RoBERTaRM +from coati.trainer import RewardModelTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from datasets import load_dataset +from torch.optim import Adam +from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from colossalai.nn.optimizer import HybridAdam + +from tqdm import tqdm + +def normalize(args): + # configure model + if args.model == 'bloom': + model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'opt': + model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'gpt2': + model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'deberta': + model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'llama': + model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'roberta': + model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model_path is not None: + state_dict = torch.load(args.model_path) + model.load_state_dict(state_dict) + + model = model.to(torch.float16) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'deberta': + tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + else: + raise ValueError(f'Unsupported model "{args.model}"') + max_len = args.max_len + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + else: + tokenizer.pad_token = tokenizer.eos_token + + # prepare for data and dataset + if args.subset is not None: + data = load_dataset(args.dataset, data_dir=args.subset) + else: + data = load_dataset(args.dataset) + + if args.test: + train_data = data['train'].select(range(10000)) + else: + train_data = data['train'] + + if args.dataset == 'Dahoas/rm-static': + train_dataset = RmStaticDataset(train_data, tokenizer, max_len) + elif args.dataset == 'Anthropic/hh-rlhf': + train_dataset = HhRlhfDataset(train_data, tokenizer, max_len) + else: + raise ValueError(f'Unsupported dataset "{args.dataset}"') + + train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) + + model.eval() + with torch.no_grad(): + bar = tqdm(train_dataloader) + output = [] + for chosen_ids, c_mask, reject_ids, r_mask in train_dataloader: + chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device()) + c_mask = c_mask.squeeze(1).to(torch.cuda.current_device()) + reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device()) + r_mask = r_mask.squeeze(1).to(torch.cuda.current_device()) + chosen_reward = model(chosen_ids, attention_mask=c_mask) + reject_reward = model(reject_ids, attention_mask=r_mask) + output.append(chosen_reward) + output.append(reject_reward) + bar.update() + bar.close() + output = torch.cat(output) + mean = output.mean() + std = output.std() + print(mean.item(), std.item()) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--model_path', type=str, default=None) + parser.add_argument('--dataset', + type=str, + choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], + default='Dahoas/rm-static') + parser.add_argument('--subset', type=str, default=None) + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--max_len', type=int, default=256) + parser.add_argument('--test', type=bool, default=False) + parser.add_argument('--lora_rank', type=int, default=0) + args = parser.parse_args() + normalize(args) diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index 5ded6d8432ed..688a5160c4bd 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -90,15 +90,15 @@ def main(args): raise ValueError(f'Unsupported actor model "{args.model}"') if rm_model_name == 'gpt2': - critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'bloom': - critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'opt': - critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'llama': - critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) elif rm_model_name == 'roberta': - critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -143,14 +143,18 @@ def main(args): prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None prompt_dataloader = DataLoader(prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, - batch_size=args.train_batch_size) + batch_size=args.experience_batch_size) pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None pretrain_dataloader = DataLoader(pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, @@ -180,7 +184,7 @@ def tokenize_fn(texts): train_batch_size=args.train_batch_size, experience_batch_size=args.experience_batch_size, tokenizer=tokenize_fn, - max_length=128, + max_length=512, do_sample=True, temperature=1.0, top_k=50, diff --git a/applications/Chat/examples/train_prompts_m.py b/applications/Chat/examples/train_prompts_m.py new file mode 100644 index 000000000000..35e08a6f7406 --- /dev/null +++ b/applications/Chat/examples/train_prompts_m.py @@ -0,0 +1,243 @@ +import argparse + +import pandas as pd +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic +from coati.trainer import PPOTrainer, MPPOTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer + +from colossalai.nn.optimizer import HybridAdam + + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + if args.rm_path is not None: + state_dict = torch.load(args.rm_path, map_location='cpu') + + # configure model + if args.model == 'gpt2': + initial_model = GPTActor(pretrained=args.pretrain) + elif args.model == 'bloom': + initial_model = BLOOMActor(pretrained=args.pretrain) + elif args.model == 'opt': + initial_model = OPTActor(pretrained=args.pretrain) + elif args.model == 'llama': + initial_model = LlamaActor(pretrained=args.pretrain) + elif args.model == 'roberta': + initial_model = RoBERTaActor(pretrained=args.pretrain) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model + + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'roberta': + reward_model = RoBERTaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + reward_model.load_state_dict(state_dict) + + # if args.strategy != 'colossalai_gemini': + # initial_model.to(torch.float16).to(torch.cuda.current_device()) + # reward_model.to(torch.float16).to(torch.cuda.current_device()) + + initial_model.to(torch.cuda.current_device()) + reward_model.to(torch.cuda.current_device()) + + with strategy.model_init_context(): + if args.model == 'gpt2': + actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'bloom': + actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'opt': + actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'llama': + actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if rm_model_name == 'gpt2': + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == 'bloom': + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == 'opt': + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == 'llama': + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + elif rm_model_name == 'roberta': + critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + critic.load_state_dict(state_dict) + del state_dict + + # if args.strategy != 'colossalai_gemini': + # critic.to(torch.float16).to(torch.cuda.current_device()) + # actor.to(torch.float16).to(torch.cuda.current_device()) + + critic.to(torch.cuda.current_device()) + actor.to(torch.cuda.current_device()) + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=1e-5) + critic_optim = HybridAdam(critic.parameters(), lr=1e-5) + else: + actor_optim = Adam(actor.parameters(), lr=7e-6) + critic_optim = Adam(critic.parameters(), lr=7e-6) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer.eos_token = '<\s>' + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) + else: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384) + if dist.is_initialized() and dist.get_world_size() > 1: + prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None + prompt_dataloader = DataLoader(prompt_dataset, + shuffle=(prompt_sampler is None), + sampler=prompt_sampler, + batch_size=args.experience_batch_size) + + pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset) + if dist.is_initialized() and dist.get_world_size() > 1: + pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None + pretrain_dataloader = DataLoader(pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator) + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + + # configure trainer + # trainer = PPOTrainer( + trainer = MPPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + experience_batch_size=args.experience_batch_size, + tokenizer=tokenize_fn, + max_length=256, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + trainer.fit(prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + # save model checkpoint after fitting + trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') + parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive', + help='strategy to use') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) + parser.add_argument('--rm_path', type=str, default=None) + parser.add_argument('--rm_pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=50) + parser.add_argument('--max_timesteps', type=int, default=1) + parser.add_argument('--update_timesteps', type=int, default=1) + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--train_batch_size', type=int, default=1) + parser.add_argument('--ptx_batch_size', type=int, default=1) + parser.add_argument('--experience_batch_size', type=int, default=4) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--kl_coef', type=float, default=0.1) + parser.add_argument('--ptx_coef', type=float, default=0.5) + args = parser.parse_args() + main(args)