Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bb43927
prompt example
Apr 23, 2023
884a645
prompt load csv data
Apr 24, 2023
603bd7e
remove legacy try
Apr 24, 2023
fe4cecb
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo
Apr 24, 2023
99fbd40
maker models require_grad set to False
Apr 24, 2023
08108d2
working on zero redundancy update
Apr 25, 2023
bded51f
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo
Apr 25, 2023
3266516
mmmt_prompt example; naive strategy requires_grad state_dict & shardi…
Apr 25, 2023
332fd3c
remove legacy examples
Apr 25, 2023
9b2b77e
remove legacy examples
Apr 25, 2023
18f5814
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo
Apr 25, 2023
13fa038
remove replay buffer tp state. bad design
Apr 26, 2023
9a6be66
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo
Apr 26, 2023
3cab3fd
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo
Apr 26, 2023
cf5c2d1
opt benchmark
Apr 27, 2023
94783c5
better script
Apr 27, 2023
bfe6a69
nothing
Apr 27, 2023
1c98e93
[chat] strategy refactor unwrap model
ver217 Apr 27, 2023
f5a0821
[chat] strategy refactor save model
ver217 Apr 27, 2023
47b47eb
[chat] add docstr
ver217 Apr 27, 2023
1e56034
[chat] refactor trainer save model
ver217 Apr 27, 2023
cfce710
[chat] fix strategy typing
ver217 Apr 27, 2023
0c25123
[chat] refactor trainer save model
ver217 Apr 27, 2023
53a3b90
[chat] update readme
ver217 Apr 27, 2023
b762e15
[chat] fix unit test
ver217 Apr 27, 2023
eafca77
working on lora reconstruction
Apr 27, 2023
40bec4e
Merge remote-tracking branch 'ver217/refactor/chat-actor' into detach…
Apr 27, 2023
ed2dd61
state_dict sending adapts to new unwrap function
Apr 27, 2023
9bb05a2
Merge remote-tracking branch 'upstream/main' into detached_ppo_refactor
Apr 27, 2023
b995717
Merge remote-tracking branch 'ver217/dev/chat-ray' into detached_ppo_…
Apr 27, 2023
b79af56
remove comments
Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions applications/Chat/coati/models/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def sample(model: nn.Module,
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if update_model_kwargs_fn is not None:
model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
Expand Down
27 changes: 3 additions & 24 deletions applications/Chat/coati/ray/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,14 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
ray.get(tasks)
# sending loop
tasks = []
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config):

for state_dict_shard in self._get_model_state_dict_shard(self.actor, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
fully_update=fully_update))
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config):
for state_dict_shard in self._get_model_state_dict_shard(self.critic, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
Expand Down Expand Up @@ -176,28 +177,6 @@ def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)

def _get_unwrapped_actor(self):
if False:
pass
elif isinstance(self.strategy, ColossalAIStrategy):
ret = Actor(self.strategy._unwrap_model(self.actor))
return ret
elif isinstance(self.strategy, DDPStrategy):
return Actor(self.strategy._unwrap_actor(self.actor))
elif isinstance(self.strategy, NaiveStrategy):
return self.actor

def _get_unwrapped_critic(self):
if False:
pass
elif isinstance(self.strategy, ColossalAIStrategy):
ret = self.strategy._unwrap_model(self.critic)
return ret
elif isinstance(self.strategy, DDPStrategy):
return self.critic.module
elif isinstance(self.strategy, NaiveStrategy):
return self.critic

def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
# try:
# self.strategy.merge_lora_weight(model)
Expand Down
50 changes: 38 additions & 12 deletions applications/Chat/coati/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,6 @@ def set_dist_env(env_info: Dict[str, str]):
os.environ['MASTER_ADDR'] = env_info['master_addr']


