Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional

Expand All @@ -7,11 +6,13 @@
import torch
import torch.distributed as dist
from coati.distributed.profiling_utils import CustomProfiler
from coati.utils import save_checkpoint
from tqdm import tqdm
from transformers import AutoModelForCausalLM

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -55,16 +56,19 @@ def __init__(
self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size
self.checkpoint_path = model_config.pop("checkpoint_path", None)

self.model_config = model_config
self.plugin_config = plugin_config

self.device = get_current_device()
self.lr_scheduler = None
self.n_behind = n_behind
self.total_prompt_trained = 0 # for setting start index when resume training

def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
self.coordinator = DistCoordinator()

plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if (
Expand Down Expand Up @@ -143,6 +147,26 @@ def calculate_effective_group_to_raw_group_mapping(self, step):
return effective_group_to_raw_group_mapping

def loop(self) -> None:
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
if self.tp_rank == 0 and self.dp_rank == 0:
ray_broadcast_tensor_dict(
state_dict,
src=self.num_producers,
device=self.device,
group_name=f"sync_model_{self.pp_rank}",
)
else:
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")

print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
)
Expand Down Expand Up @@ -208,6 +232,7 @@ def loop(self) -> None:
for k, v in raw_batch.items()
}
# [batch_size, num_generations] -> [batch_size]
self.total_prompt_trained += raw_batch["reward"].size(0)
reward = raw_batch["reward"][:, :, 0]
format_acc = raw_batch["format_acc"][:, :, 0]
ans_acc = raw_batch["ans_acc"][:, :, 0]
Expand Down Expand Up @@ -285,10 +310,19 @@ def loop(self) -> None:
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.policy_model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
epoch=episode,
step=step,
batch_size=int(self.total_prompt_trained / step),
coordinator=self.coordinator,
) # for setting start index when resuming training
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")

if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
episode != 0 or step >= self.n_behind
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from coati.distributed.loss import PolicyLoss
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer

from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
Expand Down Expand Up @@ -157,6 +158,14 @@ def setup(self):
)
if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model)
if self.checkpoint_path is not None:
load_checkpoint(
self.checkpoint_path,
self.booster,
self.policy_model,
self.optimizer,
self.lr_scheduler,
)
self.plugin.logger.set_level("ERROR")

def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
Expand Down
35 changes: 34 additions & 1 deletion applications/ColossalChat/coati/distributed/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import torch
import tqdm
import wandb
from coati.dataset import StatefulDistributedSampler
from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
from coati.distributed.profiling_utils import CustomProfiler
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.utils import load_checkpoint
from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -68,6 +70,7 @@ def __init__(
self.profiler = CustomProfiler(f"P{self.producer_idx}", disabled=not enable_profiling)

self.train_dataset_config = train_dataset_config
self.checkpoint_path = model_config.pop("checkpoint_path", None)
self.model_config = model_config
self.generate_config = generate_config
self.tokenizer_config = tokenizer_config
Expand Down Expand Up @@ -121,7 +124,7 @@ def __init__(
self.train_dataloader = DataLoader(
self.train_dataset,
batch_size=microbatch_size,
sampler=DistributedSampler(
sampler=StatefulDistributedSampler(
self.train_dataset,
num_replicas=num_producers,
rank=producer_idx,
Expand All @@ -133,6 +136,13 @@ def __init__(
drop_last=True,
collate_fn=collate_fn_grpo,
)
if self.checkpoint_path is not None:
# resume training from checkpoint
start_epoch, start_step, sampler_start_idx = load_checkpoint(self.checkpoint_path, None, None, None, None)
self.train_dataloader.sampler.set_start_index(sampler_start_idx)
print(
f"[P{self.producer_idx}] Resume training from checkpoint {self.checkpoint_path}, start epoch {start_epoch}, start step {start_step}, sampler start index {sampler_start_idx}"
)
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
Expand Down Expand Up @@ -203,6 +213,29 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
raise NotImplementedError

def loop(self) -> None:

torch.cuda.empty_cache()
self.profiler.enter("sync_model")
if self.consumer_pp_size > 1:
for pp_idx in range(self.consumer_pp_size):
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
else:
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
if "consumer_global_step" in state_dict:
self.consumer_global_step = state_dict.pop("consumer_global_step").item()
self.load_state_dict(state_dict)
self.profiler.exit("sync_model")
print(f"[P{self.producer_idx}] Sync initial model done.")
del state_dict
torch.cuda.empty_cache()

num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
num_valid_microbatches = num_update_per_episode * self.num_microbatches

Expand Down
9 changes: 6 additions & 3 deletions applications/ColossalChat/coati/utils/ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ def load_checkpoint(
"""

# Update booster params states.
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
if model is not None:
booster.load_model(model=model, checkpoint=os.path.join(load_dir, "modeling"))
if optimizer is not None:
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
if lr_scheduler is not None:
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))

running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
Expand Down
13 changes: 11 additions & 2 deletions applications/ColossalChat/rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument(
"-cp",
"--checkpoint-path",
type=str,
default=None,
help="Path to the checkpoint to load the model from. If not provided, the model will be loaded from the model path.",
)
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument(
"-ed",
Expand Down Expand Up @@ -226,8 +233,10 @@

os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock

inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
inference_model_config = dict(path=args.model, checkpoint_path=args.checkpoint_path)
train_model_config = dict(
path=args.model, use_flash_attention_2=True, use_cache=False, checkpoint_path=args.checkpoint_path
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)

if args.backend == "transformers":
Expand Down