Skip to content
Merged

Ra #66

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions applications/Chat/coati/dataset/prompt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def __init__(self,
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]

for data_dict in list_data_dict:
token = tokenizer(data_dict["instruction"],
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in token.items():
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()

def __len__(self):
return len(self.keyed_prompt["input_ids"])
Expand Down
30 changes: 15 additions & 15 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,18 @@ def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int
) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
tokenized_list = tokenizer(
strings, return_tensors="pt", padding="longest",
max_length=max_length, truncation=True
)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
Expand All @@ -105,7 +102,10 @@ def preprocess(
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length)
for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
Expand Down
12 changes: 8 additions & 4 deletions applications/Chat/coati/experience_maker/naive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from coati.models.utils import compute_reward, normalize
from coati.models.generation import generate_with_actor
from coati.models.utils import calc_action_log_probs, compute_reward, normalize

from .base import Experience, ExperienceMaker

Expand All @@ -16,13 +17,16 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
self.initial_model.eval()
self.reward_model.eval()

sequences, attention_mask, action_mask = self.actor.generate(input_ids,
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
input_ids,
return_action_mask=True,
**generate_kwargs)
num_actions = action_mask.size(1)

action_log_probs = self.actor(sequences, num_actions, attention_mask)
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
actor_output = self.actor(sequences, attention_mask)
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
Expand Down
14 changes: 8 additions & 6 deletions applications/Chat/coati/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
from typing import Union

import torch.nn as nn

from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel


def get_base_model(model: nn.Module) -> nn.Module:
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.

Args:
model (nn.Module): model to get base model from

Returns:
nn.Module: the base model
"""
if isinstance(model, Actor):
return model.get_base_model()
return model
assert isinstance(model, (Actor, Critic, RewardModel)), \
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
return model.model


__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
53 changes: 12 additions & 41 deletions applications/Chat/coati/models/base/actor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional, Tuple, Union
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..generation import generate
from ..lora import LoRAModule
from ..utils import log_probs_from_logits


class Actor(LoRAModule):
Expand All @@ -24,42 +21,16 @@ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str =
self.model = model
self.convert_to_lora()

@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]

def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
) -> torch.Tensor:
"""Returns model output.
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

def get_base_model(self):
return self.model
output = self.model(
input_ids,
attention_mask=attention_mask,
**model_kwargs
)
return output
41 changes: 37 additions & 4 deletions applications/Chat/coati/models/generation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F


try:
from transformers.generation_logits_process import (
Expand Down Expand Up @@ -55,9 +57,8 @@ def sample(model: nn.Module,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
'input_ids': input_ids
}
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
if prepare_inputs_fn is not None else {'input_ids': input_ids}
outputs = model(**model_inputs)

next_token_logits = outputs['logits'][:, -1, :]
Expand Down Expand Up @@ -144,3 +145,35 @@ def generate(model: nn.Module,
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")


@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)

# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
19 changes: 19 additions & 0 deletions applications/Chat/coati/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return log_probs_labels.squeeze(-1)


def calc_action_log_probs(output: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int
) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.

Returns:
torch.Tensor: Action log probs.
"""
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]


def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
Expand Down
23 changes: 13 additions & 10 deletions applications/Chat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import torch
import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.base import Actor, Critic, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -165,7 +166,8 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
self.critic.train()
# policy loss
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
Expand All @@ -175,8 +177,8 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
if self.ptx_coef != 0:
batch = next(iter(self.pretrain_dataloader))
batch = to_device(batch, self.device)
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)

Expand All @@ -200,14 +202,15 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
return {'reward': experience.reward.mean().item()}


def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy.unwrap_model(actor)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation

if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation

return new_kwargs
21 changes: 7 additions & 14 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch
import torch.nn as nn
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -69,21 +68,16 @@ def prepare(
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
"""

def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)

rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
model = self.setup_model(model)
optimizer = self.setup_optimizer(optimizer, model)
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
rets.append(self.setup_model(model))
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')

Expand All @@ -93,16 +87,15 @@ def prepare_model(model: nn.Module):

@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.

Args:
model (nn.Module): the model to unwrap

Returns:
nn.Module: the original model (usually a huggingface model)
nn.Module: the original model
"""
return get_base_model(model)
return model

@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
Expand Down Expand Up @@ -133,4 +126,4 @@ def save_pretrained(self,

@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
pass
Loading