def state_dict_to(state_dict: Dict[str, Any],
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cpu')):
'''
keep state_dict intact
'''
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
return new_state_dict


def get_model_numel(model: nn.Module) -> int:
numel = sum(p.numel() for p in model.parameters())
return numel
Expand All @@ -150,3 +138,41 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
# a receiver may have more than one sender
target_receivers.append(sender_idx % num_receivers)
return target_receivers


def state_dict_to(state_dict: Dict[str, Any],
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cpu')):
'''
keep state_dict intact
'''
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
return new_state_dict


def state_dict_filter_lora(state_dict: Dict[str, Any], keep_non_lora = False):
'''
if keep_non_lora, also return non_lora state_dict
'''
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict:
if 'lora_A' in k or 'lora_B' in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
if keep_non_lora:
return state_dict_lora, state_dict_non_lora
else:
return state_dict_lora


def state_dict_lora_reconstruct(state_dict_lora: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
'''
state_dict_reconstruct = OrderedDict()


8 changes: 4 additions & 4 deletions applications/Chat/coati/trainer/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ def unwrap_model(model: nn.Module) -> nn.Module:
"""
return get_base_model(model)

@staticmethod
def _unwrap_critic(critic: Critic) -> nn.Module:
return Strategy._unwrap_model(critic)

@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
pass
Expand All @@ -134,3 +130,7 @@ def save_pretrained(self,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
pass

@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
24 changes: 12 additions & 12 deletions applications/Chat/coati/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,6 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path)

def get_model_state_dict_shard(self, model: nn.Module, **config):
if self.stage != 3:
yield from super().get_model_state_dict_shard(model, **config)
else:
# unwrapped_model = self._unwrap_model(model)
# for module in unwrapped_model.modules():
# if isinstance(module, LoraLinear):
# module.merge_weights = True
# module.eval()
model: ZeroDDP = model
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)

def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3:
Expand All @@ -198,3 +186,15 @@ def save_pretrained(self,
if self.stage == 3:
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
super().save_pretrained(model, path, only_rank0, tokenizer)

def get_model_state_dict_shard(self, model: nn.Module, **config):
if self.stage != 3:
yield from super().get_model_state_dict_shard(model, **config)
else:
# unwrapped_model = self._unwrap_model(model)
# for module in unwrapped_model.modules():
# if isinstance(module, LoraLinear):
# module.merge_weights = True
# module.eval()
base_model: ZeroDDP = get_base_model(model)
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
26 changes: 15 additions & 11 deletions applications/Chat/coati/trainer/strategies/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torch.optim as optim
from coati.models.base import get_base_model
from coati.replay_buffer import ReplayBuffer
from coati.models.base import RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -71,8 +74,20 @@ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = No
state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict)

def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)

def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
state_dict = get_grad_required_state_dict(model)
else:
Expand Down Expand Up @@ -111,14 +126,3 @@ def _try_init_dist(self, force: bool = False) -> None:
except Exception as e:
if force:
raise e

def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
1 change: 1 addition & 0 deletions applications/Chat/examples/ray/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
logs/*
39 changes: 39 additions & 0 deletions applications/Chat/examples/ray/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

PROMPT_PATH=/home/lccsr/data3/awesome-chatgpt-prompts/prompts.csv

num_trainers=4
num_makers=4

# "facebook/opt-2.7b"
for pretrain in "facebook/opt-1.3b" "facebook/opt-6.7b" "facebook/opt-13b"
do

for experience_batch_size in 16 32 64
do
for train_batch_size in 16 32 64
do
for update_steps in 8 32 128
do
# set a big enough experience_steps for twice maker-update
experience_steps=$((2*num_trainers*train_batch_size*update_steps/num_makers/experience_batch_size))

config_string=${num_trainers}_${num_makers}_pretrain_${pretrain##*/}_experience_batch_size_${experience_batch_size}_train_batch_size_${train_batch_size}_update_steps_${update_steps}_experience_steps_${experience_steps}
echo running: ${config_string}

nohup python mmmt_prompt.py \
--prompt_path $PROMPT_PATH \
--trainer_strategy colossalai_gemini --maker_strategy naive \
--model 'opt' \
--pretrain $pretrain \
--critic_pretrain "facebook/opt-350m" \
--num_trainers $num_trainers \
--num_makers $num_makers \
--experience_steps $experience_steps \
--experience_batch_size $experience_batch_size \
--update_steps $update_steps \
--train_batch_size $train_batch_size \
--debug > logs/output_${config_string}.txt 2>&1
done
done
done
done
4 changes: 2 additions & 2 deletions applications/Chat/examples/ray/mmmt_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def model_fn():
]

def trainer_model_fn():
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda()
actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda()
return actor, critic

# configure Trainer
Expand Down