From 10a9cf912e3e2b78e83afba10761ca75a3e8575d Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 21 Apr 2023 17:23:48 +0800 Subject: [PATCH 1/5] [chat] refactor experience maker holder --- .../Chat/coati/ray/example/1mmt_dummy.py | 32 +++--- .../coati/ray/src/experience_maker_holder.py | 103 +++++++----------- 2 files changed, 54 insertions(+), 81 deletions(-) diff --git a/applications/Chat/coati/ray/example/1mmt_dummy.py b/applications/Chat/coati/ray/example/1mmt_dummy.py index 68b666663b12..64071cd73a6d 100644 --- a/applications/Chat/coati/ray/example/1mmt_dummy.py +++ b/applications/Chat/coati/ray/example/1mmt_dummy.py @@ -1,15 +1,18 @@ import argparse import os import socket -from copy import deepcopy from functools import partial import ray import torch -from coati.models.base import RewardModel from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.src.experience_maker_holder import ExperienceMakerHolder -from coati.ray.src.utils import get_actor_from_args, get_critic_from_args, get_reward_model_from_args +from coati.ray.src.utils import ( + get_actor_from_args, + get_critic_from_args, + get_reward_model_from_args, + get_strategy_from_args, +) from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -106,10 +109,18 @@ def main(args): ) for i, env_info_trainer in enumerate(env_info_trainers) ] + def model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).half().cuda() + critic = get_critic_from_args(args.model, args.pretrain).half().cuda() + reward_model = get_reward_model_from_args(args.model, args.pretrain).half().cuda() + initial_model = get_actor_from_args(args.model, args.pretrain).half().cuda() + return actor, critic, reward_model, initial_model + # configure Experience Maker experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], - strategy=args.maker_strategy, + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, env_info=env_info_maker, experience_batch_size=args.experience_batch_size, kl_coef=0.1, @@ -125,19 +136,6 @@ def main(args): debug=args.debug, ) - def init_inference_model(fn, model_name, pretrained): - model = fn(model_name, pretrained) - return model.half().cuda() - - # init maker locally - ray.get( - experience_holder_ref.initialize_experience_maker_local.remote( - initial_model_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain), - reward_model_func=partial(init_inference_model, get_reward_model_from_args, args.model, args.pretrain), - actor_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain), - critic_func=partial(init_inference_model, get_critic_from_args, args.model, args.pretrain), - )) - # configure sampler random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400)) diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index b1acdbb5494d..b4d234c9df0c 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -3,7 +3,7 @@ import tracemalloc from copy import deepcopy from threading import Lock -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ray import torch @@ -36,39 +36,52 @@ class ExperienceMakerHolder: strategy: experience_batch_size: batch size of generated experience kl_coef: the coefficient of kl divergence loss + sync_models_from_trainers: whether to sync models from trainers. If True, you must call update_experience_maker() in trainers to sync models. ''' - def __init__(self, - detached_trainer_name_list: List[str], - strategy: str, - env_info: Dict[str, str] = None, - experience_batch_size: int = 8, - kl_coef: float = 0.1, - callbacks: List[Callback] = [], - eval_performance: bool = False, - debug: bool = False, - **generate_kwargs): + def __init__( + self, + detached_trainer_name_list: List[str], + strategy_fn: Callable[[], Strategy], + # a function returns (actor, critic, reward_model, initial_model) + model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], + env_info: Dict[str, str] = None, + sync_models_from_trainers: bool = False, + experience_batch_size: int = 8, + kl_coef: float = 0.1, + callbacks: List[Callback] = [], + eval_performance: bool = False, + debug: bool = False, + **generate_kwargs): # set environment variables if env_info: set_dist_env(env_info=env_info) self.target_trainer_list = [] for name in detached_trainer_name_list: self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) - self.strategy_str = strategy - self.strategy = get_strategy_from_args(strategy) + self.strategy = strategy_fn() self.experience_batch_size = experience_batch_size self.kl_coef = kl_coef - self.generate_kwargs = generate_kwargs - actor, critic, reward_model, initial_model = None, None, None, None + # init models + with self.strategy.model_init_context(): + actor, critic, reward_model, initial_model = model_fn() + self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor) + if eval_performance: + actor_numel = get_model_numel(actor) + critic_numel = get_model_numel(critic) + initial_model_numel = get_model_numel(initial_model) + reward_model_numel = get_model_numel(reward_model) + evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, + reward_model_numel) + callbacks = callbacks + [evaluator] + + actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) self.callbacks = callbacks - self.eval_performance = eval_performance self._model_visit_lock = Lock() - self._initial_model_initialized = False - self._reward_model_initialized = False - self._actor_initialized = False - self._critic_initialized = False + + self._is_fully_initialized = not sync_models_from_trainers self._debug = debug self.target_auto_balance = False @@ -79,28 +92,9 @@ def __init__(self, def _get_ready(self): while not self._fully_initialized(): time.sleep(1.0) - # setup performance evaluator - if self.eval_performance: - actor_numel = get_model_numel(self.experience_maker.actor) - critic_numel = get_model_numel(self.experience_maker.critic) - initial_model_numel = get_model_numel(self.experience_maker.initial_model) - reward_model_numel = get_model_numel(self.experience_maker.reward_model) - evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, - reward_model_numel) - self.callbacks.append(evaluator) - - self.generate_kwargs = _set_default_generate_kwargs(self.generate_kwargs, self.experience_maker.actor) def _fully_initialized(self): - if not self._initial_model_initialized: - return False - if not self._reward_model_initialized: - return False - if not self._actor_initialized: - return False - if not self._critic_initialized: - return False - return True + return self._is_fully_initialized def update_target_trainer_list(self, detached_trainer_name_list): self.target_trainer_list = [] @@ -168,6 +162,7 @@ def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None self._send_experience(experience=experience) self._on_finish() + # TODO(ver217): remove this function @ray.method(concurrency_group="model_io") def initialize_experience_maker(self, actor_model: str = None, @@ -238,40 +233,18 @@ def initialize_experience_maker(self, self._critic_initialized = True self._reward_model_initialized = True - def initialize_experience_maker_local(self, - initial_model_func=None, - reward_model_func=None, - actor_func=None, - critic_func=None): - ''' - Use function call to construct the model here, because some strategy requieres env_info - The model initialized here will be IGNORED in initialize_experience_maker. - initial_model and reward_model can have their own strategy rather than self.strategy. For example, Quantization. - ''' - - if actor_func is not None: - self.experience_maker.actor = actor_func() - self._actor_initialized = True - if critic_func is not None: - self.experience_maker.critic = critic_func() - self._critic_initialized = True - if initial_model_func is not None: - self.experience_maker.initial_model = initial_model_func() - self._initial_model_initialized = True - if reward_model_func is not None: - self.experience_maker.reward_model = reward_model_func() - self._reward_model_initialized = True - @ray.method(concurrency_group="model_io") def update_experience_maker(self, new_actor_state_dict: Dict[str, Any] = None, new_critic_state_dict: Dict[str, Any] = None, + fully_update: bool = False, chunk_start: bool = None, chunk_end: bool = None): ''' called by trainer chunk_start: Set True at the first call. Before sending state_dict calls chunk_end: Set True at the last call. After sending state_dict calls. + fully_update: Set True if you want to sync models when initializing TODO: load_state_dict integrate with model-sharding strategy ''' @@ -295,6 +268,8 @@ def update_experience_maker(self, current, peak = tracemalloc.get_traced_memory() print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") tracemalloc.stop() + if fully_update: + self._is_fully_initialized = True def _on_make_experience_start(self) -> None: for callback in self.callbacks: From b74a4fc92f0982a1a9218c130db137767d9ac3ad Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 21 Apr 2023 17:36:52 +0800 Subject: [PATCH 2/5] [chat] refactor model init --- .../Chat/coati/ray/example/1mmt_dummy.py | 8 -- .../coati/ray/src/detached_trainer_base.py | 13 ++-- .../coati/ray/src/detached_trainer_ppo.py | 78 ++++++++----------- .../coati/ray/src/experience_maker_holder.py | 2 +- 4 files changed, 39 insertions(+), 62 deletions(-) diff --git a/applications/Chat/coati/ray/example/1mmt_dummy.py b/applications/Chat/coati/ray/example/1mmt_dummy.py index 64071cd73a6d..414937abb1be 100644 --- a/applications/Chat/coati/ray/example/1mmt_dummy.py +++ b/applications/Chat/coati/ray/example/1mmt_dummy.py @@ -95,15 +95,7 @@ def main(args): lora_rank=args.lora_rank, train_batch_size=args.train_batch_size, buffer_limit=16, - experience_batch_size=args.experience_batch_size, max_epochs=args.max_epochs, - # kwargs: - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, eval_performance=True, debug=args.debug, ) for i, env_info_trainer in enumerate(env_info_trainers) diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py index 3558f58017a6..eb50941567f7 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_base.py +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -21,7 +21,6 @@ class DetachedTrainer(ABC): Args: detached_strategy (DetachedStrategy): the strategy to use for training detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training - experience_batch_size (int, defaults to 8): the batch size to use for experience generation max_epochs (int, defaults to 1): the number of epochs of training process data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process @@ -34,21 +33,17 @@ def __init__(self, train_batch_size: int = 8, buffer_limit: int = 0, buffer_cpu_offload: bool = True, - experience_batch_size: int = 8, max_epochs: int = 1, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], - debug: bool = False, - **generate_kwargs) -> None: + debug: bool = False) -> None: super().__init__() self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload) - self.experience_batch_size = experience_batch_size self.max_epochs = max_epochs self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks - self.generate_kwargs = generate_kwargs self.target_holder_name_list = experience_maker_holder_name_list self.target_holder_list = [] @@ -61,9 +56,13 @@ def update_target_holder_list(self, experience_maker_holder_name_list): self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) @abstractmethod - def _update_remote_makers(self): + def _update_remote_makers(self, fully_update: bool = False, **kwargs): pass + @ray.method(concurrency_group="model_io") + def sync_models_to_remote_makers(self, **kwargs): + self._update_remote_makers(fully_update=True, **kwargs) + @abstractmethod def training_step(self, experience: Experience) -> Dict[str, Any]: pass diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index 2850f1cf1d37..19638efe2547 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -54,28 +54,27 @@ class DetachedPPOTrainer(DetachedTrainer): ''' def __init__( - self, - experience_maker_holder_name_list: List[str], - strategy: str, - model: str, - pretrained: str = None, - lora_rank: int = 0, - cr_model: str = None, # if not None, use below cr settings for critic - cr_pretrained: str = None, - cr_lora_rank: int = 0, - env_info: Dict[str, str] = None, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - value_clip: float = 0.4, - experience_batch_size: int = 8, - max_epochs: int = 10, - dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], - eval_performance: bool = False, - debug: bool = False, - **generate_kwargs) -> None: + self, + experience_maker_holder_name_list: List[str], + strategy: str, + model: str, + pretrained: str = None, + lora_rank: int = 0, + cr_model: str = None, # if not None, use below cr settings for critic + cr_pretrained: str = None, + cr_lora_rank: int = 0, + env_info: Dict[str, str] = None, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + value_clip: float = 0.4, + max_epochs: int = 10, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + eval_performance: bool = False, + debug: bool = False, + ) -> None: # set environment variables if env_info: set_dist_env(env_info=env_info) @@ -112,7 +111,6 @@ def __init__( self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) # configure trainer - generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) @@ -120,12 +118,10 @@ def __init__( train_batch_size=train_batch_size, buffer_limit=buffer_limit, buffer_cpu_offload=buffer_cpu_offload, - experience_batch_size=experience_batch_size, max_epochs=max_epochs, dataloader_pin_memory=dataloader_pin_memory, callbacks=callbacks, - debug=debug, - **generate_kwargs) + debug=debug) # for remote maker initialization self._model_str = model @@ -135,7 +131,7 @@ def __init__( @ray.method(concurrency_group="model_io") @torch.no_grad() - def _update_remote_makers(self, **config): + def _update_remote_makers(self, fully_update: bool = False, **config): # TODO: balance duties if is_rank_0(): self.update_target_holder_list(self.target_holder_name_list) @@ -143,31 +139,34 @@ def _update_remote_makers(self, **config): if is_rank_0(): # mark start for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_start=True) + target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update) # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config): if is_rank_0(): for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard) + target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, + fully_update=fully_update) if is_rank_0(): # mark end for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_end=True) + target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) # critic if is_rank_0(): # mark start for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_start=True) + target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update) # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): if is_rank_0(): for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard) + target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, + fully_update=fully_update) if is_rank_0(): # mark end for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_end=True) + target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) + # TODO(ver217): remove this function @ray.method(concurrency_group="model_io") def initialize_remote_makers(self, **config): # TODO: balance duties @@ -273,16 +272,3 @@ def _get_model_state_dict_shard(self, model: torch.nn.Module, **config): pass for state_dict in self.strategy.get_model_state_dict_shard(model, **config): yield state_dict_to(state_dict) - - -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: - origin_model = strategy._unwrap_actor(actor) - new_kwargs = {**generate_kwargs} - # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation - - if 'update_model_kwargs_fn' not in generate_kwargs: - new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn - - return new_kwargs diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index b4d234c9df0c..0ee50704cfee 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -36,7 +36,7 @@ class ExperienceMakerHolder: strategy: experience_batch_size: batch size of generated experience kl_coef: the coefficient of kl divergence loss - sync_models_from_trainers: whether to sync models from trainers. If True, you must call update_experience_maker() in trainers to sync models. + sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. ''' def __init__( From eca708b73033500cdd3ee961a97d6dbdb9eee149 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 21 Apr 2023 17:44:36 +0800 Subject: [PATCH 3/5] [chat] refactor trainer args --- .../Chat/coati/ray/example/1mmt_dummy.py | 11 +++--- .../coati/ray/src/detached_trainer_ppo.py | 34 ++++--------------- .../coati/trainer/strategies/colossalai.py | 7 ++-- 3 files changed, 16 insertions(+), 36 deletions(-) diff --git a/applications/Chat/coati/ray/example/1mmt_dummy.py b/applications/Chat/coati/ray/example/1mmt_dummy.py index 414937abb1be..9ec31a7db9a4 100644 --- a/applications/Chat/coati/ray/example/1mmt_dummy.py +++ b/applications/Chat/coati/ray/example/1mmt_dummy.py @@ -84,15 +84,18 @@ def main(args): tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token + def trainer_model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).half().cuda() + critic = get_critic_from_args(args.model, args.pretrain).half().cuda() + return actor, critic + # configure Trainer trainer_refs = [ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( experience_maker_holder_name_list=["maker1"], - strategy=args.trainer_strategy, - model=args.model, + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, env_info=env_info_trainer, - pretrained=args.pretrain, - lora_rank=args.lora_rank, train_batch_size=args.train_batch_size, buffer_limit=16, max_epochs=args.max_epochs, diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index 19638efe2547..c70b749c67e5 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -1,10 +1,9 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import ray import torch from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic -from coati.models.generation_utils import update_model_kwargs_fn from coati.models.loss import PolicyLoss, ValueLoss from coati.trainer.callbacks import Callback from coati.trainer.callbacks.performance_evaluator import TrainerPerformaceEvaluator @@ -56,13 +55,8 @@ class DetachedPPOTrainer(DetachedTrainer): def __init__( self, experience_maker_holder_name_list: List[str], - strategy: str, - model: str, - pretrained: str = None, - lora_rank: int = 0, - cr_model: str = None, # if not None, use below cr settings for critic - cr_pretrained: str = None, - cr_lora_rank: int = 0, + strategy_fn: Callable[[], Strategy], + model_fn: Callable[[], Tuple[Actor, Critic]], env_info: Dict[str, str] = None, train_batch_size: int = 8, buffer_limit: int = 0, @@ -79,16 +73,10 @@ def __init__( if env_info: set_dist_env(env_info=env_info) # configure strategy - self.strategy = get_strategy_from_args(strategy) + self.strategy = strategy_fn() # configure models, loss and optimizers - if cr_model is None: - cr_model = model - cr_pretrained = pretrained - cr_lora_rank = lora_rank - with self.strategy.model_init_context(): - self.actor = get_actor_from_args(model, pretrained, lora_rank) - self.critic = get_critic_from_args(cr_model, cr_pretrained, cr_lora_rank) + self.actor, self.critic = model_fn() if eval_performance: actor_numel = get_model_numel(self.actor) @@ -96,11 +84,7 @@ def __init__( evaluator = TrainerPerformaceEvaluator(actor_numel, critic_numel) callbacks = callbacks + [evaluator] - if strategy != 'colossalai_gemini': - self.actor.to(torch.cuda.current_device()) # .to(torch.float16) - self.critic.to(torch.cuda.current_device()) # .to(torch.float16) - - if strategy.startswith('colossalai'): + if isinstance(self.strategy, ColossalAIStrategy): self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7) self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7) else: @@ -123,12 +107,6 @@ def __init__( callbacks=callbacks, debug=debug) - # for remote maker initialization - self._model_str = model - self._cr_model_str = cr_model - self._pretrained = pretrained - self._cr_pretrained = cr_pretrained - @ray.method(concurrency_group="model_io") @torch.no_grad() def _update_remote_makers(self, fully_update: bool = False, **config): diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 5a6021c5013f..b809c010247b 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -5,7 +5,7 @@ import torch.distributed as dist import torch.nn as nn import torch.optim as optim -from coati.models.base import LM, Actor, RewardModel, Critic +from coati.models.base import LM, Actor, Critic, RewardModel from coati.models.lora import LoraLinear from torch.optim import Optimizer from transformers.modeling_utils import PreTrainedModel @@ -139,7 +139,7 @@ def setup_model(self, model: nn.Module) -> nn.Module: model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) if self.stage != 3 and self.precision == 'fp16': - model = model.half() + model = model.half().cuda() return model def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: @@ -163,7 +163,6 @@ def _unwrap_actor(actor: Actor) -> nn.Module: def _unwrap_critic(critic: Critic) -> nn.Module: return Strategy._unwrap_critic(critic) - def _unwrap_model(self, model: Union[nn.Module, ZeroDDP]) -> nn.Module: return super()._unwrap_model(model) @@ -220,4 +219,4 @@ def get_model_state_dict_shard(self, model: nn.Module, **config): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - yield from model.state_dict_shard(max_shard_size=1024) \ No newline at end of file + yield from model.state_dict_shard(max_shard_size=1024) From d5ceab459372ee6a3bf856607159ae386fcfc5e4 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 21 Apr 2023 18:24:47 +0800 Subject: [PATCH 4/5] [chat] refactor model init --- .../Chat/coati/ray/example/1mmt_dummy.py | 13 +++- .../coati/ray/src/detached_trainer_base.py | 1 - .../coati/ray/src/detached_trainer_ppo.py | 65 ++++------------ .../coati/ray/src/experience_maker_holder.py | 74 +------------------ 4 files changed, 24 insertions(+), 129 deletions(-) diff --git a/applications/Chat/coati/ray/example/1mmt_dummy.py b/applications/Chat/coati/ray/example/1mmt_dummy.py index 9ec31a7db9a4..fdb742406b26 100644 --- a/applications/Chat/coati/ray/example/1mmt_dummy.py +++ b/applications/Chat/coati/ray/example/1mmt_dummy.py @@ -119,7 +119,9 @@ def model_fn(): env_info=env_info_maker, experience_batch_size=args.experience_batch_size, kl_coef=0.1, - # kwargs: + debug=args.debug, + # sync_models_from_trainers=True, + # generation kwargs: max_length=512, do_sample=True, temperature=1.0, @@ -128,19 +130,22 @@ def model_fn(): eos_token_id=tokenizer.eos_token_id, eval_performance=True, use_cache=True, - debug=args.debug, ) # configure sampler random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400)) def tokenize_fn(texts): - # print(texts) input_ids = torch.stack(texts).cuda() - # print(input_ids.shape) attn_mask = torch.ones_like(input_ids) return {'input_ids': input_ids, 'attention_mask': attn_mask} + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + wait_tasks = [] for trainer_ref in trainer_refs: diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py index eb50941567f7..86b60582a614 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_base.py +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -59,7 +59,6 @@ def update_target_holder_list(self, experience_maker_holder_name_list): def _update_remote_makers(self, fully_update: bool = False, **kwargs): pass - @ray.method(concurrency_group="model_io") def sync_models_to_remote_makers(self, **kwargs): self._update_remote_makers(fully_update=True, **kwargs) diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index c70b749c67e5..f500bd5e4824 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -115,71 +115,32 @@ def _update_remote_makers(self, fully_update: bool = False, **config): self.update_target_holder_list(self.target_holder_name_list) # actor: if is_rank_0(): - # mark start + # mark start, ensure order + tasks = [] for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update) + tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update)) + ray.get(tasks) # sending loop + tasks = [] for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config): if is_rank_0(): for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, - fully_update=fully_update) - if is_rank_0(): - # mark end - for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) - # critic - if is_rank_0(): - # mark start - for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update) - # sending loop + tasks.append( + target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, + fully_update=fully_update)) + # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): if is_rank_0(): for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, - fully_update=fully_update) + tasks.append( + target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, + fully_update=fully_update)) + ray.get(tasks) if is_rank_0(): # mark end for target_holder in self.target_holder_list: target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) - # TODO(ver217): remove this function - @ray.method(concurrency_group="model_io") - def initialize_remote_makers(self, **config): - # TODO: balance duties - if is_rank_0(): - self.update_target_holder_list(self.target_holder_name_list) - with torch.no_grad(): - # actor / initial_model: - # mark start - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(actor_model=self._model_str, - actor_pretrained=self._pretrained, - chunk_start=True) - # sending loop - for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor), - **config): - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(actor_state_dict=state_dict_shard) - # mark end - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(actor_model=self._model_str, chunk_end=True) - # critic / reward_model: - # mark start - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, - critic_pretrained=self._cr_pretrained, - chunk_start=True) - # sending loop - for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), - **config): - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(critic_state_dict=state_dict_shard) - # mark end - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, chunk_end=True) - @ray.method(concurrency_group="compute") def training_step(self, experience: Experience) -> Dict[str, float]: self.actor.train() diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index 0ee50704cfee..d9208d61b11c 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -86,7 +86,7 @@ def __init__( self._debug = debug self.target_auto_balance = False - if self._debug: + if self._debug and not self._is_fully_initialized: print('[maker] Waiting for INIT') def _get_ready(self): @@ -162,77 +162,6 @@ def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None self._send_experience(experience=experience) self._on_finish() - # TODO(ver217): remove this function - @ray.method(concurrency_group="model_io") - def initialize_experience_maker(self, - actor_model: str = None, - actor_pretrained: str = None, - actor_state_dict: Dict[str, Any] = None, - critic_model: str = None, - critic_pretrained: str = None, - critic_state_dict: Dict[str, Any] = None, - chunk_start: bool = None, - chunk_end: bool = None): - ''' - called by trainer - chunk_start: Set True at the first call. Before sending state_dict calls - chunk_end: Set True at the last call. After sending state_dict calls. - - TODO: load_state_dict integrate with model-sharding strategy - ''' - if self._fully_initialized(): - return - - if chunk_start: - if self._debug: - print('[maker] INIT') - with torch.no_grad(): - # (csric) any better way to get model structure? - with self.strategy.model_init_context(): - if not self._actor_initialized and actor_model is not None: - self.experience_maker.actor = get_actor_from_args(actor_model, - actor_pretrained).half().requires_grad_(False) - if not self._critic_initialized and critic_model is not None: - self.experience_maker.critic = get_critic_from_args( - critic_model, critic_pretrained).half().requires_grad_(False) - if not self._initial_model_initialized and actor_model is not None: - self.experience_maker.initial_model = get_actor_from_args( - actor_model, actor_pretrained).half().requires_grad_(False) - if not self._reward_model_initialized and critic_model is not None: - self.experience_maker.reward_model = get_reward_model_from_args( - critic_model, critic_pretrained).half().requires_grad_(False) - - with torch.no_grad(): - if not self._actor_initialized and actor_state_dict is not None: - self.experience_maker.actor.model.load_state_dict(actor_state_dict, strict=False) - if not self._critic_initialized and critic_state_dict is not None: - self.experience_maker.critic.load_state_dict(critic_state_dict, strict=False) - if not self._initial_model_initialized and actor_state_dict is not None: - self.experience_maker.initial_model.model.load_state_dict(actor_state_dict, strict=False) - if not self._reward_model_initialized and critic_state_dict is not None: - self.experience_maker.reward_model.load_state_dict(critic_state_dict, strict=False) - - if chunk_end: - with torch.no_grad(): - if actor_model is not None: - if not self._actor_initialized: - self.experience_maker.actor = self.strategy.prepare( - self.experience_maker.actor.to(torch.cuda.current_device())) - if not self._initial_model_initialized: - self.experience_maker.initial_model = self.strategy.prepare( - self.experience_maker.initial_model.to(torch.cuda.current_device())) - self._actor_initialized = True - self._initial_model_initialized = True - if critic_model is not None: - if not self._critic_initialized: - self.experience_maker.critic = self.strategy.prepare( - self.experience_maker.critic.to(torch.cuda.current_device())) - if not self._reward_model_initialized: - self.experience_maker.reward_model = self.strategy.prepare( - self.experience_maker.reward_model.to(torch.cuda.current_device())) - self._critic_initialized = True - self._reward_model_initialized = True - @ray.method(concurrency_group="model_io") def update_experience_maker(self, new_actor_state_dict: Dict[str, Any] = None, @@ -262,6 +191,7 @@ def update_experience_maker(self, if new_critic_state_dict is not None: self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) + # the lock must be released after both actor and critic being updated if chunk_end: self._model_visit_lock.release() if _watch_memory: From 9fa0de1212b25e4f1c473ef3af52e5b4caea18bb Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 21 Apr 2023 19:24:22 +0800 Subject: [PATCH 5/5] [chat] refactor trainer --- applications/Chat/coati/ray/src/detached_trainer_ppo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index f500bd5e4824..056942b83360 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -113,8 +113,6 @@ def _update_remote_makers(self, fully_update: bool = False, **config): # TODO: balance duties if is_rank_0(): self.update_target_holder_list(self.target_holder_name_list) - # actor: - if is_rank_0(): # mark start, ensure order tasks = [] for target_holder in self.target_holder_list: