From 73a25e84b6c7dd3178c45a51e7d22147cc1b6676 Mon Sep 17 00:00:00 2001 From: csric Date: Tue, 9 May 2023 16:23:16 +0800 Subject: [PATCH 1/3] lora support prototype --- .../Chat/coati/ray/detached_trainer_ppo.py | 28 +++- .../Chat/coati/ray/experience_maker_holder.py | 23 ++- .../Chat/coati/ray/lora_constructor.py | 138 ++++++++++++++++++ applications/Chat/coati/ray/utils.py | 27 ---- 4 files changed, 177 insertions(+), 39 deletions(-) create mode 100644 applications/Chat/coati/ray/lora_constructor.py diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py index d3dfc6e93a46..5d84b58ca0e3 100644 --- a/applications/Chat/coati/ray/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -13,6 +13,7 @@ from .callbacks import TrainerCallback, TrainerPerformanceEvaluator from .detached_trainer_base import DetachedTrainer +from .lora_constructor import LoRAConstructor from .utils import ( get_actor_from_args, get_critic_from_args, @@ -67,6 +68,7 @@ def __init__( callbacks: List[TrainerCallback] = [], eval_performance: bool = False, debug: bool = False, + update_lora_weights: bool = False, ) -> None: # set environment variables if env_info: @@ -106,6 +108,8 @@ def __init__( if self._debug: print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}') + self._update_lora_weights = update_lora_weights + @ray.method(concurrency_group="model_io") @torch.no_grad() def _update_remote_makers(self, fully_update: bool = False, **config): @@ -121,16 +125,18 @@ def _update_remote_makers(self, fully_update: bool = False, **config): # sending loop tasks = [] - for state_dict_shard in self._get_model_state_dict_shard(self.actor, **config): + for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update = fully_update, **config): for target_holder in self.target_holder_list: tasks.append( target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, + new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor), fully_update=fully_update)) # sending loop - for state_dict_shard in self._get_model_state_dict_shard(self.critic, **config): + for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update = fully_update, **config): for target_holder in self.target_holder_list: tasks.append( target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, + new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic), fully_update=fully_update)) ray.get(tasks) # mark end @@ -177,10 +183,16 @@ def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None: self.strategy.save_optimizer(self.critic_optim, path, only_rank0) - def _get_model_state_dict_shard(self, model: torch.nn.Module, **config): - # try: - # self.strategy.merge_lora_weight(model) - # except AttributeError: - # pass + def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update = False, **config): for state_dict in self.strategy.get_model_state_dict_shard(model, **config): - yield state_dict_to(state_dict) + if not self._update_lora_weights or fully_update: + yield state_dict_to(state_dict) + else: + state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict) + yield state_dict_to(state_dict_lora) + + def _get_model_lora_config_dict(self, model: torch.nn.Module): + if not self._update_lora_weights: + return None + unwrapped_model = self.strategy.unwrap_model(model) + return LoRAConstructor.extract_lora_config(unwrapped_model) \ No newline at end of file diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index 573771ad6258..52deb5180b60 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -20,7 +20,7 @@ from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env - +from .lora_constructor import LoRAConstructor @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) class ExperienceMakerHolder: @@ -45,6 +45,7 @@ def __init__( callbacks: List[MakerCallback] = [], eval_performance: bool = False, debug: bool = False, + update_lora_weights: bool = False, **generate_kwargs): # set environment variables if env_info: @@ -77,6 +78,11 @@ def __init__( self._is_fully_initialized = not sync_models_from_trainers self._debug = debug + self._update_lora_weights = update_lora_weights + if self._update_lora_weights: + self.actor_lora_constructor = LoRAConstructor() + self.critic_lora_constructor = LoRAConstructor() + self.target_auto_balance = False self._target_idx = 0 @@ -166,7 +172,9 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 @ray.method(concurrency_group="model_io") def update_experience_maker(self, new_actor_state_dict: Dict[str, Any] = None, + new_actor_lora_config_dict: Dict[str, Any] = None, new_critic_state_dict: Dict[str, Any] = None, + new_critic_lora_config_dict: Dict[str, Any] = None, fully_update: bool = False, chunk_start: bool = None, chunk_end: bool = None): @@ -188,10 +196,17 @@ def update_experience_maker(self, with torch.no_grad(): if new_actor_state_dict is not None: - self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False) + if not self._update_lora_weights or fully_update: + self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False) + else: + state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict) + self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae) if new_critic_state_dict is not None: - self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) - + if not self._update_lora_weights or fully_update: + self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) + else: + state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict) + self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae) # the lock must be released after both actor and critic being updated if chunk_end: diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py new file mode 100644 index 000000000000..bc42b5547df6 --- /dev/null +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -0,0 +1,138 @@ +from typing import Any, Callable, Dict, List, Optional +from collections import OrderedDict +from dataclasses import dataclass + +import torch +import torch.nn as nn +from loralib.layers import LoRALayer +from coati.models.lora import LoraLinear + + +@dataclass +class LoRAConfig: + r: int = 0 + lora_alpha: int = 1 + lora_dropout: float = 0 + fan_in_fan_out: bool = False + + +class LoRAConstructor: + ''' + Tools for reconstructing a model from a remote LoRA model. + (Transfering only LoRA data costs much less!) + Usage: + Step 1 (Sender): + filter_state_dict_lora() + + Step 2 (Sender, Optional): + extract_lora_config() + + Step 3 (Sender): + send state_dict_lora and lora_config_dict + + Step 4 (Receiver): + reconstruct_increase() + + Step 5 (Receiver): + load_state_dict_increase() + + ''' + + def __init__(self): + self.lora_config_dict = None + + def register_lora_config(self, lora_config_dict: Dict[str, Any]): + self.lora_config_dict = lora_config_dict + + def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any] = None): + ''' + xxx.lora_A, xxx.lora_B -->> xxx.weight_increase + ''' + if lora_config_dict is not None: + self.register_lora_config(lora_config_dict) + + state_dict_increasae = OrderedDict() + if self.lora_config_dict is None: + # default configuration + lora_A, lora_B, layer_prefix = None, None, None + for k, v in state_dict_lora.items(): + if k.rpartition('.')[-1] == 'lora_A': + lora_A = v + layer_prefix = k.rpartition('.')[0] + elif k.rpartition('.')[-1] == 'lora_B': + assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" + lora_B = v + weight_data_increase = self._compute(lora_A, lora_B) + state_dict_increasae[layer_prefix + '.weight_increase'] = weight_data_increase + lora_A, lora_B, layer_prefix = None, None, None + else: + raise ValueError('unexpected key') + else: + # per layer configuration + config_iter = iter(self.lora_config_dict.items()) + lora_A, lora_B, layer_prefix = None, None, None + for k, v in state_dict_lora.items(): + if k.rpartition('.')[-1] == 'lora_A': + lora_A = v + layer_prefix = k.rpartition('.')[0] + elif k.rpartition('.')[-1] == 'lora_B': + assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" + layer_prefix_2, config = next(config_iter) + assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" + lora_B = v + weight_data_increase = self._compute(lora_A, lora_B, config) + state_dict_increasae[layer_prefix + '.weight_increase'] = weight_data_increase + lora_A, lora_B, layer_prefix = None, None, None + else: + raise ValueError('unexpected key') + return state_dict_increasae + + def _compute(self, lora_A, lora_B, config=LoRAConfig()): + def T(w): + return w.T if config.fan_in_fan_out else w + if config.r > 0: + scaling = config.lora_alpha / config.r + weight_data_increase = T(lora_B @ lora_A) * scaling + return weight_data_increase + return 0 + + def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]): + ''' + The final reconstruction step + ''' + # naive approach + model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False) + + @staticmethod + def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): + ''' + if keep_non_lora, also return non_lora state_dict + ''' + state_dict_lora = OrderedDict() + state_dict_non_lora = OrderedDict() + for k, v in state_dict: + if 'lora_A' in k or 'lora_B' in k: + state_dict_lora[k] = v + elif keep_non_lora: + state_dict_non_lora[k] = v + if keep_non_lora: + return state_dict_lora, state_dict_non_lora + else: + return state_dict_lora, None + + @staticmethod + def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: + ''' + extract LoraLinear model. + return OrderedDict(): name -> LoRAConfig + ''' + lora_config_dict = OrderedDict() + + for name, child in model.named_modules: + if isinstance(child, LoraLinear): + lora_config_dict[name] = LoRAConfig(r=child.r, + lora_alpha=child.lora_alpha, + lora_dropout=child.lora_dropout, + fan_in_fan_out=child.fan_in_fan_out) + + return lora_config_dict diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 7e36a6b08589..4361ee236771 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -150,30 +150,3 @@ def state_dict_to(state_dict: Dict[str, Any], for k, v in state_dict.items(): new_state_dict[k] = v.to(dtype=dtype, device=device) return new_state_dict - - -def state_dict_filter_lora(state_dict: Dict[str, Any], keep_non_lora = False): - ''' - if keep_non_lora, also return non_lora state_dict - ''' - state_dict_lora = OrderedDict() - state_dict_non_lora = OrderedDict() - for k, v in state_dict: - if 'lora_A' in k or 'lora_B' in k: - state_dict_lora[k] = v - elif keep_non_lora: - state_dict_non_lora[k] = v - if keep_non_lora: - return state_dict_lora, state_dict_non_lora - else: - return state_dict_lora - - -def state_dict_lora_reconstruct(state_dict_lora: Dict[str, Any]): - ''' - xxx.lora_A, xxx.lora_B -->> xxx.weight - TODO - ''' - state_dict_reconstruct = OrderedDict() - - \ No newline at end of file From a6cd2f8914f886fb83e0ba675cab7160e61dc023 Mon Sep 17 00:00:00 2001 From: csric Date: Tue, 9 May 2023 17:09:40 +0800 Subject: [PATCH 2/3] lora support --- .../Chat/coati/ray/experience_maker_holder.py | 9 ++++++++- applications/Chat/coati/ray/lora_constructor.py | 11 ++++++----- applications/Chat/examples/ray/mmmt_prompt.py | 5 ++++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index 52deb5180b60..8551ef1eacef 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -19,7 +19,12 @@ from tqdm import tqdm from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback -from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env +from .utils import (get_model_numel, + get_rank, + get_world_size, + is_rank_0, + set_dist_env, + state_dict_to) from .lora_constructor import LoRAConstructor @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) @@ -199,12 +204,14 @@ def update_experience_maker(self, if not self._update_lora_weights or fully_update: self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False) else: + new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device()) state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict) self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae) if new_critic_state_dict is not None: if not self._update_lora_weights or fully_update: self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) else: + new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device()) state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict) self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae) diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py index bc42b5547df6..5d5f0e662d26 100644 --- a/applications/Chat/coati/ray/lora_constructor.py +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -46,7 +46,8 @@ def register_lora_config(self, lora_config_dict: Dict[str, Any]): def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any] = None): ''' - xxx.lora_A, xxx.lora_B -->> xxx.weight_increase + xxx.lora_A, xxx.lora_B -->> xxx.weight + Warning: the xxx.weight here is the increment actually. ''' if lora_config_dict is not None: self.register_lora_config(lora_config_dict) @@ -63,7 +64,7 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" lora_B = v weight_data_increase = self._compute(lora_A, lora_B) - state_dict_increasae[layer_prefix + '.weight_increase'] = weight_data_increase + state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase lora_A, lora_B, layer_prefix = None, None, None else: raise ValueError('unexpected key') @@ -81,7 +82,7 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" lora_B = v weight_data_increase = self._compute(lora_A, lora_B, config) - state_dict_increasae[layer_prefix + '.weight_increase'] = weight_data_increase + state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase lora_A, lora_B, layer_prefix = None, None, None else: raise ValueError('unexpected key') @@ -110,7 +111,7 @@ def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): ''' state_dict_lora = OrderedDict() state_dict_non_lora = OrderedDict() - for k, v in state_dict: + for k, v in state_dict.items(): if 'lora_A' in k or 'lora_B' in k: state_dict_lora[k] = v elif keep_non_lora: @@ -128,7 +129,7 @@ def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: ''' lora_config_dict = OrderedDict() - for name, child in model.named_modules: + for name, child in model.named_modules(): if isinstance(child, LoraLinear): lora_config_dict[name] = LoRAConfig(r=child.r, lora_alpha=child.lora_alpha, diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py index 6f43d8950758..941e080ebb5e 100644 --- a/applications/Chat/examples/ray/mmmt_prompt.py +++ b/applications/Chat/examples/ray/mmmt_prompt.py @@ -20,7 +20,6 @@ from transformers import AutoConfig, AutoTokenizer from transformers.modeling_utils import no_init_weights - def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) @@ -86,6 +85,7 @@ def model_fn(): env_info=env_info_maker, kl_coef=0.1, debug=args.debug, + update_lora_weights = not (args.lora_rank == 0), # sync_models_from_trainers=True, # generation kwargs: max_length=512, @@ -119,6 +119,7 @@ def trainer_model_fn(): buffer_limit=16, eval_performance=True, debug=args.debug, + update_lora_weights = not (args.lora_rank == 0), ) for i, env_info_trainer in enumerate(env_info_trainers) ] @@ -156,6 +157,7 @@ def tokenize_fn(texts): for trainer_ref in trainer_refs: wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) + ray.get(wait_tasks) @@ -187,5 +189,6 @@ def tokenize_fn(texts): parser.add_argument('--quant_group_size', type=int, default=128) parser.add_argument('--debug', action='store_true') args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) main(args) From f0022dd57a1acb88c14dfb362cfa51827b5786f0 Mon Sep 17 00:00:00 2001 From: csric Date: Tue, 9 May 2023 19:05:35 +0800 Subject: [PATCH 3/3] 1mmt lora & remove useless code --- .../Chat/coati/ray/lora_constructor.py | 51 +++++++------------ applications/Chat/examples/ray/1mmt_prompt.py | 2 + 2 files changed, 19 insertions(+), 34 deletions(-) diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py index 5d5f0e662d26..599a58248728 100644 --- a/applications/Chat/coati/ray/lora_constructor.py +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -44,7 +44,7 @@ def __init__(self): def register_lora_config(self, lora_config_dict: Dict[str, Any]): self.lora_config_dict = lora_config_dict - def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any] = None): + def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]): ''' xxx.lora_A, xxx.lora_B -->> xxx.weight Warning: the xxx.weight here is the increment actually. @@ -53,39 +53,22 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict self.register_lora_config(lora_config_dict) state_dict_increasae = OrderedDict() - if self.lora_config_dict is None: - # default configuration - lora_A, lora_B, layer_prefix = None, None, None - for k, v in state_dict_lora.items(): - if k.rpartition('.')[-1] == 'lora_A': - lora_A = v - layer_prefix = k.rpartition('.')[0] - elif k.rpartition('.')[-1] == 'lora_B': - assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" - lora_B = v - weight_data_increase = self._compute(lora_A, lora_B) - state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase - lora_A, lora_B, layer_prefix = None, None, None - else: - raise ValueError('unexpected key') - else: - # per layer configuration - config_iter = iter(self.lora_config_dict.items()) - lora_A, lora_B, layer_prefix = None, None, None - for k, v in state_dict_lora.items(): - if k.rpartition('.')[-1] == 'lora_A': - lora_A = v - layer_prefix = k.rpartition('.')[0] - elif k.rpartition('.')[-1] == 'lora_B': - assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" - layer_prefix_2, config = next(config_iter) - assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" - lora_B = v - weight_data_increase = self._compute(lora_A, lora_B, config) - state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase - lora_A, lora_B, layer_prefix = None, None, None - else: - raise ValueError('unexpected key') + config_iter = iter(self.lora_config_dict.items()) + lora_A, lora_B, layer_prefix = None, None, None + for k, v in state_dict_lora.items(): + if k.rpartition('.')[-1] == 'lora_A': + lora_A = v + layer_prefix = k.rpartition('.')[0] + elif k.rpartition('.')[-1] == 'lora_B': + assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" + layer_prefix_2, config = next(config_iter) + assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" + lora_B = v + weight_data_increase = self._compute(lora_A, lora_B, config) + state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase + lora_A, lora_B, layer_prefix = None, None, None + else: + raise ValueError('unexpected key') return state_dict_increasae def _compute(self, lora_A, lora_B, config=LoRAConfig()): diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py index bd7224aae749..06e522962f0e 100644 --- a/applications/Chat/examples/ray/1mmt_prompt.py +++ b/applications/Chat/examples/ray/1mmt_prompt.py @@ -73,6 +73,7 @@ def trainer_model_fn(): buffer_limit=16, eval_performance=True, debug=args.debug, + update_lora_weights = not (args.lora_rank == 0), ) for i, env_info_trainer in enumerate(env_info_trainers) ] @@ -100,6 +101,7 @@ def model_fn(): experience_batch_size=args.experience_batch_size, kl_coef=0.1, debug=args.debug, + update_lora_weights = not (args.lora_rank == 0), # sync_models_from_trainers=True, # generation kwargs: max_length=512,