Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 47 additions & 74 deletions applications/Chat/coati/ray/example/1mmt_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,8 @@
get_reward_model_from_args,
get_strategy_from_args,
)
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer


def get_gpt_config(model_name: str) -> GPT2Config:
model_map = {
's': GPT2Config(),
'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
'2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
'4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
'6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
'8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
'10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
'12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
'15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
'18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
'20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
'24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
'28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
'32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
'36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
'40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
'175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
}
try:
return model_map[model_name]
except KeyError:
raise ValueError(f'Unknown model "{model_name}"')
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer


def get_free_port():
Expand Down Expand Up @@ -81,34 +52,16 @@ def main(args):
}

# configure tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token

def trainer_model_fn():
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
critic = get_critic_from_args(args.model, args.pretrain).half().cuda()
return actor, critic

# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
max_epochs=args.max_epochs,
eval_performance=True,
debug=args.debug,
) for i, env_info_trainer in enumerate(env_info_trainers)
]

def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
critic = get_critic_from_args(args.model, args.pretrain).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.pretrain).half().cuda()
initial_model = get_actor_from_args(args.model, args.pretrain).half().cuda()
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).half().cuda()
critic = get_critic_from_args(args.model, config=critic_cfg).half().cuda()
reward_model = get_reward_model_from_args(args.model, config=critic_cfg).half().cuda()
initial_model = get_actor_from_args(args.model, config=actor_cfg).half().cuda()
return actor, critic, reward_model, initial_model

# configure Experience Maker
Expand All @@ -117,7 +70,6 @@ def model_fn():
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
debug=args.debug,
# sync_models_from_trainers=True,
Expand All @@ -132,14 +84,37 @@ def model_fn():
use_cache=True,
)

# configure sampler
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400))
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.model, config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
return actor, critic

# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
) for i, env_info_trainer in enumerate(env_info_trainers)
]

dataset_size = args.experience_batch_size * 4

def tokenize_fn(texts):
input_ids = torch.stack(texts).cuda()
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}

def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
dataloader = DataLoader(dataset, batch_size=args.experience_batch_size)
return dataloader

# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
Expand All @@ -148,15 +123,13 @@ def tokenize_fn(texts):

wait_tasks = []

for trainer_ref in trainer_refs:
wait_tasks.append(
trainer_ref.fit.remote(num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps))
wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
num_steps=args.experience_steps))

num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * \
args.max_epochs * args.num_trainers + 3 # +3 for fault tolerance
wait_tasks.append(experience_holder_ref.workingloop.remote(random_prompts, tokenize_fn, times=num_exp_per_maker))
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))

ray.get(wait_tasks)

Expand All @@ -170,12 +143,12 @@ def tokenize_fn(texts):
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")

parser.add_argument('--debug', action='store_true')
Expand Down
16 changes: 8 additions & 8 deletions applications/Chat/coati/ray/src/detached_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ class DetachedReplayBuffer:
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
'''

def __init__(self,
sample_batch_size: int,
tp_world_size: int = 1,
limit: int = 0,
cpu_offload: bool = True) -> None:
self.cpu_offload = cpu_offload
def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
self.limit = limit
self.items = Queue(self.limit, actor_options={"num_cpus": 1})
Expand All @@ -51,9 +46,14 @@ def append(self, experience: Experience) -> None:
'''
Expected to be called remotely.
'''
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
items = split_experience_batch(experience)
self.extend(items)

@torch.no_grad()
def extend(self, items: List[BufferItem]) -> None:
'''
Expected to be called remotely.
'''
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size]
Expand Down
64 changes: 41 additions & 23 deletions applications/Chat/coati/ray/src/detached_trainer_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import ray
import torch
from coati.experience_maker import Experience
from coati.replay_buffer.utils import BufferItem
from coati.trainer.callbacks import Callback
from torch.utils.data import DataLoader
from tqdm import tqdm

from .detached_replay_buffer import DetachedReplayBuffer
Expand All @@ -21,7 +24,6 @@ class DetachedTrainer(ABC):
Args:
detached_strategy (DetachedStrategy): the strategy to use for training
detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
Expand All @@ -32,16 +34,11 @@ def __init__(self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
max_epochs: int = 1,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
debug: bool = False) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size,
limit=buffer_limit,
cpu_offload=buffer_cpu_offload)
self.max_epochs = max_epochs
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
self.target_holder_name_list = experience_maker_holder_name_list
Expand All @@ -66,31 +63,45 @@ def sync_models_to_remote_makers(self, **kwargs):
def training_step(self, experience: Experience) -> Dict[str, Any]:
pass

def _learn(self):
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
for _ in pbar:
if self._debug:
print("[trainer] sampling exp")
experience = self._buffer_sample()
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
self._learn_epoch(pbar, data)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
self._learn_epoch(pbar, data)

def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
is_warmup = len(data) == 0
for x in pbar:
if self._debug:
print("[trainer] training step")
# sample a batch and then train to avoid waiting
experience = x if not is_warmup else self._buffer_sample()
experience.to_device(torch.cuda.current_device())
self._on_learn_batch_start()
metrics = self.training_step(experience)
self._on_learn_batch_end(metrics, experience)

if self._debug:
print("[trainer] step over")
experience.to_device("cpu")
if is_warmup:
data.append(experience)
pbar.set_postfix(metrics)

def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
for episode in range(num_episodes):
self._on_episode_start(episode)
for timestep in tqdm(range(max_timesteps // update_timesteps),
desc=f'Episode [{episode+1}/{num_episodes}]',
disable=not is_rank_0()):
self._learn()
self._update_remote_makers()
self._on_episode_end(episode)
for _ in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
self._learn(update_steps, train_epochs)
self._update_remote_makers()
self._on_fit_end()
self._on_finish()

Expand All @@ -108,6 +119,13 @@ def buffer_append(self, experience: Experience):
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.append(experience)

@ray.method(concurrency_group="buffer_append")
def buffer_extend(self, items: List[BufferItem]):
# called by ExperienceMakerHolder
if self._debug:
print(f"[trainer] receiving exp.")
self.detached_replay_buffer.extend(items)

@ray.method(concurrency_group="buffer_sample")
def _buffer_sample(self):
return self.detached_replay_buffer.sample()
Expand Down
5 changes: 0 additions & 5 deletions applications/Chat/coati/ray/src/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def __init__(
env_info: Dict[str, str] = None,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.4,
max_epochs: int = 10,
dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [],
eval_performance: bool = False,
Expand Down Expand Up @@ -101,8 +99,6 @@ def __init__(
super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
buffer_cpu_offload=buffer_cpu_offload,
max_epochs=max_epochs,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug)
Expand Down Expand Up @@ -144,7 +140,6 @@ def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
self.critic.train()

experience.to_device(torch.cuda.current_device())
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
Expand Down
Loading