Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 125 additions & 46 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ray.util.collective as cc
import torch
import torch.distributed as dist
from coati.distributed.profiling_utils import CustomProfiler
from tqdm import tqdm
from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -36,6 +37,8 @@ def __init__(
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
enable_profiling: bool = False,
n_behind: int = 0,
):
self.num_producers = num_producers
self.num_episodes = num_episodes
Expand All @@ -49,6 +52,7 @@ def __init__(
self.minibatch_size = minibatch_size
self.save_interval = save_interval
self.save_dir = save_dir
self.enable_profiling = enable_profiling
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // minibatch_size

Expand All @@ -57,6 +61,7 @@ def __init__(

self.device = get_current_device()
self.lr_scheduler = None
self.n_behind = n_behind

def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
Expand Down Expand Up @@ -94,13 +99,49 @@ def setup(self) -> None:

self.buffer = []
self.recv_cnt = 0
self.profiler = CustomProfiler(f"C{self.rank}", disabled=not self.enable_profiling)

def state_dict(self) -> Dict[str, torch.Tensor]:
raise NotImplementedError

def step(self, step_idx: int, **kwargs) -> Optional[float]:
raise NotImplementedError

def prepare_mini_batch(self, effective_group_to_raw_group_mapping: Dict[int, int]) -> Dict[str, torch.Tensor]:
"""
Prepare a mini-batch from the effective group to raw group mapping.
This method is used to create a mini-batch for training.
"""
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
return batch, raw_mini_batches_metric_dict

def calculate_effective_group_to_raw_group_mapping(self, step):
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
if self.n_behind == 0:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
else:
if self.buffer[buffer_idx][-1] <= step - self.n_behind:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = buffer_idx
return effective_group_to_raw_group_mapping

def loop(self) -> None:
print(
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
Expand All @@ -112,14 +153,53 @@ def loop(self) -> None:
disable=self.rank != 0,
) as pbar:
for step in pbar:
torch.cuda.reset_peak_memory_stats()
i = 0

self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
for _ in range(self.num_recv_per_update):
if self.n_behind > 0:
# after sync model, do not wait for more data to arrive as rollout takes time, use buffered data
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
self.profiler.log(
f"Still have {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
)
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(
effective_group_to_raw_group_mapping
)
effective_group_to_raw_group_mapping = (
self.calculate_effective_group_to_raw_group_mapping(step=step)
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before
- self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1

# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.profiler.enter(f"recv_broadcast_data_P{r}")
raw_batch = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
self.profiler.exit(f"recv_broadcast_data_P{r}")
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
Expand Down Expand Up @@ -154,63 +234,52 @@ def loop(self) -> None:
format_acc[group_idx],
ans_acc[group_idx],
response_len[group_idx],
step,
]
)
if effective_group_mask is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
effective_group_to_raw_group_mapping = self.calculate_effective_group_to_raw_group_mapping(
step=step
)
print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
)

while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
# on each dp_rank, we use minibatch_size effective samples to form a batch
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
if self.n_behind == 0:
# If n_behind is 0, we start training after receiving data from producers.
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
self.profiler.log(
f"Collect {len(effective_group_to_raw_group_mapping)} effective groups, greater than {self.dp_size * self.minibatch_size}, start training"
)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1
batch, raw_mini_batches_metric_dict = self.prepare_mini_batch(
effective_group_to_raw_group_mapping
)
self.profiler.enter("step")
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.profiler.exit("step")
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(
effective_group_to_raw_group_mapping
)
effective_group_to_raw_group_mapping = (
self.calculate_effective_group_to_raw_group_mapping(step=step)
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before
- self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1

if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
Expand All @@ -221,13 +290,16 @@ def loop(self) -> None:
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")

if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
episode != 0 or step >= self.n_behind
):
if self.pp_size > 1:
print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
)
else:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
self.profiler.enter("sync_model")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.pp_size > 1:
Expand All @@ -245,6 +317,13 @@ def loop(self) -> None:
)
del state_dict
torch.cuda.empty_cache()
self.profiler.exit("sync_model")
self.profiler.log(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
self.profiler.exit(f"rollout_episode_{episode}_step_{step}")

def __del__(self):
if hasattr(self, "profiler"):
self.profiler.close()


@ray.remote
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
enable_profiling: bool = False,
n_behind: int = 0,
):
print(f"Using GRPO config: {grpo_config}")
if (
Expand All @@ -63,6 +65,8 @@ def __init__(
minibatch_size,
save_interval=save_interval,
save_dir=save_dir,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
Expand Down Expand Up @@ -375,7 +379,7 @@ def _criterion(outputs, inputs):
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
Expand Down
6 changes: 6 additions & 0 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def launch_distributed(
eval_generation_config: Optional[Dict[str, Any]] = None,
log_rollout_interval: int = 20,
rollout_save_dir: str = "./rollout",
enable_profiling: bool = False,
n_behind: int = 0,
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
Expand Down Expand Up @@ -132,6 +134,8 @@ def launch_distributed(
wandb_group_name=wandb_group_name,
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
Expand Down Expand Up @@ -171,6 +175,8 @@ def launch_distributed(
project_name=project_name,
run_name=run_name,
wandb_group_name=wandb_group_name,
enable_profiling=enable_profiling,
n_behind=n_behind,
)
consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in consumer_procs])
Expand Down
Loading