diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 827fd9cbeb..72aad000ce 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -1,14 +1,14 @@ # GRPO Algorithm Configuration grpo: num_prompts_per_step: 32 - num_generations_per_prompt: 8 + num_generations_per_prompt: 16 max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true val_period: 10 - val_at_start: true + val_at_start: false max_val_samples: 256 - val_batch_size: 16 + val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 @@ -24,7 +24,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 32 + train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 logprob_batch_size: 4 diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 69802553c1..d747e9249f 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -3,7 +3,7 @@ defaults: "grpo_math_1B.yaml" policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" - train_global_batch_size: 32 + train_global_batch_size: 512 train_micro_batch_size: 1 generation_batch_size: 32 logprob_batch_size: 2 @@ -13,7 +13,7 @@ policy: optimizer: name: "torch.optim.AdamW" kwargs: - lr: 5.0e-6 + lr: 3.0e-7 weight_decay: 0.01 betas: [0.9, 0.999] eps: 1e-8 diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index eab407c0a3..ebbe53cff5 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -166,6 +166,7 @@ def __init__( self.llm = LLM( model=self.model_name, + load_format="dummy", tensor_parallel_size=self.tensor_parallel_size, gpu_memory_utilization=self.gpu_memory_utilization, enable_prefix_caching=True, diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index b27b575c10..34a621063d 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -107,11 +107,17 @@ def __init__( def do_fsdp(model): # Create a device mesh with 'world_size' GPUs in a 1D arrangement. mesh = init_device_mesh("cuda", (world_size,)) + mp_policy = MixedPrecision( + param_dtype=self.dtype, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) return FullyShardedDataParallel( model, device_mesh=mesh, auto_wrap_policy=size_based_auto_wrap_policy, + mixed_precision=mp_policy, ) self.model.to("cuda") @@ -676,16 +682,31 @@ def report_device_id(self) -> str: self.device_uuid = current_platform.get_device_uuid(torch.cuda.current_device()) return self.device_uuid - def get_weight_ipc_handles(self): + @torch.no_grad() + def get_weight_ipc_handles(self, offload_model=True): from torch.multiprocessing.reductions import reduce_tensor # TODO @sahilj: do this without an allgather (maybe FSDP2) params = self.model.state_dict() + + # Create a copy of parameters in the desired dtype (bfloat16 or float32) + dtype_params = {} + for name, param in params.items(): + # Convert parameters to the configured dtype + dtype_params[name] = param.to(self.dtype, non_blocking=True) + + # Replace the original params with the converted ones + params = dtype_params self._held_reference_model_params = params data = {} self.device_uuid = self.report_device_id() for name, p in params.items(): data[name] = reduce_tensor(p.detach()) + + if offload_model: + self.model = self.move_to_cpu(self.model) + gc.collect() + torch.cuda.empty_cache() return {self.device_uuid: data} def prepare_for_lp_inference(self): @@ -707,13 +728,19 @@ def prepare_for_training(self, *args, **kwargs): torch.cuda.empty_cache() + @torch.no_grad() def offload_before_refit(self): """Offload the optimizer and buffers to the CPU.""" + torch.randn(1).cuda() # wake up torch allocator if hasattr(self, "optimizer") and self.optimizer is not None: for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to("cpu") + + for buffer in self.model.buffers(): + buffer.data = buffer.data.to("cpu") + gc.collect() torch.cuda.empty_cache() @@ -724,10 +751,12 @@ def offload_before_refit(self): f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" ) + @torch.no_grad() def offload_after_refit(self): # Offload as much as possible on the CPU self.model = self.move_to_cpu(self.model) self.model.eval() + torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function if self._held_reference_model_params is not None: diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 9f6724923a..9c89046af2 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -231,10 +231,16 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create both policies print("Creating vLLM policy...") vllm_policy = VllmGeneration(cluster, vllm_config) + vllm_policy.finish_generation() print("Creating HF policy...") hf_policy = HfPolicy(cluster, hf_config) + print(f"refitting vllm policy...") + ipc_handles = hf_policy.get_weights_ipc_handles() + vllm_policy.prepare_for_generation() + vllm_policy.update_weights(ipc_handles) + # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") generation_results = vllm_policy.generate(test_input_data) @@ -262,6 +268,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): } ) # Get logprobs from HF policy + hf_policy.prepare_for_lp_inference() fprop_results = hf_policy.get_logprobs(fprop_logprob_data) # Zero out logprobs for input tokens @@ -327,6 +334,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): print(f"Training loss: {results['loss']}") hf_policy.finish_training() + hf_policy.offload_after_refit() # Step 4: Use vLLM for generation again to complete the workflow print("Using vLLM for generation again...")