diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 7f6eb73262fa..3d30208c05f4 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -61,7 +61,13 @@ def T(w): if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: - self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): + # csric: temporary fix + self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) + self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) + self.reset_parameters() + else: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling self.merged = False def eval(self): diff --git a/applications/Chat/coati/ray/example/1m1t_quantize.py b/applications/Chat/coati/ray/example/1m1t_quantize.py new file mode 100644 index 000000000000..12a60fd65d8b --- /dev/null +++ b/applications/Chat/coati/ray/example/1m1t_quantize.py @@ -0,0 +1,147 @@ +import argparse +import pandas as pd +import torch +import ray +import os +import socket + +from coati.ray.src.experience_maker_holder import ExperienceMakerHolder +from coati.ray.src.detached_trainer_ppo import DetachedPPOTrainer + +from transformers import AutoTokenizer, BloomTokenizerFast +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainer = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '1', + 'master_port' : trainer_port, + 'master_addr' : master_addr} + + # maker_env_info + maker_port = str(get_free_port()) + env_info_maker = {'local_rank' : '0', + 'rank' : '0', + 'world_size' : '1', + 'master_port' : maker_port, + 'master_addr' : master_addr} + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure Trainer + trainer_ref = DetachedPPOTrainer.options(name="trainer1", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=["maker1"], + strategy=args.trainer_strategy, + model=args.model, + env_info = env_info_trainer, + pretrained=args.pretrain, + lora_rank=args.lora_rank, + train_batch_size=args.train_batch_size, + buffer_limit=16, + experience_batch_size=args.experience_batch_size, + max_epochs=args.max_epochs, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # configure Experience Maker + experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=["trainer1"], + strategy=args.maker_strategy, + env_info = env_info_maker, + experience_batch_size=args.experience_batch_size, + kl_coef=0.1, + #kwargs: + max_length=128, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + debug=args.debug, + ) + + # a 'jump wire' to set quantized initial_model and reward_model + + + # trainer send its actor and critic to experience holders. + # ray.get(trainer_ref.initialize_remote_makers.remote()) + + # configure sampler + dataset = pd.read_csv(args.prompt_path)['prompt'] + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.cuda() for k, v in batch.items()} + + trainer_done_ref = trainer_ref.fit.remote(num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + num_exp_per_maker = args.num_episodes * args.max_timesteps // args.update_timesteps * args.max_epochs + 3 # +3 for fault tolerance + maker_done_ref = experience_holder_ref.workingloop.remote(dataset, tokenize_fn, times=num_exp_per_maker) + + ray.get([trainer_done_ref, maker_done_ref]) + + # save model checkpoint after fitting + trainer_ref.strategy_save_actor.remote(args.save_path, only_rank0=True) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + trainer_ref.strategy_save_actor_optim.remote('actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('prompt_path') + parser.add_argument('--trainer_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--maker_strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) + main(args) diff --git a/applications/Chat/coati/ray/src/detached_trainer_base.py b/applications/Chat/coati/ray/src/detached_trainer_base.py index f1ed1ec71499..f5e52e8a3b3a 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_base.py +++ b/applications/Chat/coati/ray/src/detached_trainer_base.py @@ -24,6 +24,7 @@ class DetachedTrainer(ABC): data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating + ''' def __init__(self, @@ -45,6 +46,11 @@ def __init__(self, self.generate_kwargs = generate_kwargs self.target_holder_name_list = experience_maker_holder_name_list self.target_holder_list = [] + + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + self._debug = True + else: + self._debug = False def update_target_holder_list(self, experience_maker_holder_name_list): self.target_holder_name_list = experience_maker_holder_name_list @@ -63,13 +69,13 @@ def training_step(self, experience: Experience) -> Dict[str, Any]: def _learn(self): pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0()) for _ in pbar: - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + if self._debug: print("[trainer] sampling exp") experience = self._buffer_sample() - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + if self._debug: print("[trainer] training step") metrics = self.training_step(experience) - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + if self._debug: print("[trainer] step over") pbar.set_postfix(metrics) @@ -88,15 +94,14 @@ def fit(self, num_episodes: int = 50000, max_timesteps: int = 500, update_timest @ray.method(concurrency_group="buffer_length") def buffer_get_length(self): # called by ExperienceMakerHolder - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + if self._debug: print("[trainer] telling length") return self.detached_replay_buffer.get_length() @ray.method(concurrency_group="buffer_append") def buffer_append(self, experience: Experience): # called by ExperienceMakerHolder - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - # print(f"[trainer] receiving exp. Current buffer length: {self.detached_replay_buffer.get_length()}") + if self._debug: print(f"[trainer] receiving exp.") self.detached_replay_buffer.append(experience) diff --git a/applications/Chat/coati/ray/src/detached_trainer_ppo.py b/applications/Chat/coati/ray/src/detached_trainer_ppo.py index 838e82d07f4a..071f0ddab2b9 100644 --- a/applications/Chat/coati/ray/src/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/src/detached_trainer_ppo.py @@ -14,11 +14,13 @@ import ray -from .utils import is_rank_0, get_cuda_actor_critic_from_args, get_strategy_from_args, set_dist_env +from .utils import is_rank_0, get_actor_from_args, get_critic_from_args, get_strategy_from_args, set_dist_env, \ + state_dict_to + from .detached_trainer_base import DetachedTrainer -@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append":1, "buffer_sample":1,"model_io": 1, "compute": 1}) +@ray.remote(concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}) class DetachedPPOTrainer(DetachedTrainer): ''' Detached Trainer for PPO algorithm @@ -44,9 +46,12 @@ def __init__(self, experience_maker_holder_name_list: List[str], strategy: str, model: str, - env_info: Dict[str, str] = None, pretrained: str = None, lora_rank: int = 0, + cr_model: str = None, # if not None, use below cr settings for critic + cr_pretrained: str = None, + cr_lora_rank: int = 0, + env_info: Dict[str, str] = None, train_batch_size: int = 8, buffer_limit: int = 0, buffer_cpu_offload: bool = True, @@ -63,24 +68,32 @@ def __init__(self, # configure strategy self.strategy = get_strategy_from_args(strategy) # configure models, loss and optimizers + if cr_model is None: + cr_model = model + cr_pretrained = pretrained + cr_lora_rank = lora_rank + with self.strategy.model_init_context(): - self.actor, self.critic = get_cuda_actor_critic_from_args(model, pretrained, lora_rank) + self.actor = get_actor_from_args(model, pretrained, lora_rank) + self.critic = get_critic_from_args(cr_model, cr_pretrained, cr_lora_rank) if strategy != 'colossalai_gemini': - self.actor.to(torch.float16).to(torch.cuda.current_device()) - self.critic.to(torch.float16).to(torch.cuda.current_device()) + self.actor.to(torch.cuda.current_device()) #.to(torch.float16) + self.critic.to(torch.cuda.current_device()) #.to(torch.float16) + if strategy.startswith('colossalai'): - self.actor_optim = HybridAdam(self.actor.parameters(), lr=5e-6) - self.critic_optim = HybridAdam(self.critic.parameters(), lr=5e-6) + self.actor_optim = HybridAdam(self.actor.parameters(), lr=1e-7) + self.critic_optim = HybridAdam(self.critic.parameters(), lr=1e-7) else: - self.actor_optim = Adam(self.actor.parameters(), lr=5e-6) - self.critic_optim = Adam(self.critic.parameters(), lr=5e-6) + self.actor_optim = Adam(self.actor.parameters(), lr=1e-7) + self.critic_optim = Adam(self.critic.parameters(), lr=1e-7) (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) - generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) + # configure trainer + generate_kwargs = _set_default_generate_kwargs(self.strategy, generate_kwargs, self.actor) self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) @@ -94,25 +107,69 @@ def __init__(self, callbacks=callbacks, **generate_kwargs) + # for remote maker initialization + self._model_str = model + self._cr_model_str = cr_model + self._pretrained = pretrained + self._cr_pretrained = cr_pretrained + @ray.method(concurrency_group="model_io") - def _update_remote_makers(self): + def _update_remote_makers(self, **config): # TODO: balance duties if is_rank_0(): self.update_target_holder_list(self.target_holder_name_list) - for target_holder in self.target_holder_list: - # TODO: reduce malloc - with torch.no_grad(): - ray.get(target_holder.update_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) - + with torch.no_grad(): + # actor: + # mark start + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(chunk_start=True) + # sending loop + for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor), **config): + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(new_actor_state_dict = state_dict_shard) + # mark end + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(chunk_end=True) + # critic + # mark start + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(chunk_start=True) + # sending loop + for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(new_critic_state_dict = state_dict_shard) + # mark end + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(chunk_end=True) + @ray.method(concurrency_group="model_io") - def initialize_remote_makers(self): + def initialize_remote_makers(self, **config): # TODO: balance duties if is_rank_0(): self.update_target_holder_list(self.target_holder_name_list) - for target_holder in self.target_holder_list: - # TODO: reduce malloc - with torch.no_grad(): - ray.get(target_holder.initialize_experience_maker.remote(self._get_unwrapped_actor(), self._get_unwrapped_critic())) + with torch.no_grad(): + # actor / initial_model: + # mark start + for target_holder in self.target_holder_list: + target_holder.initialize_experience_maker.remote(actor_model=self._model_str,actor_pretrained=self._pretrained,chunk_start=True) + # sending loop + for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_actor(self.actor), **config): + for target_holder in self.target_holder_list: + target_holder.initialize_experience_maker.remote(actor_state_dict=state_dict_shard) + # mark end + for target_holder in self.target_holder_list: + target_holder.initialize_experience_maker.remote(actor_model=self._model_str, chunk_end=True) + # critic / reward_model: + # mark start + for target_holder in self.target_holder_list: + target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str,critic_pretrained=self._cr_pretrained,chunk_start=True) + # sending loop + for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): + for target_holder in self.target_holder_list: + target_holder.initialize_experience_maker.remote(critic_state_dict=state_dict_shard) + # mark end + for target_holder in self.target_holder_list: + target_holder.initialize_experience_maker.remote(critic_model=self._cr_model_str, chunk_end=True) @ray.method(concurrency_group="compute") def training_step(self, experience: Experience) -> Dict[str, float]: @@ -177,6 +234,14 @@ def _get_unwrapped_critic(self): elif isinstance(self.strategy, NaiveStrategy): return self.critic + def _get_model_state_dict_shard(self, model: torch.nn.Module, **config): + try: + self.strategy.merge_lora_weight(model) + except AttributeError: + pass + for state_dict in self.strategy.get_model_state_dict_shard(model, **config): + yield state_dict_to(state_dict) + def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: origin_model = strategy._unwrap_actor(actor) @@ -189,4 +254,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn return new_kwargs - \ No newline at end of file diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index 94e4a3d537a5..67b89a68119a 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -13,16 +13,17 @@ from threading import Lock import time import os +import tracemalloc - -from .utils import is_rank_0, get_strategy_from_args, set_dist_env +from .utils import is_rank_0, get_strategy_from_args, set_dist_env, get_actor_from_args, get_critic_from_args, \ + get_reward_model_from_args @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) class ExperienceMakerHolder: ''' Args: - detached_trainer_name_list: str list to get ray actor handleskkk + detached_trainer_name_list: str list to get ray actor handles strategy: experience_batch_size: batch size of generated experience kl_coef: the coefficient of kl divergence loss @@ -46,18 +47,39 @@ def __init__(self, self.experience_batch_size = experience_batch_size self.kl_coef = kl_coef self.generate_kwargs = generate_kwargs - # Need a trainer to give an actor and a critic via initialize_experience_maker(...) actor, critic, reward_model, initial_model = None, None, None, None self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) + self._model_visit_lock = Lock() - self.fully_initialized = False + self._initial_model_initialized = False + self._reward_model_initialized = False + self._actor_initialized = False + self._critic_initialized = False + if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + self._debug = True + else: + self._debug = False + self.target_auto_balance = False + + if self._debug: print('[maker] Waiting for INIT') def _get_ready(self): - while not self.fully_initialized: + while not self._fully_initialized(): time.sleep(1.0) + def _fully_initialized(self): + if not self._initial_model_initialized: + return False + if not self._reward_model_initialized: + return False + if not self._actor_initialized: + return False + if not self._critic_initialized: + return False + return True + def update_target_trainer_list(self, detached_trainer_name_list): self.target_trainer_list = [] for name in detached_trainer_name_list: @@ -66,7 +88,6 @@ def update_target_trainer_list(self, detached_trainer_name_list): # copy from ../trainer/base.py @ray.method(concurrency_group="compute") def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience: - self._get_ready() if isinstance(inputs, Tensor): return self.experience_maker.make_experience(inputs, **self.generate_kwargs) elif isinstance(inputs, dict): @@ -76,40 +97,37 @@ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experien @ray.method(concurrency_group="experience_io") def _send_experience(self, experience): - ''' - ignore it - - # choose a trainer that has the least experience batch in its detached_replay_buffer - chosen_trainer = None - min_length = None - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print("[maker] choosing target trainer") - while chosen_trainer is None: - for target_trainer in self.target_trainer_list: - try: - temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1) - if min_length is None: - min_length = temp_length - chosen_trainer = target_trainer - else: - if temp_length < min_length: + if not self.target_auto_balance: + # choose the trainer in polling mannar + if not hasattr(self, "_target_idx"): + self._target_idx = 0 + chosen_trainer = self.target_trainer_list[self._target_idx] + if self._debug: + print(f"[maker] sending exp to {chosen_trainer}") + chosen_trainer.buffer_append.remote(experience) + self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) + else: + # choose a trainer that has the least experience batch in its detached_replay_buffer + chosen_trainer = None + min_length = None + if self._debug: + print("[maker] choosing tartget trainer") + while chosen_trainer is None: + for target_trainer in self.target_trainer_list: + try: + temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1) + if min_length is None: min_length = temp_length chosen_trainer = target_trainer - except GetTimeoutError: - pass - - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print(f"[maker] sending exp to {chosen_trainer}") - chosen_trainer.buffer_append.remote(experience) - ''' - # - if not hasattr(self, "_target_idx"): - self._target_idx = 0 - chosen_trainer = self.target_trainer_list[self._target_idx] - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print(f"[maker] sending exp to {chosen_trainer}") - chosen_trainer.buffer_append.remote(experience) - self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) + else: + if temp_length < min_length: + min_length = temp_length + chosen_trainer = target_trainer + except GetTimeoutError: + pass + if self._debug: + print(f"[maker] sending exp to {chosen_trainer}") + chosen_trainer.buffer_append.remote(experience) def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000): self._get_ready() @@ -126,47 +144,123 @@ def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None self._send_experience(experience=experience) @ray.method(concurrency_group="model_io") - def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic): + def initialize_experience_maker(self, + actor_model: str = None, + actor_pretrained: str = None, + actor_state_dict: Dict[str, Any] = None, + critic_model: str = None, + critic_pretrained: str = None, + critic_state_dict: Dict[str, Any] = None, + chunk_start: bool = None, + chunk_end: bool = None): ''' - called by trainer. Only once. + called by trainer + chunk_start: Set True at the first call. Before sending state_dict calls + chunk_end: Set True at the last call. After sending state_dict calls. + + TODO: load_state_dict integrate with model-sharding strategy ''' - # TODO: reduce malloc - if self.fully_initialized: + if self._fully_initialized(): return - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: - print('[maker] INIT') + + if chunk_start: + if self._debug: + print('[maker] INIT') + with torch.no_grad(): + # (csric) any better way to get model structure? + with self.strategy.model_init_context(): + if not self._actor_initialized and actor_model is not None: + self.experience_maker.actor = get_actor_from_args(actor_model, actor_pretrained).half().requires_grad_(False) + if not self._critic_initialized and critic_model is not None: + self.experience_maker.critic = get_critic_from_args(critic_model, critic_pretrained).half().requires_grad_(False) + if not self._initial_model_initialized and actor_model is not None: + self.experience_maker.initial_model = get_actor_from_args(actor_model, actor_pretrained).half().requires_grad_(False) + if not self._reward_model_initialized and critic_model is not None: + self.experience_maker.reward_model = get_reward_model_from_args(critic_model, critic_pretrained).half().requires_grad_(False) + with torch.no_grad(): - with self.strategy.model_init_context(): - actor = init_actor - critic = init_critic - initial_model = deepcopy(actor) - reward_model = RewardModel(deepcopy(critic.model), - deepcopy(critic.value_head)).to(torch.cuda.current_device()) - if self.strategy_str != 'colossalai_gemini': - actor.to(torch.float16).to(torch.cuda.current_device()) - critic.to(torch.float16).to(torch.cuda.current_device()) - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) - - self.experience_maker.actor = self.strategy.prepare(actor) - self.experience_maker.critic = self.strategy.prepare(critic) - self.experience_maker.initial_model = self.strategy.prepare(initial_model) - self.experience_maker.reward_model = self.strategy.prepare(reward_model) - self.fully_initialized = True + if not self._actor_initialized and actor_state_dict is not None: + self.experience_maker.actor.model.load_state_dict(actor_state_dict, strict=False) + if not self._critic_initialized and critic_state_dict is not None: + self.experience_maker.critic.load_state_dict(critic_state_dict, strict=False) + if not self._initial_model_initialized and actor_state_dict is not None: + self.experience_maker.initial_model.model.load_state_dict(actor_state_dict, strict=False) + if not self._reward_model_initialized and critic_state_dict is not None: + self.experience_maker.reward_model.load_state_dict(critic_state_dict, strict=False) + + + if chunk_end: + with torch.no_grad(): + if actor_model is not None: + if not self._actor_initialized: + self.experience_maker.actor = self.strategy.prepare(self.experience_maker.actor.to(torch.cuda.current_device())) + if not self._initial_model_initialized: + self.experience_maker.initial_model = self.strategy.prepare(self.experience_maker.initial_model.to(torch.cuda.current_device())) + self._actor_initialized = True + self._initial_model_initialized = True + if critic_model is not None: + if not self._critic_initialized: + self.experience_maker.critic = self.strategy.prepare(self.experience_maker.critic.to(torch.cuda.current_device())) + if not self._reward_model_initialized: + self.experience_maker.reward_model = self.strategy.prepare(self.experience_maker.reward_model.to(torch.cuda.current_device())) + self._critic_initialized = True + self._reward_model_initialized = True + + + def initialize_experience_maker_local(self, + initial_model_func=None, + reward_model_func=None, + actor_func=None, + critic_func=None): + ''' + Use function call to construct the model here, because some strategy requieres env_info + The model initialized here will be IGNORED in initialize_experience_maker. + initial_model and reward_model can have their own strategy rather than self.strategy. For example, Quantization. + ''' + + if actor_func is not None: + self.experience_maker.actor = actor_func() + self._actor_initialized = True + if critic_func is not None: + self.experience_maker.critic = critic_func() + self._critic_initialized = True + if initial_model_func is not None: + self.experience_maker.initial_model = initial_model_func() + self._initial_model_initialized = True + if reward_model_func is not None: + self.experience_maker.reward_model = reward_model_func() + self._reward_model_initialized = True @ray.method(concurrency_group="model_io") - def update_experience_maker(self, new_actor: Actor, new_critic: Critic): + def update_experience_maker(self, + new_actor_state_dict: Dict[str, Any] = None, + new_critic_state_dict: Dict[str, Any] = None, + chunk_start: bool = None, + chunk_end: bool = None): ''' called by trainer + chunk_start: Set True at the first call. Before sending state_dict calls + chunk_end: Set True at the last call. After sending state_dict calls. + + TODO: load_state_dict integrate with model-sharding strategy ''' - # TODO: reduce malloc - self._model_visit_lock.acquire() - with torch.no_grad(): - if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True: + _watch_memory = True + if chunk_start: + if self._debug: print("[maker] UPDATE ") - if self.strategy_str != 'colossalai_gemini': - new_actor.to(torch.float16).to(torch.cuda.current_device()) - new_critic.to(torch.float16).to(torch.cuda.current_device()) - self.experience_maker.actor = self.strategy.prepare(new_actor) - self.experience_maker.critic = self.strategy.prepare(new_critic) - self._model_visit_lock.release() + if _watch_memory: + tracemalloc.start() + self._model_visit_lock.acquire() + + with torch.no_grad(): + if new_actor_state_dict is not None: + self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False) + if new_critic_state_dict is not None: + self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) + + if chunk_end: + self._model_visit_lock.release() + if _watch_memory: + current, peak = tracemalloc.get_traced_memory() + print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") + tracemalloc.stop() diff --git a/applications/Chat/coati/ray/src/utils.py b/applications/Chat/coati/ray/src/utils.py index c750879b6d18..827c2b8c6dc9 100644 --- a/applications/Chat/coati/ray/src/utils.py +++ b/applications/Chat/coati/ray/src/utils.py @@ -1,30 +1,64 @@ import torch.distributed as dist from typing import Any, Callable, Dict, List, Optional -from coati.models.bloom import BLOOMActor, BLOOMCritic -from coati.models.gpt import GPTActor, GPTCritic -from coati.models.opt import OPTActor, OPTCritic +from coati.models.bloom import BLOOMActor, BLOOMCritic, BLOOMRM +from coati.models.gpt import GPTActor, GPTCritic, GPTRM +from coati.models.opt import OPTActor, OPTCritic, OPTRM +from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM + from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy import torch import os + def is_rank_0() -> bool: return not dist.is_initialized() or dist.get_rank() == 0 -def get_cuda_actor_critic_from_args(model: str, pretrained: str = None, lora_rank=0): +def get_actor_from_args(model: str, pretrained: str = None, lora_rank = 0): if model == 'gpt2': - actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + actor = GPTActor(pretrained=pretrained, lora_rank=lora_rank) elif model == 'bloom': - actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + actor = BLOOMActor(pretrained=pretrained, lora_rank=lora_rank) elif model == 'opt': - actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) - critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank).to(torch.cuda.current_device()) + actor = OPTActor(pretrained=pretrained, lora_rank=lora_rank) + elif model == 'llama': + actor = LlamaActor(pretrained=pretrained, lora_rank=lora_rank) + elif model == 'roberta': + actor = RoBERTaActor(pretrained=pretrained, lora_rank=lora_rank) else: - raise ValueError(f'Unsupported model "{model}"') - return actor, critic + raise ValueError(f'Unsupported actor model "{model}"') + return actor +def get_critic_from_args(model: str, pretrained: str = None, lora_rank = 0): + if model == 'gpt2': + critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, use_action_mask=True) + elif model == 'bloom': + critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, use_action_mask=True) + elif model == 'opt': + critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, use_action_mask=True) + elif model == 'llama': + critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, use_action_mask=True) + elif model == 'roberta': + critic = RoBERTaCritic(pretrained=pretrained, lora_rank=lora_rank, use_action_mask=True) + else: + raise ValueError(f'Unsupported reward model "{model}"') + return critic + +def get_reward_model_from_args(model: str, pretrained: str = None): + if model == 'gpt2': + reward_model = GPTRM(pretrained=pretrained) + elif model == 'bloom': + reward_model = BLOOMRM(pretrained=pretrained) + elif model == 'opt': + reward_model = OPTRM(pretrained=pretrained) + elif model == 'llama': + reward_model = LlamaRM(pretrained=pretrained) + elif model == 'roberta': + reward_model = RoBERTaRM(pretrained=pretrained) + else: + raise ValueError(f'Unsupported reward model "{model}"') + return reward_model def get_strategy_from_args(strategy: str): if strategy == 'naive': @@ -40,9 +74,40 @@ def get_strategy_from_args(strategy: str): return strategy_ +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer +from coati.utils import prepare_llama_tokenizer_and_embedding + +def get_tokenizer_from_args(model: str, **kwargs): + if model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + elif model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif model == 'llama': + pretrain_path = kwargs["pretrain"] + tokenizer = AutoTokenizer.from_pretrained(pretrain_path) + elif model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + else: + raise ValueError(f'Unsupported model "{model}"') + + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + def set_dist_env(env_info: Dict[str, str]): os.environ["RANK"] = env_info['rank'] os.environ["LOCAL_RANK"] = env_info['local_rank'] os.environ["WORLD_SIZE"] = env_info['world_size'] os.environ['MASTER_PORT'] = env_info['master_port'] os.environ['MASTER_ADDR'] = env_info['master_addr'] + + +def state_dict_to(state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device('cpu')): + ''' + keep state_dict intact + ''' + new_state_dict = {} + for k, v in state_dict.items(): + new_state_dict[k] = v.to(dtype = dtype, device = device) + return new_state_dict \ No newline at end of file diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index 7d25138561ea..e0232fbc64ad 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -112,6 +112,11 @@ def _unwrap_actor(actor: Actor) -> nn.Module: """ return Strategy._unwrap_model(actor) + @staticmethod + def _unwrap_critic(critic: Critic) -> nn.Module: + return Strategy._unwrap_model(critic) + + @abstractmethod def save_model(self, model: nn.Module, diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index ba85ba76d4b1..5a6021c5013f 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -5,7 +5,7 @@ import torch.distributed as dist import torch.nn as nn import torch.optim as optim -from coati.models.base import LM, Actor, RewardModel +from coati.models.base import LM, Actor, RewardModel, Critic from coati.models.lora import LoraLinear from torch.optim import Optimizer from transformers.modeling_utils import PreTrainedModel @@ -159,12 +159,12 @@ def _unwrap_actor(actor: Actor) -> nn.Module: return model.module return model - def _unwrap_model(self, model: Union[nn.Module, ZeroDDP]) -> nn.Module: - if isinstance(model, ZeroDDP) and self.stage == 3: - logger.info(f"model type: {type(model)}, get static torch model") - model = get_static_torch_model(model) - logger.info(f"unwrapped_model type: {type(model)}") + @staticmethod + def _unwrap_critic(critic: Critic) -> nn.Module: + return Strategy._unwrap_critic(critic) + + def _unwrap_model(self, model: Union[nn.Module, ZeroDDP]) -> nn.Module: return super()._unwrap_model(model) def save_model(self, @@ -210,3 +210,14 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal raise RuntimeError( f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.') torch.save(optimizer.state_dict(), path) + + def get_model_state_dict_shard(self, model: nn.Module, **config): + if self.stage != 3: + yield from super().get_model_state_dict_shard(model, **config) + else: + unwrapped_model = self._unwrap_model(model) + for module in unwrapped_model.modules(): + if isinstance(module, LoraLinear): + module.merge_weights = True + module.eval() + yield from model.state_dict_shard(max_shard_size=1024) \ No newline at end of file diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index 8a8c4b3c2f4e..a0fd3fa27a58 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from coati.models.base import LM, Actor, RewardModel +from coati.models.base import LM, Actor, RewardModel, Critic from coati.models.lora import LoraLinear from coati.replay_buffer import ReplayBuffer from torch.nn.parallel import DistributedDataParallel as DDP @@ -74,6 +74,11 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False def _unwrap_actor(actor: Actor) -> nn.Module: model: DDP = Strategy._unwrap_actor(actor) return model.module + + @staticmethod + def _unwrap_critic(critic: Critic) -> nn.Module: + model: DDP = Strategy._unwrap_critic(critic) + return model.module def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: if only_rank0 and dist.get_rank() != 0: @@ -109,3 +114,4 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal def setup_sampler(self, dataset) -> DistributedSampler: return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) + diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py index bb47e5ab2688..a22be1181fb8 100644 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -72,3 +72,15 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: state_dict = torch.load(path, map_location=map_location) optimizer.load_state_dict(state_dict) + + def get_model_state_dict_shard(self, model: nn.Module, **config): + # TODO: implement sharding on naive strategy + state_dict = model.state_dict() + yield state_dict + + def merge_lora_weight(self, model: nn.Module): + unwrapped_model = self._unwrap_model(model) + for module in unwrapped_model.modules(): + if isinstance(module, LoraLinear): + module.merge_weights = True + module.eval() \ No newline at end of file