diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index 961f2aec677d..f57c9458a271 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -77,6 +77,7 @@ def sample(model: nn.Module, input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if update_model_kwargs_fn is not None: model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) + # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py index d30158019d65..d3dfc6e93a46 100644 --- a/applications/Chat/coati/ray/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -120,13 +120,14 @@ def _update_remote_makers(self, fully_update: bool = False, **config): ray.get(tasks) # sending loop tasks = [] - for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config): + + for state_dict_shard in self._get_model_state_dict_shard(self.actor, **config): for target_holder in self.target_holder_list: tasks.append( target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, fully_update=fully_update)) # sending loop - for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): + for state_dict_shard in self._get_model_state_dict_shard(self.critic, **config): for target_holder in self.target_holder_list: tasks.append( target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, @@ -176,28 +177,6 @@ def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: self.strategy.save_optimizer(self.critic_optim, path, only_rank0) - def _get_unwrapped_actor(self): - if False: - pass - elif isinstance(self.strategy, ColossalAIStrategy): - ret = Actor(self.strategy._unwrap_model(self.actor)) - return ret - elif isinstance(self.strategy, DDPStrategy): - return Actor(self.strategy._unwrap_actor(self.actor)) - elif isinstance(self.strategy, NaiveStrategy): - return self.actor - - def _get_unwrapped_critic(self): - if False: - pass - elif isinstance(self.strategy, ColossalAIStrategy): - ret = self.strategy._unwrap_model(self.critic) - return ret - elif isinstance(self.strategy, DDPStrategy): - return self.critic.module - elif isinstance(self.strategy, NaiveStrategy): - return self.critic - def _get_model_state_dict_shard(self, model: torch.nn.Module, **config): # try: # self.strategy.merge_lora_weight(model) diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 6cd7c564cc92..48f33e70c632 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -120,18 +120,6 @@ def set_dist_env(env_info: Dict[str, str]): os.environ['MASTER_ADDR'] = env_info['master_addr'] -def state_dict_to(state_dict: Dict[str, Any], - dtype: torch.dtype = torch.float16, - device: torch.device = torch.device('cpu')): - ''' - keep state_dict intact - ''' - new_state_dict = {} - for k, v in state_dict.items(): - new_state_dict[k] = v.to(dtype=dtype, device=device) - return new_state_dict - - def get_model_numel(model: nn.Module) -> int: numel = sum(p.numel() for p in model.parameters()) return numel @@ -150,3 +138,41 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i # a receiver may have more than one sender target_receivers.append(sender_idx % num_receivers) return target_receivers + + +def state_dict_to(state_dict: Dict[str, Any], + dtype: torch.dtype = torch.float16, + device: torch.device = torch.device('cpu')): + ''' + keep state_dict intact + ''' + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_state_dict[k] = v.to(dtype=dtype, device=device) + return new_state_dict + + +def state_dict_filter_lora(state_dict: Dict[str, Any], keep_non_lora = False): + ''' + if keep_non_lora, also return non_lora state_dict + ''' + state_dict_lora = OrderedDict() + state_dict_non_lora = OrderedDict() + for k, v in state_dict: + if 'lora_A' in k or 'lora_B' in k: + state_dict_lora[k] = v + elif keep_non_lora: + state_dict_non_lora[k] = v + if keep_non_lora: + return state_dict_lora, state_dict_non_lora + else: + return state_dict_lora + + +def state_dict_lora_reconstruct(state_dict_lora: Dict[str, Any]): + ''' + xxx.lora_A, xxx.lora_B -->> xxx.weight + ''' + state_dict_reconstruct = OrderedDict() + + \ No newline at end of file diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index c10bc2d185a8..bd30422022ae 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -104,10 +104,6 @@ def unwrap_model(model: nn.Module) -> nn.Module: """ return get_base_model(model) - @staticmethod - def _unwrap_critic(critic: Critic) -> nn.Module: - return Strategy._unwrap_model(critic) - @abstractmethod def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: pass @@ -134,3 +130,7 @@ def save_pretrained(self, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: pass + + @abstractmethod + def get_model_state_dict_shard(self, model: nn.Module, **config): + pass \ No newline at end of file diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index acee42d7b5b3..88268b677eb2 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -171,18 +171,6 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.') torch.save(optimizer.state_dict(), path) - def get_model_state_dict_shard(self, model: nn.Module, **config): - if self.stage != 3: - yield from super().get_model_state_dict_shard(model, **config) - else: - # unwrapped_model = self._unwrap_model(model) - # for module in unwrapped_model.modules(): - # if isinstance(module, LoraLinear): - # module.merge_weights = True - # module.eval() - model: ZeroDDP = model - yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False) - def unwrap_model(self, model: nn.Module) -> nn.Module: base_model: Union[nn.Module, ZeroDDP] = get_base_model(model) if self.stage == 3: @@ -198,3 +186,15 @@ def save_pretrained(self, if self.stage == 3: raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now') super().save_pretrained(model, path, only_rank0, tokenizer) + + def get_model_state_dict_shard(self, model: nn.Module, **config): + if self.stage != 3: + yield from super().get_model_state_dict_shard(model, **config) + else: + # unwrapped_model = self._unwrap_model(model) + # for module in unwrapped_model.modules(): + # if isinstance(module, LoraLinear): + # module.merge_weights = True + # module.eval() + base_model: ZeroDDP = get_base_model(model) + yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False) diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py index 5d3da8ee3478..972deebeaa0d 100644 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -9,6 +9,9 @@ import torch.optim as optim from coati.models.base import get_base_model from coati.replay_buffer import ReplayBuffer +from coati.models.base import RewardModel +from coati.models.lora import LoraLinear +from coati.replay_buffer import ReplayBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader from transformers.modeling_utils import PreTrainedModel @@ -71,8 +74,20 @@ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = No state_dict = torch.load(path, map_location=map_location) optimizer.load_state_dict(state_dict) + def save_pretrained(self, + model: nn.Module, + path: str, + only_rank0: bool = True, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + unwrapped_model = self.unwrap_model(model) + assert isinstance(unwrapped_model, PreTrainedModel) + unwrapped_model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy + model = self.unwrap_model(model) if 'requires_grad_only' in config and config['requires_grad_only'] == True: state_dict = get_grad_required_state_dict(model) else: @@ -111,14 +126,3 @@ def _try_init_dist(self, force: bool = False) -> None: except Exception as e: if force: raise e - - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - unwrapped_model = self.unwrap_model(model) - assert isinstance(unwrapped_model, PreTrainedModel) - unwrapped_model.save_pretrained(path) - if tokenizer is not None: - tokenizer.save_pretrained(path) diff --git a/applications/Chat/examples/ray/.gitignore b/applications/Chat/examples/ray/.gitignore new file mode 100644 index 000000000000..4cf8dd15619e --- /dev/null +++ b/applications/Chat/examples/ray/.gitignore @@ -0,0 +1 @@ +logs/* \ No newline at end of file diff --git a/applications/Chat/examples/ray/benchmark.sh b/applications/Chat/examples/ray/benchmark.sh new file mode 100644 index 000000000000..3852684007b7 --- /dev/null +++ b/applications/Chat/examples/ray/benchmark.sh @@ -0,0 +1,39 @@ + +PROMPT_PATH=/home/lccsr/data3/awesome-chatgpt-prompts/prompts.csv + +num_trainers=4 +num_makers=4 + +# "facebook/opt-2.7b" +for pretrain in "facebook/opt-1.3b" "facebook/opt-6.7b" "facebook/opt-13b" +do + + for experience_batch_size in 16 32 64 + do + for train_batch_size in 16 32 64 + do + for update_steps in 8 32 128 + do + # set a big enough experience_steps for twice maker-update + experience_steps=$((2*num_trainers*train_batch_size*update_steps/num_makers/experience_batch_size)) + + config_string=${num_trainers}_${num_makers}_pretrain_${pretrain##*/}_experience_batch_size_${experience_batch_size}_train_batch_size_${train_batch_size}_update_steps_${update_steps}_experience_steps_${experience_steps} + echo running: ${config_string} + + nohup python mmmt_prompt.py \ + --prompt_path $PROMPT_PATH \ + --trainer_strategy colossalai_gemini --maker_strategy naive \ + --model 'opt' \ + --pretrain $pretrain \ + --critic_pretrain "facebook/opt-350m" \ + --num_trainers $num_trainers \ + --num_makers $num_makers \ + --experience_steps $experience_steps \ + --experience_batch_size $experience_batch_size \ + --update_steps $update_steps \ + --train_batch_size $train_batch_size \ + --debug > logs/output_${config_string}.txt 2>&1 + done + done + done +done \ No newline at end of file diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py index d2398d451c7b..6f43d8950758 100644 --- a/applications/Chat/examples/ray/mmmt_prompt.py +++ b/applications/Chat/examples/ray/mmmt_prompt.py @@ -101,8 +101,8 @@ def model_fn(): ] def trainer_model_fn(): - actor = get_actor_from_args(args.model, args.pretrain).half().cuda() - critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda() + actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda() + critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda() return actor, critic # configure Trainer