diff --git a/applications/Chat/coati/ray/example/1mmt_dummy.py b/applications/Chat/coati/ray/example/1mmt_dummy.py index fdb742406b26..c7619ea6940b 100644 --- a/applications/Chat/coati/ray/example/1mmt_dummy.py +++ b/applications/Chat/coati/ray/example/1mmt_dummy.py @@ -13,37 +13,8 @@ get_reward_model_from_args, get_strategy_from_args, ) -from transformers import AutoTokenizer, BloomTokenizerFast -from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer - - -def get_gpt_config(model_name: str) -> GPT2Config: - model_map = { - 's': GPT2Config(), - 'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16), - 'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20), - 'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25), - '2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16), - '4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16), - '6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16), - '8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16), - '10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16), - '12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16), - '15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16), - '18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16), - '20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16), - '24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16), - '28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16), - '32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16), - '36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16), - '40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16), - '175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96), - } - try: - return model_map[model_name] - except KeyError: - raise ValueError(f'Unknown model "{model_name}"') +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoTokenizer def get_free_port(): @@ -81,34 +52,16 @@ def main(args): } # configure tokenizer - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer = AutoTokenizer.from_pretrained(args.pretrain) tokenizer.pad_token = tokenizer.eos_token - def trainer_model_fn(): - actor = get_actor_from_args(args.model, args.pretrain).half().cuda() - critic = get_critic_from_args(args.model, args.pretrain).half().cuda() - return actor, critic - - # configure Trainer - trainer_refs = [ - DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1"], - strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), - model_fn=trainer_model_fn, - env_info=env_info_trainer, - train_batch_size=args.train_batch_size, - buffer_limit=16, - max_epochs=args.max_epochs, - eval_performance=True, - debug=args.debug, - ) for i, env_info_trainer in enumerate(env_info_trainers) - ] - def model_fn(): - actor = get_actor_from_args(args.model, args.pretrain).half().cuda() - critic = get_critic_from_args(args.model, args.pretrain).half().cuda() - reward_model = get_reward_model_from_args(args.model, args.pretrain).half().cuda() - initial_model = get_actor_from_args(args.model, args.pretrain).half().cuda() + actor_cfg = AutoConfig.from_pretrained(args.pretrain) + critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) + actor = get_actor_from_args(args.model, config=actor_cfg).half().cuda() + critic = get_critic_from_args(args.model, config=critic_cfg).half().cuda() + reward_model = get_reward_model_from_args(args.model, config=critic_cfg).half().cuda() + initial_model = get_actor_from_args(args.model, config=actor_cfg).half().cuda() return actor, critic, reward_model, initial_model # configure Experience Maker @@ -117,7 +70,6 @@ def model_fn(): strategy_fn=partial(get_strategy_from_args, args.maker_strategy), model_fn=model_fn, env_info=env_info_maker, - experience_batch_size=args.experience_batch_size, kl_coef=0.1, debug=args.debug, # sync_models_from_trainers=True, @@ -132,14 +84,37 @@ def model_fn(): use_cache=True, ) - # configure sampler - random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400)) + def trainer_model_fn(): + actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() + critic = get_critic_from_args(args.model, config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() + return actor, critic + + # configure Trainer + trainer_refs = [ + DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, + env_info=env_info_trainer, + train_batch_size=args.train_batch_size, + buffer_limit=16, + eval_performance=True, + debug=args.debug, + ) for i, env_info_trainer in enumerate(env_info_trainers) + ] + + dataset_size = args.experience_batch_size * 4 - def tokenize_fn(texts): - input_ids = torch.stack(texts).cuda() + def data_gen_fn(): + input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) attn_mask = torch.ones_like(input_ids) return {'input_ids': input_ids, 'attention_mask': attn_mask} + def build_dataloader(size): + dataset = [data_gen_fn() for _ in range(size)] + dataloader = DataLoader(dataset, batch_size=args.experience_batch_size) + return dataloader + # uncomment this function if sync_models_from_trainers is True # ray.get([ # trainer_ref.sync_models_to_remote_makers.remote() @@ -148,15 +123,13 @@ def tokenize_fn(texts): wait_tasks = [] - for trainer_ref in trainer_refs: - wait_tasks.append( - trainer_ref.fit.remote(num_episodes=args.num_episodes, - max_timesteps=args.max_timesteps, - update_timesteps=args.update_timesteps)) + wait_tasks.append( + experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), + num_steps=args.experience_steps)) - num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * \ - args.max_epochs * args.num_trainers + 3 # +3 for fault tolerance - wait_tasks.append(experience_holder_ref.workingloop.remote(random_prompts, tokenize_fn, times=num_exp_per_maker)) + total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size) + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) ray.get(wait_tasks) @@ -170,12 +143,12 @@ def tokenize_fn(texts): parser.add_argument('--maker_strategy', choices=['naive'], default='naive') parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--max_epochs', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--critic_pretrain', type=str, default=None) + parser.add_argument('--experience_steps', type=int, default=4) parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--train_epochs', type=int, default=1) + parser.add_argument('--update_steps', type=int, default=2) + parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--debug', action='store_true') diff --git a/applications/Chat/coati/ray/src/detached_replay_buffer.py b/applications/Chat/coati/ray/src/detached_replay_buffer.py index 4bc74bb878fd..257b0b072493 100644 --- a/applications/Chat/coati/ray/src/detached_replay_buffer.py +++ b/applications/Chat/coati/ray/src/detached_replay_buffer.py @@ -26,12 +26,7 @@ class DetachedReplayBuffer: cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. ''' - def __init__(self, - sample_batch_size: int, - tp_world_size: int = 1, - limit: int = 0, - cpu_offload: bool = True) -> None: - self.cpu_offload = cpu_offload + def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit: int = 0) -> None: self.sample_batch_size = sample_batch_size self.limit = limit self.items = Queue(self.limit, actor_options={"num_cpus": 1}) @@ -51,9 +46,14 @@ def append(self, experience: Experience) -> None: ''' Expected to be called remotely. ''' - if self.cpu_offload: - experience.to_device(torch.device('cpu')) items = split_experience_batch(experience) + self.extend(items) + + @torch.no_grad() + def extend(self, items: List[BufferItem]) -> None: + ''' + Expected to be called remotely. + ''' self.batch_collector.extend(items) while len(self.batch_collector) >= self.sample_batch_size: items = self.batch_collector[:self.sample_batch_size] diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py index 86b60582a614..1137d8f7b491 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_base.py +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -1,10 +1,13 @@ import os from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import ray +import torch from coati.experience_maker import Experience +from coati.replay_buffer.utils import BufferItem from coati.trainer.callbacks import Callback +from torch.utils.data import DataLoader from tqdm import tqdm from .detached_replay_buffer import DetachedReplayBuffer @@ -21,7 +24,6 @@ class DetachedTrainer(ABC): Args: detached_strategy (DetachedStrategy): the strategy to use for training detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training - max_epochs (int, defaults to 1): the number of epochs of training process data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating @@ -32,16 +34,11 @@ def __init__(self, experience_maker_holder_name_list: List[str], train_batch_size: int = 8, buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - max_epochs: int = 1, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], debug: bool = False) -> None: super().__init__() - self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, - limit=buffer_limit, - cpu_offload=buffer_cpu_offload) - self.max_epochs = max_epochs + self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit) self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks self.target_holder_name_list = experience_maker_holder_name_list @@ -66,31 +63,45 @@ def sync_models_to_remote_makers(self, **kwargs): def training_step(self, experience: Experience) -> Dict[str, Any]: pass - def _learn(self): - pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) - for _ in pbar: - if self._debug: - print("[trainer] sampling exp") - experience = self._buffer_sample() + def _learn(self, update_steps: int, train_epochs: int) -> None: + data = [] + # warmup + pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0()) + self._learn_epoch(pbar, data) + # item is already a batch + dataloader = DataLoader(data, + batch_size=1, + shuffle=True, + pin_memory=self.dataloader_pin_memory, + collate_fn=lambda x: x[0]) + for epoch in range(1, train_epochs): + pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0()) + self._learn_epoch(pbar, data) + + def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: + is_warmup = len(data) == 0 + for x in pbar: if self._debug: print("[trainer] training step") + # sample a batch and then train to avoid waiting + experience = x if not is_warmup else self._buffer_sample() + experience.to_device(torch.cuda.current_device()) self._on_learn_batch_start() metrics = self.training_step(experience) self._on_learn_batch_end(metrics, experience) + if self._debug: print("[trainer] step over") + experience.to_device("cpu") + if is_warmup: + data.append(experience) pbar.set_postfix(metrics) - def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None: + def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None: self._on_fit_start() - for episode in range(num_episodes): - self._on_episode_start(episode) - for timestep in tqdm(range(max_timesteps // update_timesteps), - desc=f'Episode [{episode+1}/{num_episodes}]', - disable=not is_rank_0()): - self._learn() - self._update_remote_makers() - self._on_episode_end(episode) + for _ in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()): + self._learn(update_steps, train_epochs) + self._update_remote_makers() self._on_fit_end() self._on_finish() @@ -108,6 +119,13 @@ def buffer_append(self, experience: Experience): print(f"[trainer] receiving exp.") self.detached_replay_buffer.append(experience) + @ray.method(concurrency_group="buffer_append") + def buffer_extend(self, items: List[BufferItem]): + # called by ExperienceMakerHolder + if self._debug: + print(f"[trainer] receiving exp.") + self.detached_replay_buffer.extend(items) + @ray.method(concurrency_group="buffer_sample") def _buffer_sample(self): return self.detached_replay_buffer.sample() diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index 056942b83360..b0630cd0b5ae 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -60,10 +60,8 @@ def __init__( env_info: Dict[str, str] = None, train_batch_size: int = 8, buffer_limit: int = 0, - buffer_cpu_offload: bool = True, eps_clip: float = 0.2, value_clip: float = 0.4, - max_epochs: int = 10, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], eval_performance: bool = False, @@ -101,8 +99,6 @@ def __init__( super().__init__(experience_maker_holder_name_list, train_batch_size=train_batch_size, buffer_limit=buffer_limit, - buffer_cpu_offload=buffer_cpu_offload, - max_epochs=max_epochs, dataloader_pin_memory=dataloader_pin_memory, callbacks=callbacks, debug=debug) @@ -144,7 +140,6 @@ def training_step(self, experience: Experience) -> Dict[str, float]: self.actor.train() self.critic.train() - experience.to_device(torch.cuda.current_device()) num_actions = experience.action_mask.size(1) action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) actor_loss = self.actor_loss_fn(action_log_probs, diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index c85d36bab360..ebeb58137370 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -3,7 +3,7 @@ import tracemalloc from copy import deepcopy from threading import Lock -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import ray import torch @@ -17,16 +17,9 @@ from coati.trainer.strategies.sampler import DistributedSampler from ray.exceptions import GetTimeoutError from torch import Tensor +from tqdm import tqdm -from .utils import ( - get_actor_from_args, - get_critic_from_args, - get_model_numel, - get_reward_model_from_args, - get_strategy_from_args, - is_rank_0, - set_dist_env, -) +from .utils import get_model_numel, is_rank_0, set_dist_env @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) @@ -35,7 +28,6 @@ class ExperienceMakerHolder: Args: detached_trainer_name_list: str list to get ray actor handles strategy: - experience_batch_size: batch size of generated experience kl_coef: the coefficient of kl divergence loss sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. ''' @@ -48,8 +40,7 @@ def __init__( model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], env_info: Dict[str, str] = None, sync_models_from_trainers: bool = False, - experience_batch_size: int = 8, - send_grain_size: int = 4, + buffer_cpu_offload: bool = True, kl_coef: float = 0.1, callbacks: List[Callback] = [], eval_performance: bool = False, @@ -62,7 +53,7 @@ def __init__( for name in detached_trainer_name_list: self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) self.strategy = strategy_fn() - self.experience_batch_size = experience_batch_size + self.buffer_cpu_offload = buffer_cpu_offload self.kl_coef = kl_coef # init models with self.strategy.model_init_context(): @@ -80,7 +71,6 @@ def __init__( actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) self.callbacks = callbacks - self.send_grain_size = send_grain_size self._model_visit_lock = Lock() @@ -89,6 +79,8 @@ def __init__( self._debug = debug self.target_auto_balance = False + self._target_idx = 0 + if self._debug and not self._is_fully_initialized: print('[maker] Waiting for INIT') @@ -114,6 +106,7 @@ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experien else: raise ValueError(f'Unsupported input type "{type(inputs)}"') + # TODO(ver217): remove this method @ray.method(concurrency_group="experience_io") def _send_experience(self, experience): if not self.target_auto_balance: @@ -148,33 +141,52 @@ def _send_experience(self, experience): print(f"[maker] sending exp to {chosen_trainer}") chosen_trainer.buffer_append.remote(experience) - def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000): - self._get_ready() - sampler = self.strategy.setup_sampler(dataset) - for _ in range(times): - rand_prompts = sampler.sample(self.experience_batch_size) - if tokenizer is not None: - inputs = tokenizer(rand_prompts) - else: - inputs = rand_prompts - self._model_visit_lock.acquire() + @ray.method(concurrency_group="experience_io") + def _send_items(self, experience: Experience) -> None: + items = split_experience_batch(experience) + items_per_trainer = [[] for _ in range(len(self.target_trainer_list))] + for item in items: + items_per_trainer[self._target_idx].append(item) + self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) + for i, target_trainer in enumerate(self.target_trainer_list): + if len(items_per_trainer[i]) > 0: + target_trainer.buffer_extend.remote(items_per_trainer[i]) + + def _inference_step(self, batch) -> None: + with self._model_visit_lock: self._on_make_experience_start() - experience = self._make_experience(inputs=inputs) + experience = self._make_experience(batch) self._on_make_experience_end(experience) - self._model_visit_lock.release() - # split experience for smoother handover - items = split_experience_batch(experience) - temp_buffer = [] - for item in items: - temp_buffer.append(item) - if len(temp_buffer) >= self.send_grain_size: - experience_fragment = make_experience_batch(temp_buffer) - self._send_experience(experience=experience_fragment) - temp_buffer = [] - # remain - if len(temp_buffer) > 0: - experience_fragment = make_experience_batch(temp_buffer) - self._send_experience(experience=experience_fragment) + if self.buffer_cpu_offload: + experience.to_device('cpu') + self._send_items(experience) + + def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0): + """Working loop of the experience maker. + + Args: + dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader. + num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1. + num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0. + """ + self._get_ready() + dataloader = dataloader_fn() + if num_steps > 0: + # ignore num epochs + it = iter(dataloader) + for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()): + try: + batch = next(it) + except StopIteration: + it = iter(dataloader) + batch = next(it) + self._inference_step(batch) + else: + with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar: + for _ in range(num_epochs): + for batch in dataloader: + self._inference_step(batch) + pbar.update() self._on_finish() @ray.method(concurrency_group="model_io")