diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 0d06ad6366..5ff77e11e9 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -182,6 +182,7 @@ def setup( if last_checkpoint_path else None, init_optimizer=True, + init_reference_model=False, ) loss_fn = NLLLoss() print(f" ✓ Model initialized") diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index b2b004bddb..ebc9e879f2 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -68,6 +68,7 @@ def __init__( weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, init_optimizer: bool = True, + init_reference_model: bool = True, ): self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call @@ -88,12 +89,14 @@ def __init__( device_map="cpu", # load weights onto CPU initially torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) - self.reference_model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed - ) - + if init_reference_model: + self.reference_model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + ) + else: + self.reference_model = None self.tokenizer = AutoTokenizer.from_pretrained(model_name) # If no pad token is defined, you might need: if self.tokenizer.pad_token is None: @@ -123,9 +126,10 @@ def do_fsdp(model): self.model.to("cuda") self.model = do_fsdp(self.model) self.model = self.move_to_cpu(self.model) - self.reference_model.to("cuda") - self.reference_model = do_fsdp(self.reference_model) - self.reference_model = self.move_to_cpu(self.reference_model) + if self.reference_model is not None: + self.reference_model.to("cuda") + self.reference_model = do_fsdp(self.reference_model) + self.reference_model = self.move_to_cpu(self.reference_model) self.model.to("cuda") self._held_reference_model_params = None # register_fsdp_forward_method(self.model, "generate") @@ -893,6 +897,7 @@ def __init__( init_optimizer: bool = True, weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, + init_reference_model: bool = True, ): if weights_path: weights_path = os.path.abspath(weights_path) @@ -905,6 +910,7 @@ def __init__( init_optimizer=init_optimizer, weights_path=weights_path, optimizer_path=optimizer_path, + init_reference_model=init_reference_model, ) self.worker_group = RayWorkerGroup( cluster, diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 29d8cbcbee..ded244feac 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -201,7 +201,7 @@ def training_setup(): config = basic_llama_test_config print("Creating training HfPolicy...") - policy = HfPolicy(cluster=cluster, config=config) + policy = HfPolicy(cluster=cluster, config=config, init_reference_model=False) # Create a test batch print("Creating test batch...") @@ -278,7 +278,7 @@ def verify_loss_tensor(loss_tensor): @pytest.fixture -def generation_setup(): +def generation_setup(request): """Setup and teardown specifically for generation tests.""" policy = None cluster = None @@ -300,7 +300,9 @@ def generation_setup(): config = basic_llama_test_config print("Creating generation HfPolicy...") - policy = HfPolicy(cluster=cluster, config=config) + policy = HfPolicy( + cluster=cluster, config=config, init_reference_model=request.param + ) # Create a test batch print("Creating test batch...") @@ -364,6 +366,7 @@ def generation_setup(): @pytest.mark.timeout(180) +@pytest.mark.parametrize("generation_setup", [False], indirect=True) def test_hf_policy_generation(generation_setup, tracker): policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup @@ -450,6 +453,7 @@ def test_hf_policy_generation(generation_setup, tracker): @pytest.mark.timeout(180) +@pytest.mark.parametrize("generation_setup", [True], indirect=True) def test_all_hf_policy_generation_lps_ref_training(generation_setup): policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup