From 729456957a708529d8c09b3697ad85b8824e2b12 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 11:35:23 +0800 Subject: [PATCH 01/11] update help information --- applications/ColossalChat/rl_example.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 4a4a4c3404e9..5e7af5c192d9 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,13 +10,13 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) - parser.add_argument("-g", "--num-generations", type=int, default=8) - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) - parser.add_argument("-b", "--backend", type=str, default="transformers") + parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step.") + parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8, help="Number of prompts to send from the producer to the consumer.") + parser.add_argument("-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model.") + parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.") + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers, vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() From 27d0108c9f8aab90288ef0d841ad27fb032fd6bf Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 11:44:10 +0800 Subject: [PATCH 02/11] update style --- applications/ColossalChat/rl_example.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5e7af5c192d9..7ff9bd20d2d4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -11,10 +11,26 @@ parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") - parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step.") - parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8, help="Number of prompts to send from the producer to the consumer.") - parser.add_argument("-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model.") - parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.") + parser.add_argument( + "-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step." + ) + parser.add_argument( + "-imbs", + "--inference-microbatch-size", + type=int, + default=8, + help="Number of prompts to send from the producer to the consumer.", + ) + parser.add_argument( + "-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model." + ) + parser.add_argument( + "-tMbs", + "--train-minibatch-size", + type=int, + default=1, + help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", + ) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers, vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) From 9d6ede98dcc9470f7db55af03060d5f72c4d29c4 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 13:14:41 +0800 Subject: [PATCH 03/11] fix --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 7ff9bd20d2d4..bb7b848267e7 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -32,7 +32,7 @@ help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") - parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers, vllm"]) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers","vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() From 6604654b2f3a37bea2a1dc314498f498aa1b6dde Mon Sep 17 00:00:00 2001 From: Tong Li Date: Mon, 31 Mar 2025 13:15:18 +0800 Subject: [PATCH 04/11] minor fix --- applications/ColossalChat/rl_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bb7b848267e7..bb719a13c405 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -32,7 +32,7 @@ help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") - parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers","vllm"]) + parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() From d961a5f7251f377856fc2def74dc1de98269fe49 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 1 Apr 2025 11:24:09 +0800 Subject: [PATCH 05/11] support PP training --- .../coati/distributed/consumer.py | 1 - .../coati/distributed/grpo_consumer.py | 166 +++++++++++------- applications/ColossalChat/rl_example.py | 16 +- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- colossalai/shardformer/modeling/qwen2.py | 1 + 5 files changed, 121 insertions(+), 69 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 027acc2e7537..4e1cd1f3179a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -54,7 +54,6 @@ def __init__( self.model_config = model_config self.plugin_config = plugin_config - assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now" self.device = get_current_device() self.lr_scheduler = None diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 4174f96514b8..d05709febf52 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -96,7 +96,7 @@ def __init__( self.global_step = 0 if use_wandb and self.rank == 0: name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -168,72 +168,120 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): - input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] - if loss_mask is not None - else None - ) - advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - policy_model_logits = self.policy_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + if self.plugin.pp_size > 1: + # Support training with PP. + data_iter = iter([data]) with torch.no_grad(): - reference_model_logits = self.reference_model( + reference_model_outputs = self.booster.execute_pipeline( + data_iter, + self.reference_model, + criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + data["input_ids"], + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = torch.zeros( + (old_action_log_probs.size(0), old_action_log_probs.size(1)) + ) + + data["reference_action_log_probs"] = reference_action_log_probs + + data_iter = iter([data]) + + def _criterion(outputs, inputs): + pass + + outputs = self.booster.execute_pipeline( + data_iter, + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + ) + loss = outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + else: + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask_forward_micro_batch, - loss_mask=loss_mask_forward_micro_batch, - ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) - # Calculate accumulate value. - mean_kl.append(kl.data) - mean_loss.append(loss.data) + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) reward = all_reduce_mean(reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bb719a13c405..2b6faaa4ab90 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -31,7 +31,13 @@ default=1, help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", ) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2, help="Number of samples per device.") + parser.add_argument( + "-tmbs", + "--train-microbatch-size", + type=int, + default=2, + help="Number of samples per device. PP micro batchsize when PP is activated.", + ) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() @@ -45,11 +51,7 @@ ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict( - path=args.model, - # use_flash_attention_2=True, - # use_cache=False - ) + train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -107,7 +109,7 @@ generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, + plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0}, inference_backend=args.backend, master_addr="localhost", master_port=29505, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..74349091b4d4 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1411,8 +1411,10 @@ def execute_pipeline( ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + not torch.is_grad_enabled() + or model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) ): return outputs diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 71e3557fe214..27571309e453 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -284,6 +284,7 @@ def qwen2_for_causal_lm_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, + **kwargs, ): r""" Args: From 09a3173a4920ee64292b9c638611adaa3c2d427f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 4 Apr 2025 10:05:16 +0800 Subject: [PATCH 06/11] add pp support --- .gitignore | 1 + .../coati/distributed/consumer.py | 1 - .../coati/distributed/grpo_consumer.py | 311 +++++++++++------- applications/ColossalChat/rl_example.py | 9 +- 4 files changed, 200 insertions(+), 122 deletions(-) diff --git a/.gitignore b/.gitignore index 16f764c1b1ef..533450a7cce1 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ coverage.xml applications/ColossalChat/logs applications/ColossalChat/tests/logs applications/ColossalChat/wandb +applications/ColossalChat/model diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 4e1cd1f3179a..c6ae7be2daf9 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -94,7 +94,6 @@ def loop(self) -> None: i = 0 for _ in range(self.num_recv_per_update): # receive data from producers - for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") self.buffer.extend( diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index d05709febf52..fbc06edc2aa1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -94,9 +94,7 @@ def __init__( self.policy_loss_fn = PolicyLoss() self.global_step = 0 - if use_wandb and self.rank == 0: - name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.use_wandb = use_wandb self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -107,10 +105,19 @@ def __init__( def setup(self): super().setup() + if self.use_wandb and ( + (not self.plugin.pp_size > 1 and self.rank == 0) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + ): + # Initialize wandb. + name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" + self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) self.reference_model, *_ = self.booster.boost(self.reference_model) + self.plugin.logger.set_level("ERROR") def step(self, step_idx: int, **kwargs) -> Optional[float]: """ @@ -168,72 +175,130 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ).repeat_interleave(self.num_generations, dim=0) ) mean_kl, mean_loss = [], [] - if self.plugin.pp_size > 1: - # Support training with PP. - data_iter = iter([data]) - - with torch.no_grad(): - reference_model_outputs = self.booster.execute_pipeline( - data_iter, - self.reference_model, - criterion=lambda outputs, inputs: outputs.logits.mean(), # dummy criterion - optimizer=None, - return_loss=False, - return_outputs=True, - ) - - if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - else: - # Dummy reference logprobs for data iterator. - reference_action_log_probs = torch.zeros( - (old_action_log_probs.size(0), old_action_log_probs.size(1)) - ) - - data["reference_action_log_probs"] = reference_action_log_probs - data_iter = iter([data]) + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] - def _criterion(outputs, inputs): - pass + if self.plugin.pp_size > 1: + # Support training with PP. - outputs = self.booster.execute_pipeline( - data_iter, - self.policy_model, - criterion=_criterion, - optimizer=self.optimizer, - return_loss=True, - ) - loss = outputs["loss"] + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) + + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None + + data_policy_forward = { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + "action_mask": action_mask_forward_micro_batch, + "reference_action_log_probs": reference_action_log_probs, + "advantages": advantages_forward_micro_batch, + "loss_mask": loss_mask_forward_micro_batch, + "source": self.rank, + } - if self.booster.plugin.stage_manager.is_last_stage(): - loss = all_reduce_mean(loss, self.plugin) - mean_loss.append(loss.data) - else: - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): - input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] - loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] - if loss_mask is not None - else None + def _criterion(outputs, inputs): + action_logits = outputs.logits + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + inputs["input_ids"], + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + decode_tokens_100 = self.tokenizer.batch_decode( + input_ids_forward_micro_batch[:, -num_action:], + skip_special_tokens=False, + ) + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + action_log_probs, + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + inputs["action_mask"], + loss_mask=inputs["loss_mask"], + ) + return loss + + policy_model_outputs = self.booster.execute_pipeline( + iter([data_policy_forward]), + self.policy_model, + criterion=_criterion, + optimizer=self.optimizer, + return_loss=True, + return_outputs=True, ) - advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size - ] + loss = policy_model_outputs["loss"] + + if self.booster.plugin.stage_manager.is_last_stage(): + # calculate kl + action_logits = policy_model_outputs["outputs"]["logits"] + action_log_probs = calc_action_log_probs( + action_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + kl = all_reduce_mean(kl.mean(), self.plugin) + loss = all_reduce_mean(loss, self.plugin) + mean_loss.append(loss.data) + mean_kl.append(kl) + else: + policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, attention_mask=attention_mask_forward_micro_batch, @@ -256,7 +321,6 @@ def _criterion(outputs, inputs): num_action, self.plugin.shard_config, ) - per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) - (reference_action_log_probs - action_log_probs) @@ -282,64 +346,71 @@ def _criterion(outputs, inputs): # Calculate accumulate value. mean_kl.append(kl.data) mean_loss.append(loss.data) - - reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) - advantages = all_reduce_mean(advantages.mean(), self.plugin) - response_length = all_reduce_mean(response_length.mean(), self.plugin) - self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) - self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) - self.accum_count += 1 + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + reward = all_reduce_mean(reward.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + self.accum_reward.add_(reward.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) + self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() - loss_scalar = self.accum_loss.item() - if self.rank == 0: - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } - ) - self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() - self.accum_kl.zero_() - self.accum_advantages.zero_() - self.accum_response_length.zero_() - - self.accum_count = 0 - return loss_scalar + if not self.plugin.pp_size > 1 or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + loss_scalar = self.accum_loss.item() + if (not self.plugin.pp_size > 1 and self.rank == 0) or ( + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + ): + print( + "Loss:", + self.accum_loss.item() / self.accum_count, + "\nReward:", + self.accum_reward.item() / self.accum_count, + "\nFormat Reward:", + self.accum_format_reward.item() / self.accum_count, + "\nAcc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", + self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, + ) + self.wandb_run.log( + { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + ) + self.accum_loss.zero_() + self.accum_reward.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() + + self.accum_count = 0 + return loss_scalar def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 2b6faaa4ab90..bf7a657e56c4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -109,7 +109,14 @@ generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={"pp_size": 2, "tp_size": 1, "microbatch_size": 2, "zero_stage": 0}, + # plugin_config={}, # for zero + plugin_config={ + "pp_size": 2, + "tp_size": 1, + "microbatch_size": args.train_microbatch_size // 2, + "zero_stage": 0, + "max_norm": 1.0, + }, # for pp inference_backend=args.backend, master_addr="localhost", master_port=29505, From 061d8cb3b6485787e7c2f75868342b3a243c44e1 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 4 Apr 2025 10:11:11 +0800 Subject: [PATCH 07/11] remove unused code --- .../ColossalChat/coati/distributed/grpo_consumer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index fbc06edc2aa1..f4174261ad2c 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -252,10 +252,6 @@ def _criterion(outputs, inputs): - (inputs["reference_action_log_probs"] - action_log_probs) - 1 ) - decode_tokens_100 = self.tokenizer.batch_decode( - input_ids_forward_micro_batch[:, -num_action:], - skip_special_tokens=False, - ) loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -277,7 +273,7 @@ def _criterion(outputs, inputs): loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): - # calculate kl + # calculate kl, as we cannot do this inside callback, kl needs be calculate again action_logits = policy_model_outputs["outputs"]["logits"] action_log_probs = calc_action_log_probs( action_logits / self.generate_config["temperature"], From a40d82f6293113cc73ebac7c2d14f49caf5e41a4 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 9 Apr 2025 12:53:40 +0800 Subject: [PATCH 08/11] address conversation --- .../coati/distributed/grpo_consumer.py | 31 +++++++------------ .../ColossalChat/coati/distributed/launch.py | 2 ++ applications/ColossalChat/rl_example.py | 22 +++++++++---- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f4174261ad2c..a282439cbd33 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -39,6 +39,7 @@ def __init__( use_wandb=True, generate_config=None, training_config={}, + project_name=None, ): super().__init__( num_producers, @@ -69,6 +70,7 @@ def __init__( self.accum_count = 0 self.generate_config = generate_config self.training_config = training_config + self.project_name = project_name # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -111,7 +113,7 @@ def setup(self): ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1-PP", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name) self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler @@ -239,6 +241,8 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: "source": self.rank, } + kl = [] + def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = calc_action_log_probs( @@ -252,6 +256,10 @@ def _criterion(outputs, inputs): - (inputs["reference_action_log_probs"] - action_log_probs) - 1 ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) loss, skip_update, _ = self.policy_loss_fn( action_log_probs, action_log_probs, @@ -273,26 +281,11 @@ def _criterion(outputs, inputs): loss = policy_model_outputs["loss"] if self.booster.plugin.stage_manager.is_last_stage(): - # calculate kl, as we cannot do this inside callback, kl needs be calculate again - action_logits = policy_model_outputs["outputs"]["logits"] - action_log_probs = calc_action_log_probs( - action_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) - kl = all_reduce_mean(kl.mean(), self.plugin) + if len(kl) > 0: + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + mean_kl.append(kl) loss = all_reduce_mean(loss, self.plugin) mean_loss.append(loss.data) - mean_kl.append(kl) else: policy_model_logits = self.policy_model( diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ba5d3a9d4fd8..699d90a8cdff 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -47,6 +47,7 @@ def launch_distributed( master_addr: str = "localhost", master_port: int = 29500, core_algo: str = "GRPO", + project_name: Optional[str] = None, ): if core_algo not in ALGO_MAP: @@ -108,6 +109,7 @@ def launch_distributed( "train_microbatch_size": train_microbatch_size, }, num_generations=num_generations, + project_name=project_name, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bf7a657e56c4..f87f12ed23ca 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -11,32 +11,41 @@ parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument( - "-ibs", "--inference-batch-size", type=int, default=64, help="Number of prompts to generate per step." + "-ibs", + "--inference-batch-size", + type=int, + default=64, + help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", ) parser.add_argument( "-imbs", "--inference-microbatch-size", type=int, default=8, - help="Number of prompts to send from the producer to the consumer.", + help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.", ) parser.add_argument( - "-tbs", "--train-batch-size", type=int, default=32, help="Number of prompts to update policy model." + "-tbs", + "--train-batch-size", + type=int, + default=32, + help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples", ) parser.add_argument( "-tMbs", "--train-minibatch-size", type=int, default=1, - help="Number of prompts per device. Number of samples = tMbs * num of generation per prompt.", + help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", ) parser.add_argument( "-tmbs", "--train-microbatch-size", type=int, default=2, - help="Number of samples per device. PP micro batchsize when PP is activated.", + help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) @@ -119,6 +128,7 @@ }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29505, + master_port=29506, core_algo=args.algo, + project_name=args.project, ) From 1ea3b72c2299e7c6cf82a0565fbe8ee37a9340b3 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 9 Apr 2025 17:11:55 +0800 Subject: [PATCH 09/11] fix memory leakage support tp+pp --- .../ColossalChat/coati/distributed/consumer.py | 4 ++++ .../coati/distributed/grpo_consumer.py | 14 +++++++------- applications/ColossalChat/rl_example.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index c6ae7be2daf9..6372e2c3a898 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -72,6 +72,8 @@ def setup(self) -> None: self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) + self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.dp_size = dist.get_world_size(self.plugin.dp_group) self.buffer = [] @@ -132,6 +134,8 @@ def loop(self) -> None: ray_broadcast_tensor_dict( state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) + del state_dict + torch.cuda.empty_cache() @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index a282439cbd33..68a01528e2d1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -109,7 +109,7 @@ def setup(self): super().setup() if self.use_wandb and ( (not self.plugin.pp_size > 1 and self.rank == 0) - or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) + or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0) ): # Initialize wandb. name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" @@ -282,10 +282,10 @@ def _criterion(outputs, inputs): if self.booster.plugin.stage_manager.is_last_stage(): if len(kl) > 0: - kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) + kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) - loss = all_reduce_mean(loss, self.plugin) - mean_loss.append(loss.data) + mean_loss.append(all_reduce_mean(loss, self.plugin).data) + torch.cuda.empty_cache() else: policy_model_logits = self.policy_model( @@ -336,7 +336,7 @@ def _criterion(outputs, inputs): mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) @@ -355,11 +355,11 @@ def _criterion(outputs, inputs): self.optimizer.step() self.optimizer.zero_grad() if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): loss_scalar = self.accum_loss.item() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() + self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): print( "Loss:", diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index f87f12ed23ca..6c43ccd1960f 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -121,7 +121,7 @@ # plugin_config={}, # for zero plugin_config={ "pp_size": 2, - "tp_size": 1, + "tp_size": 2, "microbatch_size": args.train_microbatch_size // 2, "zero_stage": 0, "max_norm": 1.0, From f7e532511ce766bb1f5661c397fb468c3bad5523 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 10 Apr 2025 10:22:28 +0800 Subject: [PATCH 10/11] move empty cache --- applications/ColossalChat/coati/distributed/grpo_consumer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 68a01528e2d1..e23254d1b07a 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -285,7 +285,6 @@ def _criterion(outputs, inputs): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) - torch.cuda.empty_cache() else: policy_model_logits = self.policy_model( From 1723a0286023132437001773865f1410e9f5e4a0 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 10 Apr 2025 10:22:43 +0800 Subject: [PATCH 11/11] move empty cache --- applications/ColossalChat/coati/distributed/consumer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 6372e2c3a898..79beb2a2dba6 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -129,6 +129,7 @@ def loop(self) -> None: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() state_dict = self.state_dict() if self.rank == 0: ray_broadcast_tensor_dict(