diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index ee72e029093c..40d362340474 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -101,6 +101,7 @@ def __init__(
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
+ adv=grpo_config.get("algo"),
)
# Reference model is initialized from policy model.
@@ -137,6 +138,8 @@ def __init__(
eta_min=0.1 * grpo_config.get("lr", 1e-6),
)
+ self.adv = grpo_config.get("algo")
+
def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
@@ -204,9 +207,23 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
- reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
- # [minibatch_size x num_generations]
- advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
+ if self.adv == "GRPO" or self.adv == "DAPO":
+
+ reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
+ # [minibatch_size x num_generations]
+ advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
+
+ elif self.adv == "REINFORCE_PPB":
+
+ # [minibatch_size x num_generations]
+ advantages = ((reward - reward_mean)).unsqueeze(dim=-1)
+
+ elif self.adv == "RLOO":
+
+ advantages = (
+ reward * self.num_generations / (self.num_generations - 1)
+ - reward_mean * self.num_generations / (self.num_generations - 1)
+ ).unsqueeze(dim=-1)
# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
@@ -358,10 +375,34 @@ def _criterion(outputs, inputs):
per_token_kl = 0.0
kl.append(torch.tensor(0.0))
+ inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1)
+
+ if self.adv == "REINFORCE_PPB":
+
+ inputs["advantages"] = inputs["advantages"] - self.policy_loss_fn.beta * per_token_kl
+ advantages_forward_micro_batch_mean = torch.sum(
+ inputs["advantages"] * inputs["action_mask"]
+ ) / (torch.sum(inputs["action_mask"]) + 1e-4)
+ advantages_forward_micro_batch_std = torch.rsqrt(
+ torch.sum(
+ (inputs["advantages"] - advantages_forward_micro_batch_mean) ** 2
+ * inputs["action_mask"]
+ )
+ / (torch.sum(inputs["action_mask"]) + 1e-4)
+ + 1e-8
+ )
+ inputs["advantages"] = (
+ (inputs["advantages"] - advantages_forward_micro_batch_mean)
+ * inputs["action_mask"]
+ / (advantages_forward_micro_batch_std)
+ )
+
+ per_token_kl = 0.0
+
loss, _ = self.policy_loss_fn(
action_log_probs,
inputs["old_action_log_probs"],
- inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
+ inputs["advantages"],
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
@@ -420,10 +461,39 @@ def _criterion(outputs, inputs):
per_token_kl = 0.0
kl = None
+ (
+ advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1)
+ - self.policy_loss_fn.beta * per_token_kl
+ )
+
+ if self.adv == "REINFORCE_PPB":
+
+ advantages_forward_micro_batch = (
+ advantages_forward_micro_batch - self.policy_loss_fn.beta * per_token_kl
+ )
+ advantages_forward_micro_batch_mean = torch.sum(
+ advantages_forward_micro_batch * action_mask_forward_micro_batch
+ ) / (torch.sum(action_mask_forward_micro_batch) + 1e-4)
+ advantages_forward_micro_batch_std = torch.rsqrt(
+ torch.sum(
+ (advantages_forward_micro_batch - advantages_forward_micro_batch_mean) ** 2
+ * action_mask_forward_micro_batch
+ )
+ / (torch.sum(action_mask_forward_micro_batch) + 1e-4)
+ + 1e-8
+ )
+ advantages_forward_micro_batch = (
+ (advantages_forward_micro_batch - advantages_forward_micro_batch_mean)
+ * action_mask_forward_micro_batch
+ / (advantages_forward_micro_batch_std)
+ )
+
+ per_token_kl = 0.0
+
loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs_micro_batch,
- advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
+ advantages_forward_micro_batch,
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index a48246c87526..d60312e2b0b1 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -9,7 +9,13 @@
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
-ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
+ALGO_MAP = {
+ "Simple": SimpleConsumer,
+ "GRPO": GRPOConsumer,
+ "DAPO": GRPOConsumer,
+ "REINFORCE_PPB": GRPOConsumer,
+ "RLOO": GRPOConsumer,
+}
def get_jsonl_size_fast(path: str) -> int:
@@ -66,6 +72,7 @@ def launch_distributed(
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
+
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = train_dataset_config["path"]
diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py
index 36057b24faf5..ab38f987f65a 100644
--- a/applications/ColossalChat/coati/distributed/loss.py
+++ b/applications/ColossalChat/coati/distributed/loss.py
@@ -16,6 +16,7 @@ def __init__(
clip_eps_high: float = 0.2,
beta: float = 0.01,
loss_variation: str = "sample_level",
+ adv: str = "GRPO",
) -> None:
super().__init__()
self.clip_eps_low = clip_eps_low
@@ -23,6 +24,7 @@ def __init__(
self.beta = beta
self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
+ self.adv = adv
def forward(
self,
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index fbec2319be10..38a85b9b1c4d 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -118,6 +118,9 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 42ec582f6b68..b584b940ccaa 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -137,7 +137,7 @@
)
# GRPO parameters
- parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
+ parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
@@ -292,6 +292,7 @@
if args.algo == "GRPO":
# Default Settings
grpo_config = {
+ "algo": "GRPO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
@@ -313,6 +314,7 @@
elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
+ "algo": "DAPO",
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
@@ -339,6 +341,50 @@
else None
),
}
+ elif args.algo == "REINFORCE_PPB":
+ # Default Settings
+ grpo_config = {
+ "algo": "REINFORCE_PPB",
+ "lr": args.learning_rate,
+ "train_microbatch_size": args.train_microbatch_size,
+ "beta": args.kl_coeff, # KL penalty coefficient
+ "loss_variation": "sample_level",
+ "reward_fn_type": args.reward_type,
+ "max_length": args.max_new_tokens + args.max_prompt_tokens,
+ "max_new_tokens": args.max_new_tokens,
+ "response_format_tags": (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if args.reward_type == "think_answer_tags"
+ else None
+ ),
+ }
+ elif args.algo == "RLOO":
+ # Default Settings
+ grpo_config = {
+ "algo": "RLOO",
+ "lr": args.learning_rate,
+ "train_microbatch_size": args.train_microbatch_size,
+ "beta": args.kl_coeff, # KL penalty coefficient
+ "loss_variation": "sample_level",
+ "reward_fn_type": args.reward_type,
+ "max_length": args.max_new_tokens + args.max_prompt_tokens,
+ "max_new_tokens": args.max_new_tokens,
+ "response_format_tags": (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if args.reward_type == "think_answer_tags"
+ else None
+ ),
+ }
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.reward_type == "code":