Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions applications/Chat/benchmarks/benchmark_gpt_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts,
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
Expand Down
6 changes: 4 additions & 2 deletions applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
trainer.fit(random_prompts,
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 400), device=torch.cuda.current_device())
random_attention_mask = torch.randint(1, (1000, 1, 400), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
Expand Down
94 changes: 2 additions & 92 deletions applications/Chat/coati/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from coati.experience_maker import Experience, ExperienceMaker
from coati.replay_buffer import ReplayBuffer
from torch import Tensor
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from coati.experience_maker import Experience

from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0


class Trainer(ABC):
Expand All @@ -19,113 +14,28 @@ class Trainer(ABC):

Args:
strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""

def __init__(self,
strategy: Strategy,
experience_maker: ExperienceMaker,
replay_buffer: ReplayBuffer,
experience_batch_size: int = 8,
max_epochs: int = 1,
tokenizer: Optional[Callable[[Any], dict]] = None,
sample_replay_buffer: bool = False,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
super().__init__()
self.strategy = strategy
self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs
self.sample_replay_buffer = sample_replay_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks

@abstractmethod
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass

def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')

def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]

def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)

def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
self.experience_maker.initial_model.to(torch.cuda.current_device())
self.experience_maker.reward_model.to(torch.cuda.current_device())
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()

# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
for callback in self.callbacks:
Expand Down
91 changes: 88 additions & 3 deletions applications/Chat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -7,12 +7,16 @@
from coati.models.generation_utils import update_model_kwargs_fn
from coati.models.loss import PolicyLoss, ValueLoss
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from tqdm import tqdm

from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0


class PPOTrainer(Trainer):
Expand All @@ -33,6 +37,7 @@ class PPOTrainer(Trainer):
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
Expand Down Expand Up @@ -69,8 +74,13 @@ def __init__(self,
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
super().__init__(strategy, max_epochs, tokenizer, dataloader_pin_memory, callbacks, **generate_kwargs)

self.experience_maker = experience_maker
self.replay_buffer = replay_buffer
self.experience_batch_size = experience_batch_size
self.sample_replay_buffer = sample_replay_buffer

self.actor = actor
self.critic = critic

Expand All @@ -82,6 +92,81 @@ def __init__(self,
self.actor_optim = actor_optim
self.critic_optim = critic_optim

def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')

def _sample_prompts(self, prompts) -> list:
indices = list(range(len(prompts)))
sampled_indices = self.strategy.experience_sampler.choice(
indices, self.experience_batch_size, replace=False)
return [prompts[i] for i in sampled_indices]

def _learn(self):
# replay buffer may be empty at first, we should rebuild at each training
if not self.sample_replay_buffer:
dataloader = self.strategy.setup_dataloader(
self.replay_buffer, self.dataloader_pin_memory)
device = torch.cuda.current_device()
if self.sample_replay_buffer:
pbar = tqdm(range(self.max_epochs), desc='Train epoch',
disable=not is_rank_0())
for _ in pbar:
experience = self.replay_buffer.sample()
metrics = self.training_step(experience)
pbar.set_postfix(metrics)
else:
for epoch in range(self.max_epochs):
self._on_learn_epoch_start(epoch)
if isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(epoch)
pbar = tqdm(
dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(device)
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)
pbar.set_postfix(metrics)
self._on_learn_epoch_end(epoch)

def fit(self,
prompt_dataloader,
pretrain_dataloader,
num_episodes: int = 50000,
max_timesteps: int = 500,
update_timesteps: int = 5000) -> None:
time = 0
self.pretrain_dataloader = pretrain_dataloader
self.prompt_dataloader = prompt_dataloader
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
time += 1
prompts = next(iter(self.prompt_dataloader))
self._on_make_experience_start()
self.experience_maker.initial_model.to(
torch.cuda.current_device())
self.experience_maker.reward_model.to(
torch.cuda.current_device())
experience = self._make_experience(prompts)
self._on_make_experience_end(experience)
self.replay_buffer.append(experience)
if time % update_timesteps == 0:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self._learn()
self.replay_buffer.clear()
self._on_episode_end(episode)
self._on_fit_end()

def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()
Expand Down
40 changes: 18 additions & 22 deletions applications/Chat/coati/trainer/rm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC
from datetime import datetime
from typing import Optional
from typing import Optional, List

import pandas as pd
import torch
Expand All @@ -10,11 +9,13 @@
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

from .callbacks import Callback
from .base import Trainer
from .strategies import Strategy
from .utils import is_rank_0


class RewardModelTrainer(ABC):
class RewardModelTrainer(Trainer):
"""
Trainer to use while training reward model.

Expand All @@ -23,11 +24,12 @@ class RewardModelTrainer(ABC):
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
loss_fn (callable): the loss function to use for training
train_dataset (Dataset): the dataset to use for training
valid_dataset (Dataset): the dataset to use for validation
eval_dataset (Dataset): the dataset to use for evaluation
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""

def __init__(
Expand All @@ -36,25 +38,19 @@ def __init__(
strategy: Strategy,
optim: Optimizer,
loss_fn,
train_dataset: Dataset,
valid_dataset: Dataset,
eval_dataset: Dataset,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader,
batch_size: int = 1,
max_epochs: int = 1,
callbacks: List[Callback] = [],
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
super().__init__(strategy, max_epochs, callbacks=callbacks)
train_sampler = None

if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
self.train_dataloader = DataLoader(train_dataset,
shuffle=(train_sampler is None),
sampler=train_sampler,
batch_size=batch_size)
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader

self.model = strategy.setup_model(model)
self.loss_fn = loss_fn
Expand Down Expand Up @@ -86,8 +82,8 @@ def eval_acc(self, dataloader):

def fit(self):
time = datetime.now()
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs):
epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.max_epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
Expand Down
Loading