Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()
self.lr_scheduler = None

def setup(self) -> None:
for i in range(self.num_producers):
Expand Down Expand Up @@ -121,6 +122,8 @@ def loop(self) -> None:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
Comment thread
YeAnbang marked this conversation as resolved.
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
Expand Down
185 changes: 172 additions & 13 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
from contextlib import nullcontext
from typing import Optional

import ray
import torch
import torch.distributed as dist
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
Expand All @@ -12,6 +15,7 @@
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer

from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam


Expand All @@ -31,8 +35,10 @@ def __init__(
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
num_generations=8,
use_wandb=True,
generate_config=None,
training_config={},
):
super().__init__(
num_producers,
Expand All @@ -52,7 +58,7 @@ def __init__(
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
Expand All @@ -61,6 +67,7 @@ def __init__(
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
self.generate_config = generate_config

# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand All @@ -69,6 +76,9 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
self.filter_range = training_config.get("filter_range", None)
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."

# Initialize verifiable reward.
response_format_tags = {
Expand All @@ -84,11 +94,21 @@ def __init__(
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
if use_wandb and self.rank == 0:
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True)
name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)

self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6),
)

def setup(self):
super().setup()
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
Expand All @@ -113,15 +133,17 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
response_length = torch.sum(action_mask, dim=1).to(torch.float32)

need_update = (step_idx + 1) % self.num_microbatches == 0

ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
with ctx:
policy_model_logits = self.policy_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
policy_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)

with torch.no_grad():
Expand All @@ -130,7 +152,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
reference_model_logits / self.generate_config["temperature"],
data["input_ids"],
num_action,
self.plugin.shard_config,
)

per_token_kl = (
Expand All @@ -149,21 +174,31 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)

# [batch_size, num_generations]

group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)

# [batch_size x num_generations]
reward_mean = group_reward.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = (reward - reward_mean) / (reward_std + 1e-4)

# Calculate Loss
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask,
loss_mask=loss_mask,
)

if not skip_update:
Expand Down Expand Up @@ -207,13 +242,15 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
)
self.wandb_run.log(
{
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count,
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/response_length": self.accum_response_length.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
)
self.accum_loss.zero_()
Expand All @@ -232,3 +269,125 @@ def state_dict(self):
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict


@ray.remote
class GRPOEvalConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=True,
log_dir="./results",
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)

self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations

# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)

self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
else:
os.system(f"rm -rf {self.log_dir}/*")

def setup(self):
super().setup()
self.policy_model, _, *_ = self.booster.boost(self.policy_model)

def step(self, step_idx: int, **kwargs) -> Optional[float]:
rank = dist.get_rank()
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
kwargs["input_ids"].size(0)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_reward = [value[1].item() for value in reward_group]
acc_reward = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]

response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
for i in range(len(response)):
f.write(
json.dumps(
{
"response": response[i],
"reward": reward[i],
"format_reward": format_reward[i],
"acc_reward": acc_reward[i],
"response_length": response_length[i],
},
ensure_ascii=False,
)
+ "\n"
)

self.accum_reward += sum(reward)
self.accum_format_reward += sum(format_reward)
self.accum_acc_reward += sum(acc_reward)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)

# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
)
return None

def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ class TransformersInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)

def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = 8
self.num_generations = num_generations

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -120,7 +126,13 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:


class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
Expand Down Expand Up @@ -175,27 +187,38 @@ class VLLMInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(
logprobs=0,
n=8,
)

def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(path, **model_config)
self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
self.num_generations = num_generations

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
micro_batch_input_ids = input_ids.tolist()
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
Expand Down
Loading