Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion applications/Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ More details can be found in the latest news.
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ColossalChat%20Speed.jpg" width=450/>
</p>

> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --max_timesteps 1 --update_timesteps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32

## Install

Expand Down
26 changes: 12 additions & 14 deletions applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def main(args):

(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)

trainer = PPOTrainer(strategy,
actor,
critic,
Expand All @@ -145,7 +151,6 @@ def main(args):
actor_optim,
critic_optim,
ptx_coef=0,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models,
max_length=512,
Expand All @@ -157,17 +162,11 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])

random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)

trainer.fit(dataloader,
None,
trainer.fit(prompt_dataloader=dataloader,
pretrain_dataloader=None,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)

print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')

Expand All @@ -183,9 +182,8 @@ def main(args):
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--num_collect_steps', type=int, default=8)
parser.add_argument('--num_update_steps', type=int, default=1)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0)
Expand Down
8 changes: 6 additions & 2 deletions applications/Chat/coati/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .base import Trainer
from .base import OnPolicyTrainer, SLTrainer
from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer

__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
__all__ = [
'SLTrainer', 'OnPolicyTrainer',
'RewardModelTrainer', 'SFTTrainer',
'PPOTrainer'
]
168 changes: 145 additions & 23 deletions applications/Chat/coati/trainer/base.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,108 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from contextlib import contextmanager
from typing import List

import torch
import torch.nn as nn
import tqdm
from coati.experience_maker import Experience
from coati.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from .callbacks import Callback
from .strategies import Strategy
from .utils import CycledDataLoader, is_rank_0


class Trainer(ABC):
class SLTrainer(ABC):
"""
Base class for rlhf trainers.
Base class for supervised learning trainers.

Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
model (nn.Module): the model to train
optim (Optimizer): the optimizer to use for training
"""

def __init__(self,
strategy: Strategy,
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer

@abstractmethod
def _train(self, epoch):
raise NotImplementedError()

@abstractmethod
def _eval(self, epoch):
raise NotImplementedError()

def _before_fit(self):
self.no_epoch_bar = False

def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs,
desc="Epochs",
disable=not is_rank_0() or self.no_epoch_bar
):
self._train(epoch)
self._eval(epoch)


class OnPolicyTrainer(ABC):
"""
Base class for on-policy rl trainers, e.g. PPO.

Args:
strategy (Strategy):the strategy to use for training
buffer (NaiveReplayBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""

def __init__(self,
strategy: Strategy,
max_epochs: int = 1,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs) -> None:
buffer: NaiveReplayBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []
) -> None:
super().__init__()
self.strategy = strategy
self.max_epochs = max_epochs
self.generate_kwargs = generate_kwargs
self.buffer = buffer
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks

# TODO(ver217): maybe simplify these code using context
def _on_fit_start(self) -> None:
@contextmanager
def _fit_ctx(self) -> None:
for callback in self.callbacks:
callback.on_fit_start()

def _on_fit_end(self) -> None:
for callback in self.callbacks:
callback.on_fit_end()

def _on_episode_start(self, episode: int) -> None:
try:
yield
finally:
for callback in self.callbacks:
callback.on_fit_end()

@contextmanager
def _episode_ctx(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_start(episode)

def _on_episode_end(self, episode: int) -> None:
for callback in self.callbacks:
callback.on_episode_end(episode)
try:
yield
finally:
for callback in self.callbacks:
callback.on_episode_end(episode)

def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
Expand All @@ -73,3 +127,71 @@ def _on_learn_batch_start(self) -> None:
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_learn_batch_end(metrics, experience)

@abstractmethod
def _make_experience(self, collect_step: int):
"""
Implement this method to make experience.
"""
raise NotImplementedError()

@abstractmethod
def _learn(self, update_step: int):
"""
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise NotImplementedError()

def _collect_phase(self, collect_step: int):
self._on_make_experience_start()
experience = self._make_experience(collect_step)
self._on_make_experience_end(experience)
self.buffer.append(experience)

def _update_phase(self, update_step: int):
self._on_learn_epoch_start(update_step)
self._learn(update_step)
self._on_learn_epoch_end(update_step)

def fit(self,
prompt_dataloader: DataLoader,
pretrain_dataloader: DataLoader,
num_episodes: int,
num_collect_steps: int,
num_update_steps: int,
):
"""
The main training loop of on-policy rl trainers.

Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)

with self._fit_ctx():
for episode in tqdm.trange(num_episodes,
desc="Episodes",
disable=not is_rank_0()):
with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps,
desc="Collect steps",
disable=not is_rank_0()):
self._collect_phase(collect_step)
if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer,
self.dataloader_pin_memory)
for update_step in tqdm.trange(num_update_steps,
desc="Update steps",
disable=not is_rank_0()):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.buffer.clear()
Loading