From 81eeef7ec37543d5303876815334f3a9178fdb42 Mon Sep 17 00:00:00 2001 From: csric Date: Fri, 21 Apr 2023 17:25:00 +0800 Subject: [PATCH] split experience to send --- .../coati/ray/src/experience_maker_holder.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/applications/Chat/coati/ray/src/experience_maker_holder.py b/applications/Chat/coati/ray/src/experience_maker_holder.py index b1acdbb5494d..93624c2be921 100644 --- a/applications/Chat/coati/ray/src/experience_maker_holder.py +++ b/applications/Chat/coati/ray/src/experience_maker_holder.py @@ -10,6 +10,7 @@ import torch.nn as nn from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker from coati.models.base import Actor, Critic, RewardModel +from coati.replay_buffer.utils import split_experience_batch, make_experience_batch, BufferItem from coati.trainer.callbacks import Callback from coati.trainer.callbacks.performance_evaluator import ExperienceMakerPerformanceEvaluator from coati.trainer.strategies import Strategy @@ -47,6 +48,7 @@ def __init__(self, callbacks: List[Callback] = [], eval_performance: bool = False, debug: bool = False, + send_grain_size: int = 4, **generate_kwargs): # set environment variables if env_info: @@ -63,6 +65,7 @@ def __init__(self, self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef) self.callbacks = callbacks self.eval_performance = eval_performance + self.send_grain_size = send_grain_size self._model_visit_lock = Lock() self._initial_model_initialized = False @@ -165,7 +168,19 @@ def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None experience = self._make_experience(inputs=inputs) self._on_make_experience_end(experience) self._model_visit_lock.release() - self._send_experience(experience=experience) + # split experience for smoother handover + items = split_experience_batch(experience) + temp_buffer = [] + for item in items: + temp_buffer.append(item) + if len(temp_buffer) >= self.send_grain_size: + experience_fragment = make_experience_batch(temp_buffer) + self._send_experience(experience=experience_fragment) + temp_buffer = [] + # remain + if len(temp_buffer) > 0: + experience_fragment = make_experience_batch(temp_buffer) + self._send_experience(experience=experience_fragment) self._on_finish() @ray.method(concurrency_group="model_io")