Skip to content
Merged
16 changes: 8 additions & 8 deletions applications/Chat/coati/dataset/prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def __init__(self,
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]

for data_dict in list_data_dict:
token = tokenizer(data_dict["instruction"],
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in token.items():
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()

def __len__(self):
return len(self.keyed_prompt["input_ids"])
Expand Down
30 changes: 15 additions & 15 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,18 @@ def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int
) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
tokenized_list = tokenizer(
strings, return_tensors="pt", padding="longest",
max_length=max_length, truncation=True
)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
Expand All @@ -105,7 +102,10 @@ def preprocess(
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length)
for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
Expand Down
12 changes: 8 additions & 4 deletions applications/Chat/coati/experience_maker/naive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from coati.models.utils import compute_reward, normalize
from coati.models.generation import generate_with_actor
from coati.models.utils import calc_action_log_probs, compute_reward, normalize

from .base import Experience, ExperienceMaker

Expand All @@ -16,13 +17,16 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
self.initial_model.eval()
self.reward_model.eval()

sequences, attention_mask, action_mask = self.actor.generate(input_ids,
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
input_ids,
return_action_mask=True,
**generate_kwargs)
num_actions = action_mask.size(1)

action_log_probs = self.actor(sequences, num_actions, attention_mask)
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
actor_output = self.actor(sequences, attention_mask)
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
Expand Down
14 changes: 8 additions & 6 deletions applications/Chat/coati/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
from typing import Union

import torch.nn as nn

from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel


def get_base_model(model: nn.Module) -> nn.Module:
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.

Args:
model (nn.Module): model to get base model from

Returns:
nn.Module: the base model
"""
if isinstance(model, Actor):
return model.get_base_model()
return model
assert isinstance(model, (Actor, Critic, RewardModel)), \
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
return model.model


__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
53 changes: 12 additions & 41 deletions applications/Chat/coati/models/base/actor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional, Tuple, Union
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..generation import generate
from ..lora import LoRAModule
from ..utils import log_probs_from_logits


class Actor(LoRAModule):
Expand All @@ -24,42 +21,16 @@ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str =
self.model = model
self.convert_to_lora()

@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]

def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
) -> torch.Tensor:
"""Returns model output.
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

def get_base_model(self):
return self.model
output = self.model(
input_ids,
attention_mask=attention_mask,
**model_kwargs
)
return output
41 changes: 37 additions & 4 deletions applications/Chat/coati/models/generation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F


try:
from transformers.generation_logits_process import (
Expand Down Expand Up @@ -55,9 +57,8 @@ def sample(model: nn.Module,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
'input_ids': input_ids
}
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
if prepare_inputs_fn is not None else {'input_ids': input_ids}
outputs = model(**model_inputs)

next_token_logits = outputs['logits'][:, -1, :]
Expand Down Expand Up @@ -144,3 +145,35 @@ def generate(model: nn.Module,
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")


@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)

# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
19 changes: 19 additions & 0 deletions applications/Chat/coati/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return log_probs_labels.squeeze(-1)


def calc_action_log_probs(output: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int
) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.

Returns:
torch.Tensor: Action log probs.
"""
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]


def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
Expand Down
23 changes: 13 additions & 10 deletions applications/Chat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import torch
import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.base import Actor, Critic, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -165,7 +166,8 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
self.critic.train()
# policy loss
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
Expand All @@ -175,8 +177,8 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
if self.ptx_coef != 0:
batch = next(iter(self.pretrain_dataloader))
batch = to_device(batch, self.device)
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)

Expand All @@ -200,14 +202,15 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
return {'reward': experience.reward.mean().item()}


def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy.unwrap_model(actor)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation

if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation

return new_kwargs
21 changes: 7 additions & 14 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torch.nn as nn
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -69,21 +68,16 @@ def prepare(
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
"""

def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)

rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
model = self.setup_model(model)
optimizer = self.setup_optimizer(optimizer, model)
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
rets.append(self.setup_model(model))
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')

Expand All @@ -93,16 +87,15 @@ def prepare_model(model: nn.Module):

@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.

Args:
model (nn.Module): the model to unwrap

Returns:
nn.Module: the original model (usually a huggingface model)
nn.Module: the original model
"""
return get_base_model(model)
return model

@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
Expand Down Expand Up @@ -133,4 +126,4 @@ def save_pretrained(self,

@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
pass
Loading