Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions applications/Chat/coati/dataset/reward_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
tokenizer.padding_side='left'
for data in tqdm(dataset, disable=not is_rank_0()):
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
Expand Down
3 changes: 2 additions & 1 deletion applications/Chat/coati/experience_maker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
from .multi_step import MultiStepExperienceMaker

__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker', 'MultiStepExperienceMaker']
18 changes: 17 additions & 1 deletion applications/Chat/coati/experience_maker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,29 @@ def __init__(self,
critic: nn.Module,
reward_model: nn.Module,
initial_model: Actor,
kl_coef: float = 0.1) -> None:
kl_coef: float = 0.1,
max_length: int = 128,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all generation kwargs. Is it essential to declare these kwargs here explicitly?

eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
**generate_kwargs) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
self.kl_coef = kl_coef
self.max_length = max_length
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.reward_count = 0
self.reward_mean = 0.0
self.reward_M2 = 0.0

@abstractmethod
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
Expand Down
118 changes: 118 additions & 0 deletions applications/Chat/coati/experience_maker/multi_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
from coati.models.utils import compute_reward, normalize, compute_approx_kl
import torch.nn.functional as F
from .base import Experience, ExperienceMaker
from typing import Any, Callable, Optional
import torch.distributed as dist
from coati.models.base import ActorCritic
import numpy as np

class MultiStepExperienceMaker(ExperienceMaker):
"""
Multi Step experience maker.
"""
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval()
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
self.ac = ActorCritic(self.actor, self.critic)
self.buffer = []

self.gamma = 1
self.lamda = 0.25
self.kl_coef = 1/128
in_len = input_ids.size(1)

# generate the sequence and get value at the same time
sequences, values , attention_mask, action_mask = self.ac.generate(input_ids,
**generate_kwargs)

# compute action log probs
action_logits = self.actor.model(sequences)['logits']
base_action_logits = self.initial_model.model(sequences)['logits']

action_log_probs = F.log_softmax(action_logits, dim=-1)
base_log_probs = F.log_softmax(base_action_logits, dim=-1)

# compute kl_div
kl_list = compute_approx_kl(action_log_probs, base_log_probs)
# clip kl
kl_list = torch.clamp(kl_list, max=10, min=1e-4)

# add eos token to the end of sequence and compute reward
eos_tensor = torch.tensor([self.eos_token_id], device=input_ids.device).repeat(input_ids.size(0), 1)
sequence_with_eos = torch.cat([sequences, eos_tensor], dim=-1)
rewards = self.reward_model(sequence_with_eos)

# reward clip
rewards = torch.clamp(rewards, max=10, min=-10)

# running mean reward
for i in range(rewards.size(0)):
value = rewards[i]
self.reward_count += 1
delta = value - self.reward_mean
self.reward_mean += delta / self.reward_count
delta2 = value - self.reward_mean
self.reward_M2 += delta * delta2

std = self.reward_M2 / (self.reward_count - 1)
rewards = (rewards - self.reward_mean)/std

print('rewards: ', rewards)
rewards = rewards * (1 - self.kl_coef)

# get action mask
action_mask = action_mask[:, in_len:]
kl_list = kl_list[:, in_len:]

# compute the advantages
advantages, returns = self.compute_gae(kl_list, rewards, values, action_mask)

for i in range(in_len, sequences.size(1) - 1):
for j in range(sequences.size(0)):
if sequences[j, i] != self.eos_token_id:
_state = sequences[j, :i]
_action_log_prob = action_log_probs[j, i]
_value = values[j, i-in_len]
_return = returns[j, i-in_len]
_adv = advantages[j, i-in_len]
_attention_mask = attention_mask[j, :i]
_action_mask = action_mask[j, :i-in_len]
exp = Experience(_state, _action_log_prob, _value, _return, _adv, _attention_mask, _action_mask)
self.buffer.append(exp)
buffer = self.buffer
return buffer

