Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions applications/Chat/coati/ray/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
from .utils import (
get_actor_from_args,
get_critic_from_args,
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
callbacks: List[TrainerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
) -> None:
# set environment variables
if env_info:
Expand Down Expand Up @@ -106,6 +108,8 @@ def __init__(
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')

self._update_lora_weights = update_lora_weights

@ray.method(concurrency_group="model_io")
@torch.no_grad()
def _update_remote_makers(self, fully_update: bool = False, **config):
Expand All @@ -121,16 +125,18 @@ def _update_remote_makers(self, fully_update: bool = False, **config):
# sending loop
tasks = []

for state_dict_shard in self._get_model_state_dict_shard(self.actor, **config):
for state_dict_shard in self._get_model_state_dict_shard(self.actor, fully_update = fully_update, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update))
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, **config):
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update = fully_update, **config):
for target_holder in self.target_holder_list:
tasks.append(
target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
ray.get(tasks)
# mark end
Expand Down Expand Up @@ -177,10 +183,16 @@ def strategy_save_actor_optim(self, path: str, only_rank0: bool = False) -> None
def strategy_save_critic_optim(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_optimizer(self.critic_optim, path, only_rank0)

def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
# try:
# self.strategy.merge_lora_weight(model)
# except AttributeError:
# pass
def _get_model_state_dict_shard(self, model: torch.nn.Module, fully_update = False, **config):
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
yield state_dict_to(state_dict)
if not self._update_lora_weights or fully_update:
yield state_dict_to(state_dict)
else:
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(state_dict)
yield state_dict_to(state_dict_lora)

def _get_model_lora_config_dict(self, model: torch.nn.Module):
if not self._update_lora_weights:
return None
unwrapped_model = self.strategy.unwrap_model(model)
return LoRAConstructor.extract_lora_config(unwrapped_model)
32 changes: 27 additions & 5 deletions applications/Chat/coati/ray/experience_maker_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
from tqdm import tqdm

from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env

from .utils import (get_model_numel,
get_rank,
get_world_size,
is_rank_0,
set_dist_env,
state_dict_to)
from .lora_constructor import LoRAConstructor

@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
Expand All @@ -45,6 +50,7 @@ def __init__(
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs):
# set environment variables
if env_info:
Expand Down Expand Up @@ -77,6 +83,11 @@ def __init__(
self._is_fully_initialized = not sync_models_from_trainers

self._debug = debug
self._update_lora_weights = update_lora_weights
if self._update_lora_weights:
self.actor_lora_constructor = LoRAConstructor()
self.critic_lora_constructor = LoRAConstructor()

self.target_auto_balance = False

self._target_idx = 0
Expand Down Expand Up @@ -166,7 +177,9 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1
@ray.method(concurrency_group="model_io")
def update_experience_maker(self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None):
Expand All @@ -188,10 +201,19 @@ def update_experience_maker(self,

with torch.no_grad():
if new_actor_state_dict is not None:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
if not self._update_lora_weights or fully_update:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increasae = self.actor_lora_constructor.reconstruct_increase(new_actor_state_dict, new_actor_lora_config_dict)
self.actor_lora_constructor.load_state_dict_increase(self.experience_maker.actor.model, state_dict_increasae)
if new_critic_state_dict is not None:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)

if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increasae = self.critic_lora_constructor.reconstruct_increase(new_critic_state_dict, new_critic_lora_config_dict)
self.critic_lora_constructor.load_state_dict_increase(self.experience_maker.critic, state_dict_increasae)

# the lock must be released after both actor and critic being updated
if chunk_end:
Expand Down
122 changes: 122 additions & 0 deletions applications/Chat/coati/ray/lora_constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict
from dataclasses import dataclass

import torch
import torch.nn as nn
from loralib.layers import LoRALayer
from coati.models.lora import LoraLinear


@dataclass
class LoRAConfig:
r: int = 0
lora_alpha: int = 1
lora_dropout: float = 0
fan_in_fan_out: bool = False


class LoRAConstructor:
'''
Tools for reconstructing a model from a remote LoRA model.
(Transfering only LoRA data costs much less!)
Usage:
Step 1 (Sender):
filter_state_dict_lora()

Step 2 (Sender, Optional):
extract_lora_config()

Step 3 (Sender):
send state_dict_lora and lora_config_dict

Step 4 (Receiver):
reconstruct_increase()

Step 5 (Receiver):
load_state_dict_increase()

'''

def __init__(self):
self.lora_config_dict = None

def register_lora_config(self, lora_config_dict: Dict[str, Any]):
self.lora_config_dict = lora_config_dict

def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
'''
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)

state_dict_increasae = OrderedDict()
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
if k.rpartition('.')[-1] == 'lora_A':
lora_A = v
layer_prefix = k.rpartition('.')[0]
elif k.rpartition('.')[-1] == 'lora_B':
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increasae[layer_prefix + '.weight'] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
raise ValueError('unexpected key')
return state_dict_increasae

def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
return weight_data_increase
return 0

def load_state_dict_increase(self, model: nn.Module, state_dict_increasae: Dict[str, Any]):
'''
The final reconstruction step
'''
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increasae.items()}, strict=False)

@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
'''
if keep_non_lora, also return non_lora state_dict
'''
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
if 'lora_A' in k or 'lora_B' in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
if keep_non_lora:
return state_dict_lora, state_dict_non_lora
else:
return state_dict_lora, None

@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
'''
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
'''
lora_config_dict = OrderedDict()

for name, child in model.named_modules():
if isinstance(child, LoraLinear):
lora_config_dict[name] = LoRAConfig(r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out)

return lora_config_dict
27 changes: 0 additions & 27 deletions applications/Chat/coati/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,30 +150,3 @@ def state_dict_to(state_dict: Dict[str, Any],
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)
return new_state_dict


def state_dict_filter_lora(state_dict: Dict[str, Any], keep_non_lora = False):
'''
if keep_non_lora, also return non_lora state_dict
'''
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict:
if 'lora_A' in k or 'lora_B' in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
if keep_non_lora:
return state_dict_lora, state_dict_non_lora
else:
return state_dict_lora


def state_dict_lora_reconstruct(state_dict_lora: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
TODO
'''
state_dict_reconstruct = OrderedDict()


2 changes: 2 additions & 0 deletions applications/Chat/examples/ray/1mmt_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def trainer_model_fn():
buffer_limit=16,
eval_performance=True,
debug=args.debug,
update_lora_weights = not (args.lora_rank == 0),
) for i, env_info_trainer in enumerate(env_info_trainers)
]

Expand Down Expand Up @@ -100,6 +101,7 @@ def model_fn():
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
debug=args.debug,
update_lora_weights = not (args.lora_rank == 0),
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
Expand Down
5 changes: 4 additions & 1 deletion applications/Chat/examples/ray/mmmt_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights


def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
Expand Down Expand Up @@ -86,6 +85,7 @@ def model_fn():
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
update_lora_weights = not (args.lora_rank == 0),
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
Expand Down Expand Up @@ -119,6 +119,7 @@ def trainer_model_fn():
buffer_limit=16,
eval_performance=True,
debug=args.debug,
update_lora_weights = not (args.lora_rank == 0),
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
Expand Down Expand Up @@ -156,6 +157,7 @@ def tokenize_fn(texts):
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))


ray.get(wait_tasks)


Expand Down Expand Up @@ -187,5 +189,6 @@ def tokenize_fn(texts):
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()

ray.init(namespace=os.environ["RAY_NAMESPACE"])
main(args)