Skip to content
Merged
23 changes: 3 additions & 20 deletions applications/Chat/coati/ray/detached_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,12 @@ class DetachedReplayBuffer:
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
'''

def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit: int = 0) -> None:
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
self.limit = limit
self.items = Queue(self.limit, actor_options={"num_cpus": 1})
self.batch_collector: List[BufferItem] = []
'''
Workers in the same tp group share this buffer and need same sample for one step.
Therefore a held_sample should be returned tp_world_size times before it could be dropped.
worker_state records wheter a worker got the held_sample
'''
self.tp_world_size = tp_world_size
self.worker_state = [False] * self.tp_world_size
self.held_sample = None
self._worker_state_lock = Lock()


@torch.no_grad()
def append(self, experience: Experience) -> None:
Expand Down Expand Up @@ -70,16 +62,7 @@ def clear(self) -> None:

@torch.no_grad()
def sample(self, worker_rank=0, to_device="cpu") -> Experience:
self._worker_state_lock.acquire()
if not any(self.worker_state):
self.held_sample = self._sample_and_erase()
self.worker_state[worker_rank] = True
if all(self.worker_state):
self.worker_state = [False] * self.tp_world_size
ret = self.held_sample
else:
ret = copy.deepcopy(self.held_sample)
self._worker_state_lock.release()
ret = self._sample_and_erase()
ret.to_device(to_device)
return ret

Expand Down
10 changes: 6 additions & 4 deletions applications/Chat/coati/ray/detached_trainer_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(
@torch.no_grad()
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
config['requires_grad_only'] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
Expand Down Expand Up @@ -197,9 +199,9 @@ def _get_unwrapped_critic(self):
return self.critic

def _get_model_state_dict_shard(self, model: torch.nn.Module, **config):
try:
self.strategy.merge_lora_weight(model)
except AttributeError:
pass
# try:
# self.strategy.merge_lora_weight(model)
# except AttributeError:
# pass
for state_dict in self.strategy.get_model_state_dict_shard(model, **config):
yield state_dict_to(state_dict)
1 change: 1 addition & 0 deletions applications/Chat/coati/ray/experience_maker_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def update_experience_maker(self,
if new_critic_state_dict is not None:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)


# the lock must be released after both actor and critic being updated
if chunk_end:
self._model_visit_lock.release()
Expand Down
1 change: 1 addition & 0 deletions applications/Chat/coati/ray/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Any, Callable, Dict, List, Optional
from collections import OrderedDict

import torch
import torch.distributed as dist
Expand Down
11 changes: 6 additions & 5 deletions applications/Chat/coati/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ def get_model_state_dict_shard(self, model: nn.Module, **config):
if self.stage != 3:
yield from super().get_model_state_dict_shard(model, **config)
else:
unwrapped_model = self._unwrap_model(model)
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
# unwrapped_model = self._unwrap_model(model)
# for module in unwrapped_model.modules():
# if isinstance(module, LoraLinear):
# module.merge_weights = True
# module.eval()
model: ZeroDDP = model
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
35 changes: 31 additions & 4 deletions applications/Chat/coati/trainer/strategies/naive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any, Optional

import sys
from typing import Any, Optional, Dict
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.nn as nn
Expand All @@ -14,6 +15,14 @@

from .base import Strategy

# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
def get_grad_required_state_dict(model: nn.Module):
state_dict = OrderedDict()
for name, parameter in model.named_parameters():
if parameter.requires_grad:
state_dict[name] = parameter.detach()
return state_dict


class NaiveStrategy(Strategy):
"""
Expand Down Expand Up @@ -81,8 +90,26 @@ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = No

def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
state_dict = model.state_dict()
yield state_dict
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()

if 'shard_size' in config:
shard_size = config['shard_size']
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():
state_dict_shard[name] = param
accumulate_size += param.numel() * param.element_size()
if accumulate_size >= shard_size:
accumulate_size = 0
yield state_dict_shard
state_dict_shard = OrderedDict()
if accumulate_size > 0:
yield state_dict_shard
else:
yield state_dict

def merge_lora_weight(self, model: nn.Module):
unwrapped_model = self._unwrap_model(model)
Expand Down
166 changes: 0 additions & 166 deletions applications/Chat/examples/ray/1m1t.py

This file was deleted.

23 changes: 0 additions & 23 deletions applications/Chat/examples/ray/1m1t.sh

This file was deleted.

Loading