Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 75 additions & 5 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
adv=grpo_config.get("algo"),
)

# Reference model is initialized from policy model.
Expand Down Expand Up @@ -137,6 +138,8 @@ def __init__(
eta_min=0.1 * grpo_config.get("lr", 1e-6),
)

self.adv = grpo_config.get("algo")

def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
Expand Down Expand Up @@ -204,9 +207,23 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)

reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
if self.adv == "GRPO" or self.adv == "DAPO":

reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)

elif self.adv == "REINFORCE_PPB":

# [minibatch_size x num_generations]
advantages = ((reward - reward_mean)).unsqueeze(dim=-1)

elif self.adv == "RLOO":

advantages = (
reward * self.num_generations / (self.num_generations - 1)
- reward_mean * self.num_generations / (self.num_generations - 1)
).unsqueeze(dim=-1)

# [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
Expand Down Expand Up @@ -358,10 +375,34 @@ def _criterion(outputs, inputs):
per_token_kl = 0.0
kl.append(torch.tensor(0.0))

inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1)

if self.adv == "REINFORCE_PPB":

inputs["advantages"] = inputs["advantages"] - self.policy_loss_fn.beta * per_token_kl
advantages_forward_micro_batch_mean = torch.sum(
inputs["advantages"] * inputs["action_mask"]
) / (torch.sum(inputs["action_mask"]) + 1e-4)
advantages_forward_micro_batch_std = torch.rsqrt(
torch.sum(
(inputs["advantages"] - advantages_forward_micro_batch_mean) ** 2
* inputs["action_mask"]
)
/ (torch.sum(inputs["action_mask"]) + 1e-4)
+ 1e-8
)
inputs["advantages"] = (
(inputs["advantages"] - advantages_forward_micro_batch_mean)
* inputs["action_mask"]
/ (advantages_forward_micro_batch_std)
)

per_token_kl = 0.0

loss, _ = self.policy_loss_fn(
action_log_probs,
inputs["old_action_log_probs"],
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
inputs["advantages"],
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
Expand Down Expand Up @@ -420,10 +461,39 @@ def _criterion(outputs, inputs):
per_token_kl = 0.0
kl = None

(
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1)
- self.policy_loss_fn.beta * per_token_kl
)

if self.adv == "REINFORCE_PPB":

advantages_forward_micro_batch = (
advantages_forward_micro_batch - self.policy_loss_fn.beta * per_token_kl
)
advantages_forward_micro_batch_mean = torch.sum(
advantages_forward_micro_batch * action_mask_forward_micro_batch
) / (torch.sum(action_mask_forward_micro_batch) + 1e-4)
advantages_forward_micro_batch_std = torch.rsqrt(
torch.sum(
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean) ** 2
* action_mask_forward_micro_batch
)
/ (torch.sum(action_mask_forward_micro_batch) + 1e-4)
+ 1e-8
)
advantages_forward_micro_batch = (
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean)
* action_mask_forward_micro_batch
/ (advantages_forward_micro_batch_std)
)

per_token_kl = 0.0

loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs_micro_batch,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
advantages_forward_micro_batch,
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
Expand Down
9 changes: 8 additions & 1 deletion applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer

ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"DAPO": GRPOConsumer,
"REINFORCE_PPB": GRPOConsumer,
"RLOO": GRPOConsumer,
}


def get_jsonl_size_fast(path: str) -> int:
Expand Down Expand Up @@ -66,6 +72,7 @@ def launch_distributed(
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)

train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)

assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0

dataset_path = train_dataset_config["path"]
Expand Down
2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/distributed/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ def __init__(
clip_eps_high: float = 0.2,
beta: float = 0.01,
loss_variation: str = "sample_level",
adv: str = "GRPO",
) -> None:
super().__init__()
self.clip_eps_low = clip_eps_low
self.clip_eps_high = clip_eps_high
self.beta = beta
self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
self.adv = adv

def forward(
self,
Expand Down
3 changes: 3 additions & 0 deletions applications/ColossalChat/coati/distributed/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"

if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
Expand Down
48 changes: 47 additions & 1 deletion applications/ColossalChat/rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
)

# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
Expand Down Expand Up @@ -292,6 +292,7 @@
if args.algo == "GRPO":
# Default Settings
grpo_config = {
"algo": "GRPO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
Expand All @@ -313,6 +314,7 @@
elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
"algo": "DAPO",
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
Expand All @@ -339,6 +341,50 @@
else None
),
}
elif args.algo == "REINFORCE_PPB":
# Default Settings
grpo_config = {
"algo": "REINFORCE_PPB",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
elif args.algo == "RLOO":
# Default Settings
grpo_config = {
"algo": "RLOO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.reward_type == "code":
Expand Down