From 83c5f985c9cbd69ca222aa98ad8ae21367694e1e Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 26 Jun 2025 09:26:05 -0700 Subject: [PATCH 1/2] disable overlap param gather during reference model forward Signed-off-by: ashors1 --- .../models/policy/megatron_policy_worker.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index e0bd4373be..79745f171d 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -421,15 +421,6 @@ def __init__( pretrained_path, "iter_0000000/run_config.yaml" ) - assert not ( - self.cfg["megatron_cfg"]["distributed_data_parallel_config"][ - "overlap_param_gather" - ] - and self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"] - ), ( - "Using overlap param gather together with distributed optimizer has known convergence issues. Please disable overlap param gather." - ) - self.tokenizer = tokenizer if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token @@ -633,6 +624,13 @@ def __init__( self._held_gather_buffer = None self.megatron_to_hf_converter = MegatronToHFConverter(hf_model_name, self.model) + self.should_disable_forward_pre_hook = ( + self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"] + and self.cfg["megatron_cfg"]["distributed_data_parallel_config"][ + "overlap_param_gather" + ] + ) + def configure_worker(self, num_gpus: int, bundle_indices: Optional[tuple] = None): USE_EXPANDABLE_SEGMENTS = False # Disabling this right now as it seems to cause vLLM refit issues with Ampere if USE_EXPANDABLE_SEGMENTS: @@ -650,6 +648,14 @@ def get_gpu_info(self): """Return information about the GPU being used by this worker.""" return get_gpu_info(self.model) + def enable_forward_pre_hook(self): + assert isinstance(self.model, DistributedDataParallel) + self.model.enable_forward_pre_hook() + + def disable_forward_pre_hook(self, param_sync=True): + assert isinstance(self.model, DistributedDataParallel) + self.model.disable_forward_pre_hook(param_sync=param_sync) + def train( self, data: BatchedDataDict, @@ -989,6 +995,10 @@ def use_reference_model(self): On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references On exit: Restores original references and re-flips cuda/cpu """ + ## disable overlap param gather when swapping weights + if self.should_disable_forward_pre_hook: + self.disable_forward_pre_hook() + with torch.no_grad(): try: # Save original references @@ -1023,6 +1033,10 @@ def use_reference_model(self): gc.collect() torch.cuda.empty_cache() + ## re-enable overlap param gather after weight swap + if self.should_disable_forward_pre_hook: + self.enable_forward_pre_hook() + # Temporary fix, 'data' is a kwarg due to some sort of ray bug def get_reference_policy_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None From b83657e00e69b68481333869c7f622a5c49bdf81 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 26 Jun 2025 09:58:09 -0700 Subject: [PATCH 2/2] enable overlap_param_gather Signed-off-by: ashors1 --- examples/configs/dpo.yaml | 2 +- examples/configs/grpo_math_1B_megatron.yaml | 2 +- .../recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml | 2 +- .../llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml | 2 +- .../recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml | 2 +- examples/configs/sft.yaml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index ccddde43b0..db6fb7fa6d 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -134,7 +134,7 @@ policy: distributed_data_parallel_config: grad_reduce_in_fp32: false overlap_grad_reduce: true - overlap_param_gather: false + overlap_param_gather: true average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 5b14a7ff56..6b07317ed6 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -115,7 +115,7 @@ policy: distributed_data_parallel_config: grad_reduce_in_fp32: false overlap_grad_reduce: true - overlap_param_gather: false + overlap_param_gather: true average_in_collective: true use_custom_fsdp: false data_parallel_sharding_strategy: "optim_grads_params" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml index 03bd0d7077..1fd336d0b4 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatron.yaml @@ -91,7 +91,7 @@ policy: distributed_data_parallel_config: grad_reduce_in_fp32: false overlap_grad_reduce: true - overlap_param_gather: false + overlap_param_gather: true average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml index 74c93bbae0..73008f3154 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.yaml @@ -91,7 +91,7 @@ policy: distributed_data_parallel_config: grad_reduce_in_fp32: false overlap_grad_reduce: true - overlap_param_gather: false + overlap_param_gather: true average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml index f6ab46c997..ddd53920e6 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-instruct-1n8g-megatron.yaml @@ -79,7 +79,7 @@ policy: distributed_data_parallel_config: grad_reduce_in_fp32: false overlap_grad_reduce: true - overlap_param_gather: false + overlap_param_gather: true average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params" diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 5be4451d3b..e3c614e2a7 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -109,7 +109,7 @@ policy: distributed_data_parallel_config: grad_reduce_in_fp32: false overlap_grad_reduce: true - overlap_param_gather: false + overlap_param_gather: true average_in_collective: true data_parallel_sharding_strategy: "optim_grads_params"