Skip to content
Merged
8 changes: 4 additions & 4 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# GRPO Algorithm Configuration
grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 8
num_generations_per_prompt: 16
max_num_steps: 1000000
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
val_at_start: true
val_at_start: false
max_val_samples: 256
val_batch_size: 16
val_batch_size: 256

loss_fn:
reference_policy_kl_penalty: 0.01
Expand All @@ -24,7 +24,7 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
train_global_batch_size: 32
train_global_batch_size: 512
train_micro_batch_size: 4
generation_batch_size: 32
logprob_batch_size: 4
Expand Down
4 changes: 2 additions & 2 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults: "grpo_math_1B.yaml"

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
train_global_batch_size: 32
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32
logprob_batch_size: 2
Expand All @@ -13,7 +13,7 @@ policy:
optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-6
lr: 3.0e-7
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(

self.llm = LLM(
model=self.model_name,
load_format="dummy",
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
enable_prefix_caching=True,
Expand Down
31 changes: 30 additions & 1 deletion nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,17 @@ def __init__(
def do_fsdp(model):
# Create a device mesh with 'world_size' GPUs in a 1D arrangement.
mesh = init_device_mesh("cuda", (world_size,))
mp_policy = MixedPrecision(
Comment thread
parthchadha marked this conversation as resolved.
param_dtype=self.dtype,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)

return FullyShardedDataParallel(
model,
device_mesh=mesh,
auto_wrap_policy=size_based_auto_wrap_policy,
mixed_precision=mp_policy,
)

self.model.to("cuda")
Expand Down Expand Up @@ -676,16 +682,31 @@ def report_device_id(self) -> str:
self.device_uuid = current_platform.get_device_uuid(torch.cuda.current_device())
return self.device_uuid

def get_weight_ipc_handles(self):
@torch.no_grad()
def get_weight_ipc_handles(self, offload_model=True):
from torch.multiprocessing.reductions import reduce_tensor

# TODO @sahilj: do this without an allgather (maybe FSDP2)
params = self.model.state_dict()

# Create a copy of parameters in the desired dtype (bfloat16 or float32)
dtype_params = {}
for name, param in params.items():
# Convert parameters to the configured dtype
dtype_params[name] = param.to(self.dtype, non_blocking=True)

# Replace the original params with the converted ones
params = dtype_params
self._held_reference_model_params = params
data = {}
self.device_uuid = self.report_device_id()
for name, p in params.items():
data[name] = reduce_tensor(p.detach())

if offload_model:
self.model = self.move_to_cpu(self.model)
gc.collect()
torch.cuda.empty_cache()
return {self.device_uuid: data}

def prepare_for_lp_inference(self):
Expand All @@ -707,13 +728,19 @@ def prepare_for_training(self, *args, **kwargs):

torch.cuda.empty_cache()

@torch.no_grad()
def offload_before_refit(self):
"""Offload the optimizer and buffers to the CPU."""
torch.randn(1).cuda() # wake up torch allocator
if hasattr(self, "optimizer") and self.optimizer is not None:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to("cpu")

for buffer in self.model.buffers():
buffer.data = buffer.data.to("cpu")

gc.collect()
torch.cuda.empty_cache()

Expand All @@ -724,10 +751,12 @@ def offload_before_refit(self):
f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved"
)

@torch.no_grad()
def offload_after_refit(self):
# Offload as much as possible on the CPU
self.model = self.move_to_cpu(self.model)
self.model.eval()
torch.randn(1).cuda() # wake up torch allocator
self.offload_before_refit() # rerun the old offload function

if self._held_reference_model_params is not None:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,16 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
# Create both policies
print("Creating vLLM policy...")
vllm_policy = VllmGeneration(cluster, vllm_config)
vllm_policy.finish_generation()

print("Creating HF policy...")
hf_policy = HfPolicy(cluster, hf_config)

print(f"refitting vllm policy...")
ipc_handles = hf_policy.get_weights_ipc_handles()
vllm_policy.prepare_for_generation()
vllm_policy.update_weights(ipc_handles)

# Step 1: Use vLLM for generation
print("Using vLLM policy for fast generation...")
generation_results = vllm_policy.generate(test_input_data)
Expand Down Expand Up @@ -262,6 +268,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
}
)
# Get logprobs from HF policy
hf_policy.prepare_for_lp_inference()
fprop_results = hf_policy.get_logprobs(fprop_logprob_data)
# Zero out logprobs for input tokens

Expand Down Expand Up @@ -327,6 +334,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
print(f"Training loss: {results['loss']}")

hf_policy.finish_training()
hf_policy.offload_after_refit()

# Step 4: Use vLLM for generation again to complete the workflow
print("Using vLLM for generation again...")
Expand Down