diff --git a/applications/Chat/coati/ray/example/1mmt_dummy.py b/applications/Chat/coati/ray/example/1mmt_dummy.py index 68b666663b12..fdb742406b26 100644 --- a/applications/Chat/coati/ray/example/1mmt_dummy.py +++ b/applications/Chat/coati/ray/example/1mmt_dummy.py @@ -1,15 +1,18 @@ import argparse import os import socket -from copy import deepcopy from functools import partial import ray import torch -from coati.models.base import RewardModel from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.src.experience_maker_holder import ExperienceMakerHolder -from coati.ray.src.utils import get_actor_from_args, get_critic_from_args, get_reward_model_from_args +from coati.ray.src.utils import ( + get_actor_from_args, + get_critic_from_args, + get_reward_model_from_args, + get_strategy_from_args, +) from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -81,39 +84,44 @@ def main(args): tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token + def trainer_model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).half().cuda() + critic = get_critic_from_args(args.model, args.pretrain).half().cuda() + return actor, critic + # configure Trainer trainer_refs = [ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( experience_maker_holder_name_list=["maker1"], - strategy=args.trainer_strategy, - model=args.model, + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, env_info=env_info_trainer, - pretrained=args.pretrain, - lora_rank=args.lora_rank, train_batch_size=args.train_batch_size, buffer_limit=16, - experience_batch_size=args.experience_batch_size, max_epochs=args.max_epochs, - # kwargs: - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, eval_performance=True, debug=args.debug, ) for i, env_info_trainer in enumerate(env_info_trainers) ] + def model_fn(): + actor = get_actor_from_args(args.model, args.pretrain).half().cuda() + critic = get_critic_from_args(args.model, args.pretrain).half().cuda() + reward_model = get_reward_model_from_args(args.model, args.pretrain).half().cuda() + initial_model = get_actor_from_args(args.model, args.pretrain).half().cuda() + return actor, critic, reward_model, initial_model + # configure Experience Maker experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], - strategy=args.maker_strategy, + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, env_info=env_info_maker, experience_batch_size=args.experience_batch_size, kl_coef=0.1, - # kwargs: + debug=args.debug, + # sync_models_from_trainers=True, + # generation kwargs: max_length=512, do_sample=True, temperature=1.0, @@ -122,32 +130,22 @@ def main(args): eos_token_id=tokenizer.eos_token_id, eval_performance=True, use_cache=True, - debug=args.debug, ) - def init_inference_model(fn, model_name, pretrained): - model = fn(model_name, pretrained) - return model.half().cuda() - - # init maker locally - ray.get( - experience_holder_ref.initialize_experience_maker_local.remote( - initial_model_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain), - reward_model_func=partial(init_inference_model, get_reward_model_from_args, args.model, args.pretrain), - actor_func=partial(init_inference_model, get_actor_from_args, args.model, args.pretrain), - critic_func=partial(init_inference_model, get_critic_from_args, args.model, args.pretrain), - )) - # configure sampler random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400)) def tokenize_fn(texts): - # print(texts) input_ids = torch.stack(texts).cuda() - # print(input_ids.shape) attn_mask = torch.ones_like(input_ids) return {'input_ids': input_ids, 'attention_mask': attn_mask} + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + wait_tasks = [] for trainer_ref in trainer_refs: diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py index 3558f58017a6..86b60582a614 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_base.py +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -21,7 +21,6 @@ class DetachedTrainer(ABC): Args: detached_strategy (DetachedStrategy): the strategy to use for training detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training - experience_batch_size (int, defaults to 8): the batch size to use for experience generation max_epochs (int, defaults to 1): the number of epochs of training process data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process @@ -34,21 +33,17 @@ def __init__(self, train_batch_size: int = 8, buffer_limit: int = 0, buffer_cpu_offload: bool = True, - experience_batch_size: int = 8, max_epochs: int = 1, dataloader_pin_memory: bool = True, callbacks: List[Callback] = [], - debug: bool = False, - **generate_kwargs) -> None: + debug: bool = False) -> None: super().__init__() self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit, cpu_offload=buffer_cpu_offload) - self.experience_batch_size = experience_batch_size self.max_epochs = max_epochs self.dataloader_pin_memory = dataloader_pin_memory self.callbacks = callbacks - self.generate_kwargs = generate_kwargs self.target_holder_name_list = experience_maker_holder_name_list self.target_holder_list = [] @@ -61,9 +56,12 @@ def update_target_holder_list(self, experience_maker_holder_name_list): self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) @abstractmethod - def _update_remote_makers(self): + def _update_remote_makers(self, fully_update: bool = False, **kwargs): pass + def sync_models_to_remote_makers(self, **kwargs): + self._update_remote_makers(fully_update=True, **kwargs) + @abstractmethod def training_step(self, experience: Experience) -> Dict[str, Any]: pass diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index 2850f1cf1d37..056942b83360 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -1,10 +1,9 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import ray import torch from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic -from coati.models.generation_utils import update_model_kwargs_fn from coati.models.loss import PolicyLoss, ValueLoss from coati.trainer.callbacks import Callback from coati.trainer.callbacks.performance_evaluator import TrainerPerformaceEvaluator @@ -54,42 +53,30 @@ class DetachedPPOTrainer(DetachedTrainer): ''' def __init__( - self, - experience_maker_holder_name_list: List[str], - strategy: str, - model: str, - pretrained: str = None, - lora_rank: int = 0, - cr_model: str = None, # if not None, use below cr settings for critic - cr_pretrained: str = None, - cr_lora_rank: int = 0, - env_info: Dict[str, str] = None, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - value_clip: float = 0.4, - experience_batch_size: int = 8, - max_epochs: int = 10, - dataloader_pin_memory: bool = True, - callbacks: List[Callback] = [], - eval_performance: bool = False, - debug: bool = False, - **generate_kwargs) -> None: + self, + experience_maker_holder_name_list: List[str], + strategy_fn: Callable[[], Strategy], + model_fn: Callable[[], Tuple[Actor, Critic]], + env_info: Dict[str, str] = None, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + value_clip: float = 0.4, + max_epochs: int = 10, + dataloader_pin_memory: bool = True, + callbacks: List[Callback] = [], + eval_performance: bool = False, + debug: bool = False, + ) -> None: # set environment variables if env_info: set_dist_env(env_info=env_info) # configure strategy - self.strategy = get_strategy_from_args(strategy) + self.strategy = strategy_fn() # configure models, loss and optimizers - if cr_model is None: - cr_model = model - cr_pretrained = pretrained - cr_lora_rank = lora_rank - with self.strategy.model_init_context(): - self.actor = get_actor_from_args(model, pretrained, lora_rank) - self.critic = get_critic_from_args(cr_model, cr_pretrained, cr_lora_rank) + self.actor, self.critic = model_fn() if eval_performance: actor_numel = get_model_numel(self.actor) @@ -97,11 +84,7 @@ def __init__( evaluator = TrainerPerformaceEvaluator(actor_numel, critic_numel) callbacks = callbacks + [evaluator] - if strategy != 'colossalai_gemini': - self.actor.to(torch.cuda.current_device()) # .to(torch.float16) - self.critic.to(torch.cuda.current_device()) # .to(torch.float16) - - if strategy.startswith('colossalai'): + if isinstance(self.strategy, ColossalAIStrategy): self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7) self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7) else: @@ -112,7 +95,6 @@ def __init__( self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) # configure trainer - generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) @@ -120,88 +102,42 @@ def __init__( train_batch_size=train_batch_size, buffer_limit=buffer_limit, buffer_cpu_offload=buffer_cpu_offload, - experience_batch_size=experience_batch_size, max_epochs=max_epochs, dataloader_pin_memory=dataloader_pin_memory, callbacks=callbacks, - debug=debug, - **generate_kwargs) - - # for remote maker initialization - self._model_str = model - self._cr_model_str = cr_model - self._pretrained = pretrained - self._cr_pretrained = cr_pretrained + debug=debug) @ray.method(concurrency_group="model_io") @torch.no_grad() - def _update_remote_makers(self, **config): + def _update_remote_makers(self, fully_update: bool = False, **config): # TODO: balance duties if is_rank_0(): self.update_target_holder_list(self.target_holder_name_list) - # actor: - if is_rank_0(): - # mark start + # mark start, ensure order + tasks = [] for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_start=True) + tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update)) + ray.get(tasks) # sending loop + tasks = [] for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config): if is_rank_0(): for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard) - if is_rank_0(): - # mark end - for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_end=True) - # critic - if is_rank_0(): - # mark start - for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_start=True) - # sending loop + tasks.append( + target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, + fully_update=fully_update)) + # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): if is_rank_0(): for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard) + tasks.append( + target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, + fully_update=fully_update)) + ray.get(tasks) if is_rank_0(): # mark end for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_end=True) - - @ray.method(concurrency_group="model_io") - def initialize_remote_makers(self, **config): - # TODO: balance duties - if is_rank_0(): - self.update_target_holder_list(self.target_holder_name_list) - with torch.no_grad(): - # actor / initial_model: - # mark start - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(actor_model=self._model_str, - actor_pretrained=self._pretrained, - chunk_start=True) - # sending loop - for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor), - **config): - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(actor_state_dict=state_dict_shard) - # mark end - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(actor_model=self._model_str, chunk_end=True) - # critic / reward_model: - # mark start - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, - critic_pretrained=self._cr_pretrained, - chunk_start=True) - # sending loop - for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), - **config): - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(critic_state_dict=state_dict_shard) - # mark end - for target_holder in self.target_holder_list: - target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, chunk_end=True) + target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) @ray.method(concurrency_group="compute") def training_step(self, experience: Experience) -> Dict[str, float]: @@ -273,16 +209,3 @@ def _get_model_state_dict_shard(self, model: torch.nn.Module, **config): pass for state_dict in self.strategy.get_model_state_dict_shard(model, **config): yield state_dict_to(state_dict) - - -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: - origin_model = strategy._unwrap_actor(actor) - new_kwargs = {**generate_kwargs} - # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation - - if 'update_model_kwargs_fn' not in generate_kwargs: - new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn - - return new_kwargs diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index 93624c2be921..c85d36bab360 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -3,14 +3,14 @@ import tracemalloc from copy import deepcopy from threading import Lock -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ray import torch import torch.nn as nn from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker from coati.models.base import Actor, Critic, RewardModel -from coati.replay_buffer.utils import split_experience_batch, make_experience_batch, BufferItem +from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.trainer.callbacks import Callback from coati.trainer.callbacks.performance_evaluator import ExperienceMakerPerformanceEvaluator from coati.trainer.strategies import Strategy @@ -37,73 +37,67 @@ class ExperienceMakerHolder: strategy: experience_batch_size: batch size of generated experience kl_coef: the coefficient of kl divergence loss + sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. ''' - def __init__(self, - detached_trainer_name_list: List[str], - strategy: str, - env_info: Dict[str, str] = None, - experience_batch_size: int = 8, - kl_coef: float = 0.1, - callbacks: List[Callback] = [], - eval_performance: bool = False, - debug: bool = False, - send_grain_size: int = 4, - **generate_kwargs): + def __init__( + self, + detached_trainer_name_list: List[str], + strategy_fn: Callable[[], Strategy], + # a function returns (actor, critic, reward_model, initial_model) + model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], + env_info: Dict[str, str] = None, + sync_models_from_trainers: bool = False, + experience_batch_size: int = 8, + send_grain_size: int = 4, + kl_coef: float = 0.1, + callbacks: List[Callback] = [], + eval_performance: bool = False, + debug: bool = False, + **generate_kwargs): # set environment variables if env_info: set_dist_env(env_info=env_info) self.target_trainer_list = [] for name in detached_trainer_name_list: self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) - self.strategy_str = strategy - self.strategy = get_strategy_from_args(strategy) + self.strategy = strategy_fn() self.experience_batch_size = experience_batch_size self.kl_coef = kl_coef - self.generate_kwargs = generate_kwargs - actor, critic, reward_model, initial_model = None, None, None, None + # init models + with self.strategy.model_init_context(): + actor, critic, reward_model, initial_model = model_fn() + self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor) + if eval_performance: + actor_numel = get_model_numel(actor) + critic_numel = get_model_numel(critic) + initial_model_numel = get_model_numel(initial_model) + reward_model_numel = get_model_numel(reward_model) + evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, + reward_model_numel) + callbacks = callbacks + [evaluator] + + actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) self.callbacks = callbacks - self.eval_performance = eval_performance self.send_grain_size = send_grain_size self._model_visit_lock = Lock() - self._initial_model_initialized = False - self._reward_model_initialized = False - self._actor_initialized = False - self._critic_initialized = False + + self._is_fully_initialized = not sync_models_from_trainers self._debug = debug self.target_auto_balance = False - if self._debug: + if self._debug and not self._is_fully_initialized: print('[maker] Waiting for INIT') def _get_ready(self): while not self._fully_initialized(): time.sleep(1.0) - # setup performance evaluator - if self.eval_performance: - actor_numel = get_model_numel(self.experience_maker.actor) - critic_numel = get_model_numel(self.experience_maker.critic) - initial_model_numel = get_model_numel(self.experience_maker.initial_model) - reward_model_numel = get_model_numel(self.experience_maker.reward_model) - evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, - reward_model_numel) - self.callbacks.append(evaluator) - - self.generate_kwargs = _set_default_generate_kwargs(self.generate_kwargs, self.experience_maker.actor) def _fully_initialized(self): - if not self._initial_model_initialized: - return False - if not self._reward_model_initialized: - return False - if not self._actor_initialized: - return False - if not self._critic_initialized: - return False - return True + return self._is_fully_initialized def update_target_trainer_list(self, detached_trainer_name_list): self.target_trainer_list = [] @@ -183,110 +177,18 @@ def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None self._send_experience(experience=experience_fragment) self._on_finish() - @ray.method(concurrency_group="model_io") - def initialize_experience_maker(self, - actor_model: str = None, - actor_pretrained: str = None, - actor_state_dict: Dict[str, Any] = None, - critic_model: str = None, - critic_pretrained: str = None, - critic_state_dict: Dict[str, Any] = None, - chunk_start: bool = None, - chunk_end: bool = None): - ''' - called by trainer - chunk_start: Set True at the first call. Before sending state_dict calls - chunk_end: Set True at the last call. After sending state_dict calls. - - TODO: load_state_dict integrate with model-sharding strategy - ''' - if self._fully_initialized(): - return - - if chunk_start: - if self._debug: - print('[maker] INIT') - with torch.no_grad(): - # (csric) any better way to get model structure? - with self.strategy.model_init_context(): - if not self._actor_initialized and actor_model is not None: - self.experience_maker.actor = get_actor_from_args(actor_model, - actor_pretrained).half().requires_grad_(False) - if not self._critic_initialized and critic_model is not None: - self.experience_maker.critic = get_critic_from_args( - critic_model, critic_pretrained).half().requires_grad_(False) - if not self._initial_model_initialized and actor_model is not None: - self.experience_maker.initial_model = get_actor_from_args( - actor_model, actor_pretrained).half().requires_grad_(False) - if not self._reward_model_initialized and critic_model is not None: - self.experience_maker.reward_model = get_reward_model_from_args( - critic_model, critic_pretrained).half().requires_grad_(False) - - with torch.no_grad(): - if not self._actor_initialized and actor_state_dict is not None: - self.experience_maker.actor.model.load_state_dict(actor_state_dict, strict=False) - if not self._critic_initialized and critic_state_dict is not None: - self.experience_maker.critic.load_state_dict(critic_state_dict, strict=False) - if not self._initial_model_initialized and actor_state_dict is not None: - self.experience_maker.initial_model.model.load_state_dict(actor_state_dict, strict=False) - if not self._reward_model_initialized and critic_state_dict is not None: - self.experience_maker.reward_model.load_state_dict(critic_state_dict, strict=False) - - if chunk_end: - with torch.no_grad(): - if actor_model is not None: - if not self._actor_initialized: - self.experience_maker.actor = self.strategy.prepare( - self.experience_maker.actor.to(torch.cuda.current_device())) - if not self._initial_model_initialized: - self.experience_maker.initial_model = self.strategy.prepare( - self.experience_maker.initial_model.to(torch.cuda.current_device())) - self._actor_initialized = True - self._initial_model_initialized = True - if critic_model is not None: - if not self._critic_initialized: - self.experience_maker.critic = self.strategy.prepare( - self.experience_maker.critic.to(torch.cuda.current_device())) - if not self._reward_model_initialized: - self.experience_maker.reward_model = self.strategy.prepare( - self.experience_maker.reward_model.to(torch.cuda.current_device())) - self._critic_initialized = True - self._reward_model_initialized = True - - def initialize_experience_maker_local(self, - initial_model_func=None, - reward_model_func=None, - actor_func=None, - critic_func=None): - ''' - Use function call to construct the model here, because some strategy requieres env_info - The model initialized here will be IGNORED in initialize_experience_maker. - initial_model and reward_model can have their own strategy rather than self.strategy. For example, Quantization. - ''' - - if actor_func is not None: - self.experience_maker.actor = actor_func() - self._actor_initialized = True - if critic_func is not None: - self.experience_maker.critic = critic_func() - self._critic_initialized = True - if initial_model_func is not None: - self.experience_maker.initial_model = initial_model_func() - self._initial_model_initialized = True - if reward_model_func is not None: - self.experience_maker.reward_model = reward_model_func() - self._reward_model_initialized = True - @ray.method(concurrency_group="model_io") def update_experience_maker(self, new_actor_state_dict: Dict[str, Any] = None, new_critic_state_dict: Dict[str, Any] = None, + fully_update: bool = False, chunk_start: bool = None, chunk_end: bool = None): ''' called by trainer chunk_start: Set True at the first call. Before sending state_dict calls chunk_end: Set True at the last call. After sending state_dict calls. + fully_update: Set True if you want to sync models when initializing TODO: load_state_dict integrate with model-sharding strategy ''' @@ -304,12 +206,15 @@ def update_experience_maker(self, if new_critic_state_dict is not None: self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) + # the lock must be released after both actor and critic being updated if chunk_end: self._model_visit_lock.release() if _watch_memory: current, peak = tracemalloc.get_traced_memory() print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") tracemalloc.stop() + if fully_update: + self._is_fully_initialized = True def _on_make_experience_start(self) -> None: for callback in self.callbacks: diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 5a6021c5013f..b809c010247b 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -5,7 +5,7 @@ import torch.distributed as dist import torch.nn as nn import torch.optim as optim -from coati.models.base import LM, Actor, RewardModel, Critic +from coati.models.base import LM, Actor, Critic, RewardModel from coati.models.lora import LoraLinear from torch.optim import Optimizer from transformers.modeling_utils import PreTrainedModel @@ -139,7 +139,7 @@ def setup_model(self, model: nn.Module) -> nn.Module: model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) if self.stage != 3 and self.precision == 'fp16': - model = model.half() + model = model.half().cuda() return model def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: @@ -163,7 +163,6 @@ def _unwrap_actor(actor: Actor) -> nn.Module: def _unwrap_critic(critic: Critic) -> nn.Module: return Strategy._unwrap_critic(critic) - def _unwrap_model(self, model: Union[nn.Module, ZeroDDP]) -> nn.Module: return super()._unwrap_model(model) @@ -220,4 +219,4 @@ def get_model_state_dict_shard(self, model: nn.Module, **config): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - yield from model.state_dict_shard(max_shard_size=1024) \ No newline at end of file + yield from model.state_dict_shard(max_shard_size=1024)