@torch.no_grad()
def compute_gae(self, kl_list: torch.Tensor,
reward: torch.Tensor,
values: torch.Tensor,
action_mask: torch.Tensor) -> torch.Tensor:
kl = -kl_list * action_mask * self.kl_coef
values = values * action_mask
T = torch.sum(values.ne(0), dim=1)
self.total_len = sum(T)
max_len = max(T)
gae_values = torch.zeros_like(values)
delta_list = torch.zeros_like(values)

# add reward to kl[:, -1]
for i in range(len(T)):
kl[i, T[i]-1] += reward[i]

# compute delta
for t in range(max_len - 1):
next_v = values[:,t + 1] if t + 1 < max_len else 0
delta_list[:, t] = kl[:, t] + self.gamma * next_v - values[:, t]

# compute gae
gae_values[:, max_len - 1] = delta_list[:, max_len - 1]
for t in range(max_len - 2, -1, -1):
gae_values[:, t] = delta_list[:, t] + self.gamma * self.lamda * gae_values[:, t + 1]

# compute return
returns = gae_values + values
return gae_values, returns
2 changes: 1 addition & 1 deletion applications/Chat/coati/experience_maker/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie

action_log_probs = self.actor(sequences, num_actions, attention_mask)
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
value = self.critic(sequences, action_mask, attention_mask)
value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)

Expand Down
4 changes: 2 additions & 2 deletions applications/Chat/coati/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import Actor, Critic, RewardModel
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss, MPolicyLoss

__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss', 'MPolicyLoss']
3 changes: 2 additions & 1 deletion applications/Chat/coati/models/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .critic import Critic
from .lm import LM
from .reward_model import RewardModel
from .ac import ActorCritic

__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM', 'ActorCritic']
59 changes: 59 additions & 0 deletions applications/Chat/coati/models/base/ac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..generation import generate_with_value
from ..lora import LoRAModule
from ..utils import log_probs_from_logits
from coati.models.base import Actor, Critic



class ActorCritic(nn.Module):
"""
ActorCritic model class.

Args:
model (nn.Module): Actor Model.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""

def __init__(self, actor:Actor, critic:Critic) -> None:
super().__init__()
self.actor = actor
self.critic = critic

@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences, values = generate_with_value(self.actor.model, self.critic, input_ids, **kwargs)

pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_mask = input_ids.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
# pad input_mask to be as long as attention_mask
input_mask = F.pad(input_mask, (0, attention_mask.shape[-1]-input_mask.shape[-1], 0, 0), value=0)
action_mask = attention_mask - input_mask

return sequences, values, attention_mask, action_mask

# def forward(self,
# sequences: torch.LongTensor,
# num_actions: int,
# attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# """Returns action log probs
# """
# output = self.actor.model(sequences, attention_mask=attention_mask)
# logits = output['logits']
# log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
# return log_probs[:, -num_actions:]

# def get_base_model(self):
# return self.model
21 changes: 5 additions & 16 deletions applications/Chat/coati/models/base/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,21 @@ def __init__(
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = 'none',
use_action_mask: bool = False,
lora_train_bias: str = 'none'
) -> None:

super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora()

def forward(self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# we have added a <eos> token at the end of each sequence
# so we can use the last hidden state to be the input of the marking linear layer to get the value score
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']

values = self.value_head(last_hidden_states).squeeze(-1)

if action_mask is not None and self.use_action_mask:
num_actions = action_mask.size(1)
prompt_mask = attention_mask[:, :-num_actions]
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value

values = values[:, :-1]
value = values.mean(dim=1)
value_prob = last_hidden_states[:, -1]
value = self.value_head(value_prob).squeeze(1) # ensure shape is (B)
return value
8 changes: 5 additions & 3 deletions applications/Chat/coati/models/base/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def __init__(self,
self.value_head = nn.Linear(model.config.n_embd, 1)

def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# we have added a <eos> token at the end of each sequence while preparing for the dataset
# so we can use the last hidden state to be the input of the marking linear layer to get the reward score
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
return value
reward_prob = last_hidden_states[:, -1]
reward = self.value_head(reward_prob).squeeze(1) # ensure shape is (B)
return reward
Loading