From 455f46c8c5ce5ccc84552858a3111a50c8d3fa6b Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 24 Mar 2025 12:40:56 -0700 Subject: [PATCH 1/9] do not initialze reference model for sft Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 2 ++ nemo_reinforcer/models/policy/hf_policy.py | 25 ++++++++++++---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index b216c02724..35b2e94c55 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -61,6 +61,7 @@ class SFTConfig(TypedDict): val_at_start: bool seed: int + class MasterConfig(TypedDict): policy: PolicyConfig data: DataConfig @@ -175,6 +176,7 @@ def setup( if last_checkpoint_path else None, init_optimizer=True, + init_reference_model=False, ) loss_fn = NLLLoss() print(f" ✓ Model initialized") diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 2eb5598f6b..f805821e06 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -65,6 +65,7 @@ def __init__( weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, init_optimizer: bool = True, + init_reference_model: bool = True, ): self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call @@ -85,12 +86,14 @@ def __init__( device_map="cpu", # load weights onto CPU initially torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) - self.reference_model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed - ) - + if init_reference_model: + self.reference_model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + ) + else: + self.reference_model = None self.tokenizer = AutoTokenizer.from_pretrained(model_name) # If no pad token is defined, you might need: if self.tokenizer.pad_token is None: @@ -114,17 +117,17 @@ def do_fsdp(model): self.model.to("cuda") self.model = do_fsdp(self.model) self.model = self.move_to_cpu(self.model) - self.reference_model.to("cuda") - self.reference_model = do_fsdp(self.reference_model) - self.reference_model = self.move_to_cpu(self.reference_model) + if self.reference_model is not None: + self.reference_model.to("cuda") + self.reference_model = do_fsdp(self.reference_model) + self.reference_model = self.move_to_cpu(self.reference_model) self.model.to("cuda") self._held_reference_model_params = None # register_fsdp_forward_method(self.model, "generate") if init_optimizer: optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) self.optimizer = optimizer_cls( - self.model.parameters(), - **self.cfg["optimizer"]["kwargs"] + self.model.parameters(), **self.cfg["optimizer"]["kwargs"] ) else: self.optimizer = None From 990f8a167395e989652ea537f539b1dcf5aff70c Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 26 Mar 2025 08:49:58 -0700 Subject: [PATCH 2/9] fixes Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index f805821e06..c5282f9125 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -844,6 +844,7 @@ def __init__( init_optimizer: bool = True, weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, + init_reference_model: bool = True, ): if weights_path: weights_path = os.path.abspath(weights_path) @@ -856,6 +857,7 @@ def __init__( init_optimizer=init_optimizer, weights_path=weights_path, optimizer_path=optimizer_path, + init_reference_model=init_reference_model, ) self.worker_group = RayWorkerGroup( cluster, From 6364d779c02c66afd57fbab388151e606b2e8345 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 13:48:36 -0700 Subject: [PATCH 3/9] update test Signed-off-by: ashors1 --- tests/unit/models/policy/test_hf_ray_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 9b825e5302..a9ca0eb883 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -192,7 +192,7 @@ def training_setup(): config = basic_llama_test_config print("Creating training HfPolicy...") - policy = HfPolicy(cluster=cluster, config=config) + policy = HfPolicy(cluster=cluster, config=config, init_reference_model=False) # Create a test batch print("Creating test batch...") From 897e6021cc8b86389619b8b301b66425d4cb81c3 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 13:56:32 -0700 Subject: [PATCH 4/9] update pure generation test Signed-off-by: ashors1 --- .../unit/models/policy/test_hf_ray_policy.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index a9ca0eb883..d9b929b6f4 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -268,8 +268,7 @@ def verify_loss_tensor(loss_tensor): assert losses[0] > losses[-1], "Loss should decrease over training iterations" -@pytest.fixture -def generation_setup(): +def generation_setup(init_reference_model=True): """Setup and teardown specifically for generation tests.""" policy = None cluster = None @@ -354,9 +353,21 @@ def generation_setup(): policy.worker_group.shutdown() +@pytest.fixture +def generation_setup_no_ref_model(): + return generation_setup(init_reference_model=False) + + +@pytest.fixture +def generation_setup_with_ref_model(): + return generation_setup(init_reference_model=True) + + @pytest.mark.timeout(180) -def test_hf_policy_generation(generation_setup): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup +def test_hf_policy_generation(generation_setup_no_ref_model): + policy, cluster, data, tokenizer, prompts, expected_generations = ( + generation_setup_no_ref_model + ) # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" @@ -439,8 +450,10 @@ def test_hf_policy_generation(generation_setup): @pytest.mark.timeout(180) -def test_all_hf_policy_generation_lps_ref_training(generation_setup): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup +def test_all_hf_policy_generation_lps_ref_training(generation_setup_with_ref_model): + policy, cluster, data, tokenizer, prompts, expected_generations = ( + generation_setup_with_ref_model + ) # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" From 45363fb4d5ad569b584b39b8e2f6b807fd2a5041 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 16:48:49 -0700 Subject: [PATCH 5/9] fix Signed-off-by: ashors1 --- tests/unit/models/policy/test_hf_ray_policy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index a3ca94ee68..e0a37ea585 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -374,7 +374,9 @@ def generation_setup_with_ref_model(): @pytest.mark.timeout(180) def test_hf_policy_generation(generation_setup_no_ref_model, tracker): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup + policy, cluster, data, tokenizer, prompts, expected_generations = ( + generation_setup_no_ref_model + ) # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" From f018f832f46a0d4b760e3b62d13bb8c08619e2d2 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 31 Mar 2025 21:49:55 -0700 Subject: [PATCH 6/9] fix Signed-off-by: ashors1 --- tests/unit/models/policy/test_hf_ray_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index e0a37ea585..f182323adc 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -364,12 +364,12 @@ def generation_setup(init_reference_model=True): @pytest.fixture def generation_setup_no_ref_model(): - return generation_setup(init_reference_model=False) + yield generation_setup(init_reference_model=False) @pytest.fixture def generation_setup_with_ref_model(): - return generation_setup(init_reference_model=True) + yield generation_setup(init_reference_model=True) @pytest.mark.timeout(180) From 6182ec6785db93fc5e018cedd74d36a8cc0979eb Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Tue, 1 Apr 2025 09:03:34 -0700 Subject: [PATCH 7/9] fix tests Signed-off-by: Anna Shors --- tests/unit/models/policy/test_hf_ray_policy.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index f182323adc..21a76c9123 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -364,13 +364,14 @@ def generation_setup(init_reference_model=True): @pytest.fixture def generation_setup_no_ref_model(): - yield generation_setup(init_reference_model=False) + for item in generation_setup(init_reference_model=False): + yield item @pytest.fixture def generation_setup_with_ref_model(): - yield generation_setup(init_reference_model=True) - + for item in generation_setup(init_reference_model=False): + yield item @pytest.mark.timeout(180) def test_hf_policy_generation(generation_setup_no_ref_model, tracker): From af85ca9d0de3eba65dc6394a217ab115deabfca8 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 09:20:14 -0700 Subject: [PATCH 8/9] linting Signed-off-by: ashors1 --- tests/unit/models/policy/test_hf_ray_policy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 21a76c9123..23f457f06d 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -373,6 +373,7 @@ def generation_setup_with_ref_model(): for item in generation_setup(init_reference_model=False): yield item + @pytest.mark.timeout(180) def test_hf_policy_generation(generation_setup_no_ref_model, tracker): policy, cluster, data, tokenizer, prompts, expected_generations = ( From 122f37a9b014b3efa789feb1daaf245c125287ec Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 1 Apr 2025 09:55:50 -0700 Subject: [PATCH 9/9] use indirect for fixtures with args Signed-off-by: ashors1 --- .../unit/models/policy/test_hf_ray_policy.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 23f457f06d..ded244feac 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -277,7 +277,8 @@ def verify_loss_tensor(loss_tensor): assert losses[0] > losses[-1], "Loss should decrease over training iterations" -def generation_setup(init_reference_model=True): +@pytest.fixture +def generation_setup(request): """Setup and teardown specifically for generation tests.""" policy = None cluster = None @@ -299,7 +300,9 @@ def generation_setup(init_reference_model=True): config = basic_llama_test_config print("Creating generation HfPolicy...") - policy = HfPolicy(cluster=cluster, config=config) + policy = HfPolicy( + cluster=cluster, config=config, init_reference_model=request.param + ) # Create a test batch print("Creating test batch...") @@ -362,23 +365,10 @@ def generation_setup(init_reference_model=True): policy.worker_group.shutdown() -@pytest.fixture -def generation_setup_no_ref_model(): - for item in generation_setup(init_reference_model=False): - yield item - - -@pytest.fixture -def generation_setup_with_ref_model(): - for item in generation_setup(init_reference_model=False): - yield item - - @pytest.mark.timeout(180) -def test_hf_policy_generation(generation_setup_no_ref_model, tracker): - policy, cluster, data, tokenizer, prompts, expected_generations = ( - generation_setup_no_ref_model - ) +@pytest.mark.parametrize("generation_setup", [False], indirect=True) +def test_hf_policy_generation(generation_setup, tracker): + policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" @@ -463,10 +453,9 @@ def test_hf_policy_generation(generation_setup_no_ref_model, tracker): @pytest.mark.timeout(180) -def test_all_hf_policy_generation_lps_ref_training(generation_setup_with_ref_model): - policy, cluster, data, tokenizer, prompts, expected_generations = ( - generation_setup_with_ref_model - ) +@pytest.mark.parametrize("generation_setup", [True], indirect=True) +def test_all_hf_policy_generation_lps_ref_training(generation_setup): + policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup # Verify resources were created properly assert policy is not None, "Generation policy was not created properly"