From 10c114f5b171aef5c165dc38bdf5ec5df1e4be48 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Sat, 22 Mar 2025 18:06:41 -0700 Subject: [PATCH 1/5] Multiprocessing memory improvements and better defaults Signed-off-by: Sahil Jain --- examples/configs/grpo_math_1B.yaml | 8 +++--- examples/configs/grpo_math_8B.yaml | 2 +- nemo_reinforcer/models/generation/vllm.py | 1 + nemo_reinforcer/models/policy/hf_policy.py | 31 +++++++++++++++++++++- 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index ab2fbdf59c..2e25d76142 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -1,14 +1,14 @@ # GRPO Algorithm Configuration grpo: num_prompts_per_step: 32 - num_generations_per_prompt: 8 + num_generations_per_prompt: 16 max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true val_period: 10 - val_at_start: true + val_at_start: false max_val_samples: 256 - val_batch_size: 16 + val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 @@ -24,7 +24,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 32 + train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 logprob_batch_size: 4 diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 69802553c1..2b8f8eb5fa 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -3,7 +3,7 @@ defaults: "grpo_math_1B.yaml" policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" - train_global_batch_size: 32 + train_global_batch_size: 512 train_micro_batch_size: 1 generation_batch_size: 32 logprob_batch_size: 2 diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index eab407c0a3..ebbe53cff5 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -166,6 +166,7 @@ def __init__( self.llm = LLM( model=self.model_name, + load_format="dummy", tensor_parallel_size=self.tensor_parallel_size, gpu_memory_utilization=self.gpu_memory_utilization, enable_prefix_caching=True, diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 2eb5598f6b..2ea3477ab6 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -104,11 +104,17 @@ def __init__( def do_fsdp(model): # Create a device mesh with 'world_size' GPUs in a 1D arrangement. mesh = init_device_mesh("cuda", (world_size,)) + mp_policy = MixedPrecision( + param_dtype=self.dtype, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + ) return FullyShardedDataParallel( model, device_mesh=mesh, auto_wrap_policy=size_based_auto_wrap_policy, + mixed_precision=mp_policy, ) self.model.to("cuda") @@ -674,16 +680,31 @@ def report_device_id(self) -> str: self.device_uuid = current_platform.get_device_uuid(torch.cuda.current_device()) return self.device_uuid - def get_weight_ipc_handles(self): + @torch.no_grad() + def get_weight_ipc_handles(self, offload_model=True): from torch.multiprocessing.reductions import reduce_tensor # TODO @sahilj: do this without an allgather (maybe FSDP2) params = self.model.state_dict() + + # Create a copy of parameters in the desired dtype (bfloat16 or float32) + dtype_params = {} + for name, param in params.items(): + # Convert parameters to the configured dtype + dtype_params[name] = param.to(self.dtype, non_blocking=True) + + # Replace the original params with the converted ones + params = dtype_params self._held_reference_model_params = params data = {} self.device_uuid = self.report_device_id() for name, p in params.items(): data[name] = reduce_tensor(p.detach()) + + if offload_model: + self.model = self.move_to_cpu(self.model) + gc.collect() + torch.cuda.empty_cache() return {self.device_uuid: data} def prepare_for_lp_inference(self): @@ -705,13 +726,19 @@ def prepare_for_training(self, *args, **kwargs): torch.cuda.empty_cache() + @torch.no_grad() def offload_before_refit(self): """Offload the optimizer and buffers to the CPU.""" + torch.randn(1).cuda() # wake up torch allocator if hasattr(self, "optimizer") and self.optimizer is not None: for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to("cpu") + + for buffer in self.model.buffers(): + buffer.data = buffer.data.to("cpu") + gc.collect() torch.cuda.empty_cache() @@ -722,10 +749,12 @@ def offload_before_refit(self): f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" ) + @torch.no_grad() def offload_after_refit(self): # Offload as much as possible on the CPU self.model = self.move_to_cpu(self.model) self.model.eval() + torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function if self._held_reference_model_params is not None: From e89b13ef213e5c7b7c3f8edc40898c6f7496b424 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 23 Mar 2025 14:28:58 -0700 Subject: [PATCH 2/5] fix: ray.sub race condition when overlapping srun commands on same node trying a different approach Signed-off-by: Terry Kong --- ray.sub | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ray.sub b/ray.sub index a600d96063..7258aec045 100644 --- a/ray.sub +++ b/ray.sub @@ -59,7 +59,11 @@ ip_head=$head_node_ip:$port # First we start the head of the ray cluster on one of the physical nodes # Set GPU/CPU resources to 0 to avoid scheduling on the head node + head_cmd=$(cat < Date: Sun, 23 Mar 2025 22:21:43 -0700 Subject: [PATCH 3/5] Fixed default LR 8B Signed-off-by: Sahil Jain --- examples/configs/grpo_math_8B.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 2b8f8eb5fa..d747e9249f 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -13,7 +13,7 @@ policy: optimizer: name: "torch.optim.AdamW" kwargs: - lr: 5.0e-6 + lr: 3.0e-7 weight_decay: 0.01 betas: [0.9, 0.999] eps: 1e-8 From 06c73ad309eb5323c79ffbc52bf305344faab342 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 24 Mar 2025 11:58:06 -0700 Subject: [PATCH 4/5] lint fix Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/loss_functions.py | 6 +++++- nemo_reinforcer/algorithms/sft.py | 1 + nemo_reinforcer/algorithms/utils.py | 1 + nemo_reinforcer/models/policy/hf_policy.py | 9 ++++----- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 90230a06ab..8504dac007 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -166,4 +166,8 @@ def __call__( num_unmasked_tokens = torch.tensor(1) loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens - return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()} + return loss, { + "loss": loss.item(), + "num_unmasked_tokens": num_unmasked_tokens.item(), + "total_tokens": mask.numel(), + } diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index b216c02724..923437ac1c 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -61,6 +61,7 @@ class SFTConfig(TypedDict): val_at_start: bool seed: int + class MasterConfig(TypedDict): policy: PolicyConfig data: DataConfig diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a568dbcda6..138c3802d1 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -123,6 +123,7 @@ def masked_mean(values, mask, dim=None): return values[mask.bool()].mean() return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) + def set_seed(seed: int): """Sets the seed for python, numpy, and pytorch.""" random.seed(seed) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 2ea3477ab6..6917173c68 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -129,8 +129,7 @@ def do_fsdp(model): if init_optimizer: optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) self.optimizer = optimizer_cls( - self.model.parameters(), - **self.cfg["optimizer"]["kwargs"] + self.model.parameters(), **self.cfg["optimizer"]["kwargs"] ) else: self.optimizer = None @@ -692,7 +691,7 @@ def get_weight_ipc_handles(self, offload_model=True): for name, param in params.items(): # Convert parameters to the configured dtype dtype_params[name] = param.to(self.dtype, non_blocking=True) - + # Replace the original params with the converted ones params = dtype_params self._held_reference_model_params = params @@ -729,7 +728,7 @@ def prepare_for_training(self, *args, **kwargs): @torch.no_grad() def offload_before_refit(self): """Offload the optimizer and buffers to the CPU.""" - torch.randn(1).cuda() # wake up torch allocator + torch.randn(1).cuda() # wake up torch allocator if hasattr(self, "optimizer") and self.optimizer is not None: for state in self.optimizer.state.values(): for k, v in state.items(): @@ -754,7 +753,7 @@ def offload_after_refit(self): # Offload as much as possible on the CPU self.model = self.move_to_cpu(self.model) self.model.eval() - torch.randn(1).cuda() # wake up torch allocator + torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function if self._held_reference_model_params is not None: From 79427e4e088587ffc06d24f8e7d8cf1549cd6a1c Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Mon, 24 Mar 2025 17:39:15 -0700 Subject: [PATCH 5/5] updated tests for dummy vllm init Signed-off-by: Sahil Jain --- tests/unit/models/generation/test_vllm_generation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 9f6724923a..9c89046af2 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -231,10 +231,16 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create both policies print("Creating vLLM policy...") vllm_policy = VllmGeneration(cluster, vllm_config) + vllm_policy.finish_generation() print("Creating HF policy...") hf_policy = HfPolicy(cluster, hf_config) + print(f"refitting vllm policy...") + ipc_handles = hf_policy.get_weights_ipc_handles() + vllm_policy.prepare_for_generation() + vllm_policy.update_weights(ipc_handles) + # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") generation_results = vllm_policy.generate(test_input_data) @@ -262,6 +268,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): } ) # Get logprobs from HF policy + hf_policy.prepare_for_lp_inference() fprop_results = hf_policy.get_logprobs(fprop_logprob_data) # Zero out logprobs for input tokens @@ -327,6 +334,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): print(f"Training loss: {results['loss']}") hf_policy.finish_training() + hf_policy.offload_after_refit() # Step 4: Use vLLM for generation again to complete the workflow print("Using vLLM for generation again...")