From 0a3b8b050d2398a20cf4601030678bc70ade02f6 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 21 Mar 2025 18:30:44 +0800 Subject: [PATCH 1/8] add microbatch forwarding --- .../coati/distributed/grpo_consumer.py | 16 +++++++++++----- .../ColossalChat/coati/distributed/launch.py | 2 +- .../ColossalChat/coati/distributed/utils.py | 17 +++++++++++++++++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 1c0773f4e9bc..cd45c3c3c5ea 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -11,7 +11,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward -from coati.distributed.utils import calc_action_log_probs +from coati.distributed.utils import calc_action_log_probs, get_logits_rebatched_forward from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer @@ -68,6 +68,7 @@ def __init__( self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config + self.training_config = training_config # Reference model is initialized from policy model. self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -131,14 +132,17 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) + forward_batch_size = self.training_config.get("forward_micro_batch_size", data["input_ids"].size(0)) need_update = (step_idx + 1) % self.num_microbatches == 0 ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: - policy_model_logits = self.policy_model( + policy_model_logits = get_logits_rebatched_forward( + self.policy_model, + forward_batch_size, input_ids=data["input_ids"], attention_mask=data["attention_mask"], - )["logits"] + ) action_log_probs = calc_action_log_probs( policy_model_logits / self.generate_config["temperature"], data["input_ids"], @@ -147,10 +151,12 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ) with torch.no_grad(): - reference_model_logits = self.reference_model( + reference_model_logits = get_logits_rebatched_forward( + self.reference_model, + forward_batch_size, input_ids=data["input_ids"], attention_mask=data["attention_mask"], - )["logits"] + ) reference_action_log_probs = calc_action_log_probs( reference_model_logits / self.generate_config["temperature"], data["input_ids"], diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index c50db1378e16..eb97fbab3424 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -101,7 +101,7 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_microbatch_size, generate_config=generate_config_consumer, - training_config={"filter_range": [0.05, 9.0], "lr": 1e-6}, + training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": 4}, num_generations=num_generations, ) procs.append(consumer) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 919e4434faa6..1a559bd5672e 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean + + +def get_logits_rebatched_forward(model, batch_size, input_ids, attention_mask): + """ + Get logits from the model with rebatched forward. + Args: + model (torch.nn.Module): The model. + batch_size (int): The batch size. + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + """ + logits = [] + for i in range(0, input_ids.size(0), batch_size): + logits.append( + model(input_ids=input_ids[i : i + batch_size], attention_mask=attention_mask[i : i + batch_size])["logits"] + ) + return torch.cat(logits, dim=0) From 2434861044304363deaca56f0a613dc824c63e6d Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 24 Mar 2025 17:46:20 +0800 Subject: [PATCH 2/8] fix forward microbatch --- .../coati/distributed/consumer.py | 7 +- .../coati/distributed/grpo_consumer.py | 139 ++++++++++-------- .../ColossalChat/coati/distributed/utils.py | 17 --- applications/ColossalChat/rl_example.py | 9 +- 4 files changed, 89 insertions(+), 83 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index de289738347c..027acc2e7537 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,12 +66,7 @@ def setup(self) -> None: cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) - plugin_config = dict( - tp_size=1, - pp_size=1, - precision="bf16", - zero_stage=1, - ) + plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size plugin_config.update(self.plugin_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index cd45c3c3c5ea..46e7c2fa0ec5 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -11,7 +11,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward -from coati.distributed.utils import calc_action_log_probs, get_logits_rebatched_forward +from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer @@ -96,7 +96,7 @@ def __init__( self.global_step = 0 if use_wandb and self.rank == 0: name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project="GRPO-V1-debug", sync_tensorboard=True, dir="./wandb", name=name) self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, @@ -135,42 +135,13 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: forward_batch_size = self.training_config.get("forward_micro_batch_size", data["input_ids"].size(0)) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) + # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 + ctx = ( + nullcontext() + if need_update or self.booster.plugin.zero_stage == 2 + else self.booster.no_sync(self.policy_model, self.optimizer) + ) with ctx: - policy_model_logits = get_logits_rebatched_forward( - self.policy_model, - forward_batch_size, - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - ) - action_log_probs = calc_action_log_probs( - policy_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - - with torch.no_grad(): - reference_model_logits = get_logits_rebatched_forward( - self.reference_model, - forward_batch_size, - input_ids=data["input_ids"], - attention_mask=data["attention_mask"], - ) - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - data["input_ids"], - num_action, - self.plugin.shard_config, - ) - - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) - reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) @@ -183,6 +154,11 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: group_reward = reward.view(-1, self.num_generations) reward_mean = group_reward.mean(dim=1) + # [batch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [batch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), loss_mask = ( None @@ -191,35 +167,82 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] ).repeat_interleave(self.num_generations, dim=0) ) + mean_kl, mean_loss = [], [] + for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + input_ids_forward_micro_batch = data["input_ids"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + attention_mask_forward_micro_batch = data["attention_mask"][ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + action_mask_forward_micro_batch = action_mask[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + loss_mask_forward_micro_batch = ( + loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + if loss_mask is not None + else None + ) + advantages_forward_micro_batch = advantages[ + forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + ] + policy_model_logits = self.policy_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + action_log_probs = calc_action_log_probs( + policy_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) - # [batch_size x num_generations] - reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] - advantages = (reward - reward_mean) / (reward_std + 1e-4) - - loss, skip_update, _ = self.policy_loss_fn( - action_log_probs, - old_action_log_probs, - advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1), - per_token_kl, - action_mask, - loss_mask=loss_mask, - ) + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + + loss, skip_update, _ = self.policy_loss_fn( + action_log_probs, + old_action_log_probs, + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + per_token_kl, + action_mask_forward_micro_batch, + loss_mask=loss_mask_forward_micro_batch, + ) + + if not skip_update: + self.booster.backward(loss, self.optimizer) + loss = all_reduce_mean(loss, self.plugin) + kl = all_reduce_mean(kl.mean(), self.plugin) + # Calculate accumulate value. + mean_kl.append(kl.data) + mean_loss.append(loss.data) - if not skip_update: - self.booster.backward(loss, self.optimizer) - loss = all_reduce_mean(loss, self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) - # Calculate accumulate value. - self.accum_loss.add_(loss.data) + self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) - self.accum_kl.add_(kl.data) self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) self.accum_advantages.add_(advantages.data) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 1a559bd5672e..919e4434faa6 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -113,20 +113,3 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean - - -def get_logits_rebatched_forward(model, batch_size, input_ids, attention_mask): - """ - Get logits from the model with rebatched forward. - Args: - model (torch.nn.Module): The model. - batch_size (int): The batch size. - input_ids (torch.Tensor): The input ids. - attention_mask (torch.Tensor): The attention mask. - """ - logits = [] - for i in range(0, input_ids.size(0), batch_size): - logits.append( - model(input_ids=input_ids[i : i + batch_size], attention_mask=attention_mask[i : i + batch_size])["logits"] - ) - return torch.cat(logits, dim=0) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index a67a10bc5b35..04508ea2463e 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -33,7 +33,6 @@ ) train_model_config.update( dict( - use_flash_attention_2=True, torch_dtype=torch.bfloat16, use_cache=False, ) @@ -49,6 +48,12 @@ ) elif args.backend == "vllm": inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) + train_model_config.update( + dict( + use_flash_attention_2=True, + use_cache=False, + ) + ) generate_config.update( dict( max_tokens=2048, @@ -87,6 +92,6 @@ plugin_config={}, inference_backend=args.backend, master_addr="localhost", - master_port=29503, + master_port=29505, core_algo=args.algo, ) From b0ac3aa9b9e4dfba0e6554b3b85c42fc1f817624 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 26 Mar 2025 16:22:01 +0800 Subject: [PATCH 3/8] fix producer OOM --- applications/ColossalChat/coati/distributed/launch.py | 3 ++- applications/ColossalChat/coati/distributed/producer.py | 6 ++++++ applications/ColossalChat/rl_example.py | 2 ++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index eb97fbab3424..80ba76ca3ca5 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -33,6 +33,7 @@ def launch_distributed( inference_batch_size: int, inference_microbatch_size: int, train_batch_size: int, + forward_micro_batch_size: int, train_microbatch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], @@ -101,7 +102,7 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_microbatch_size, generate_config=generate_config_consumer, - training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": 4}, + training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": forward_micro_batch_size}, num_generations=num_generations, ) procs.append(consumer) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 51a1af332f25..ae45d1ac2710 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -3,6 +3,7 @@ import ray import ray.util.collective as cc import torch +import torch.distributed as dist from coati.dataset.loader import RawConversationDataset from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -100,6 +101,7 @@ def loop(self) -> None: if i >= num_valid_microbatches: break outputs = self.rollout(**batch) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config.temperature] * outputs["input_ids"].size(0) @@ -116,10 +118,13 @@ def loop(self) -> None: print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) + state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) self.load_state_dict(state_dict) + del state_dict + torch.cuda.empty_cache() # linear annealing for 1 episode, temperature from initial to 0.7 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) @@ -172,3 +177,4 @@ def rollout(self, input_ids, attention_mask, **kwargs): def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) + diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 04508ea2463e..3b42a8303815 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -14,6 +14,7 @@ parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) + parser.add_argument("-fmb", "--forward-micro-batch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() @@ -84,6 +85,7 @@ inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_microbatch_size=args.train_microbatch_size, + forward_micro_batch_size=args.forward_micro_batch_size, dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, From 701c5fa79a6ac4f992c2300e9ce003be057b2f8c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Mar 2025 08:24:33 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/distributed/launch.py | 6 +++++- applications/ColossalChat/coati/distributed/producer.py | 6 ++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 80ba76ca3ca5..d2de41e03845 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -102,7 +102,11 @@ def launch_distributed( plugin_config=plugin_config, microbatch_size=train_microbatch_size, generate_config=generate_config_consumer, - training_config={"filter_range": [0.05, 9.0], "lr": 1e-6, "forward_micro_batch_size": forward_micro_batch_size}, + training_config={ + "filter_range": [0.05, 9.0], + "lr": 1e-6, + "forward_micro_batch_size": forward_micro_batch_size, + }, num_generations=num_generations, ) procs.append(consumer) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index ae45d1ac2710..737a03cde6a8 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -3,7 +3,6 @@ import ray import ray.util.collective as cc import torch -import torch.distributed as dist from coati.dataset.loader import RawConversationDataset from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer @@ -101,7 +100,7 @@ def loop(self) -> None: if i >= num_valid_microbatches: break outputs = self.rollout(**batch) - + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config.temperature] * outputs["input_ids"].size(0) @@ -118,7 +117,7 @@ def loop(self) -> None: print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) - + state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" ) @@ -177,4 +176,3 @@ def rollout(self, input_ids, attention_mask, **kwargs): def load_state_dict(self, state_dict): self.model.load_state_dict(state_dict) - From 3d77ff45963c55260de7c3c9b84c179591438aed Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 26 Mar 2025 17:22:32 +0800 Subject: [PATCH 5/8] change project name --- applications/ColossalChat/coati/distributed/grpo_consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 46e7c2fa0ec5..1bd3436d6e56 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -96,7 +96,7 @@ def __init__( self.global_step = 0 if use_wandb and self.rank == 0: name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}" - self.wandb_run = wandb.init(project="GRPO-V1-debug", sync_tensorboard=True, dir="./wandb", name=name) + self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name) self.lr_scheduler = CosineAnnealingWarmupLR( optimizer=self.optimizer, From 51b2d00d3b473d103fbf91461a4c4a749d917ac3 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 27 Mar 2025 11:40:52 +0800 Subject: [PATCH 6/8] fix temperature annealing --- applications/ColossalChat/coati/distributed/consumer.py | 1 - applications/ColossalChat/coati/distributed/producer.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 027acc2e7537..ef5ef055235c 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -135,7 +135,6 @@ def loop(self) -> None: state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) - @ray.remote class SimpleConsumer(BaseConsumer): def __init__( diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 737a03cde6a8..f744d326198d 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -128,7 +128,7 @@ def loop(self) -> None: if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) self.model.generate_config.temperature = ( - ratio * self.generate_config["temperature"] + (1 - ratio) * 0.7 + (1 - ratio) * self.generate_config["temperature"] + ratio * 0.7 ) From e7bdd846ae9986f93e27b2dc102eab2c73fcd995 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Mar 2025 03:42:15 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalChat/coati/distributed/consumer.py | 1 + applications/ColossalChat/coati/distributed/producer.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index ef5ef055235c..027acc2e7537 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -135,6 +135,7 @@ def loop(self) -> None: state_dict, src=self.num_producers, device=self.device, group_name="sync_model" ) + @ray.remote class SimpleConsumer(BaseConsumer): def __init__( diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index f744d326198d..2c6a24a36711 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -127,9 +127,9 @@ def loop(self) -> None: # linear annealing for 1 episode, temperature from initial to 0.7 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - self.model.generate_config.temperature = ( - (1 - ratio) * self.generate_config["temperature"] + ratio * 0.7 - ) + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.7 @ray.remote From 9172ee48a4ff5d290276319b49512f73b2c97631 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Fri, 28 Mar 2025 09:44:05 +0800 Subject: [PATCH 8/8] address conversation --- .../coati/distributed/grpo_consumer.py | 2 +- .../ColossalChat/coati/distributed/launch.py | 6 ++-- .../coati/distributed/producer.py | 6 ++-- applications/ColossalChat/rl_example.py | 32 +++++++++---------- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 1bd3436d6e56..4174f96514b8 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -132,7 +132,7 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) - forward_batch_size = self.training_config.get("forward_micro_batch_size", data["input_ids"].size(0)) + forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) need_update = (step_idx + 1) % self.num_microbatches == 0 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index d2de41e03845..ba5d3a9d4fd8 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -33,8 +33,8 @@ def launch_distributed( inference_batch_size: int, inference_microbatch_size: int, train_batch_size: int, - forward_micro_batch_size: int, train_microbatch_size: int, + train_minibatch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], inference_model_config: Dict[str, Any], @@ -100,12 +100,12 @@ def launch_distributed( batch_size=train_batch_size, model_config=train_model_config, plugin_config=plugin_config, - microbatch_size=train_microbatch_size, + microbatch_size=train_minibatch_size, generate_config=generate_config_consumer, training_config={ "filter_range": [0.05, 9.0], "lr": 1e-6, - "forward_micro_batch_size": forward_micro_batch_size, + "train_microbatch_size": train_microbatch_size, }, num_generations=num_generations, ) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index f744d326198d..2c6a24a36711 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -127,9 +127,9 @@ def loop(self) -> None: # linear annealing for 1 episode, temperature from initial to 0.7 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - self.model.generate_config.temperature = ( - (1 - ratio) * self.generate_config["temperature"] + ratio * 0.7 - ) + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.7 @ray.remote diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 3b42a8303815..4a4a4c3404e9 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,19 +10,30 @@ parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) + parser.add_argument("-g", "--num-generations", type=int, default=8) parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64) parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8) parser.add_argument("-tbs", "--train-batch-size", type=int, default=32) - parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1) - parser.add_argument("-fmb", "--forward-micro-batch-size", type=int, default=2) + parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1) + parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2) parser.add_argument("-b", "--backend", type=str, default="transformers") parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) args = parser.parse_args() + assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" + assert ( + args.train_minibatch_size * args.num_generations >= args.train_microbatch_size + and args.train_microbatch_size > 0 + ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + ray.init(address="local", namespace="ray-example") inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model) + train_model_config = dict( + path=args.model, + # use_flash_attention_2=True, + # use_cache=False + ) generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) if args.backend == "transformers": @@ -32,12 +43,6 @@ torch_dtype=torch.bfloat16, ) ) - train_model_config.update( - dict( - torch_dtype=torch.bfloat16, - use_cache=False, - ) - ) generate_config.update( dict( max_length=1024 + 512, @@ -49,12 +54,6 @@ ) elif args.backend == "vllm": inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) - train_model_config.update( - dict( - use_flash_attention_2=True, - use_cache=False, - ) - ) generate_config.update( dict( max_tokens=2048, @@ -84,12 +83,13 @@ inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, + train_minibatch_size=args.train_minibatch_size, train_microbatch_size=args.train_microbatch_size, - forward_micro_batch_size=args.forward_micro_batch_size, dataset_config={"path": args.dataset, "max_length": 300}, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, + num_generations=args.num_generations, train_model_config=train_model_config, plugin_config={}, inference_backend=args.backend,