Skip to content
1 change: 1 addition & 0 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def setup(
if last_checkpoint_path
else None,
init_optimizer=True,
init_reference_model=False,
)
loss_fn = NLLLoss()
print(f" ✓ Model initialized")
Expand Down
24 changes: 15 additions & 9 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_optimizer: bool = True,
init_reference_model: bool = True,
):
self.cfg = config
# torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call
Expand All @@ -88,12 +89,14 @@ def __init__(
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)
self.reference_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)

if init_reference_model:
self.reference_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)
else:
self.reference_model = None
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# If no pad token is defined, you might need:
if self.tokenizer.pad_token is None:
Expand Down Expand Up @@ -123,9 +126,10 @@ def do_fsdp(model):
self.model.to("cuda")
self.model = do_fsdp(self.model)
self.model = self.move_to_cpu(self.model)
self.reference_model.to("cuda")
self.reference_model = do_fsdp(self.reference_model)
self.reference_model = self.move_to_cpu(self.reference_model)
if self.reference_model is not None:
self.reference_model.to("cuda")
self.reference_model = do_fsdp(self.reference_model)
self.reference_model = self.move_to_cpu(self.reference_model)
self.model.to("cuda")
self._held_reference_model_params = None
# register_fsdp_forward_method(self.model, "generate")
Expand Down Expand Up @@ -893,6 +897,7 @@ def __init__(
init_optimizer: bool = True,
weights_path: Optional[str] = None,
optimizer_path: Optional[str] = None,
init_reference_model: bool = True,
):
if weights_path:
weights_path = os.path.abspath(weights_path)
Expand All @@ -905,6 +910,7 @@ def __init__(
init_optimizer=init_optimizer,
weights_path=weights_path,
optimizer_path=optimizer_path,
init_reference_model=init_reference_model,
)
self.worker_group = RayWorkerGroup(
cluster,
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/models/policy/test_hf_ray_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def training_setup():
config = basic_llama_test_config

print("Creating training HfPolicy...")
policy = HfPolicy(cluster=cluster, config=config)
policy = HfPolicy(cluster=cluster, config=config, init_reference_model=False)

# Create a test batch
print("Creating test batch...")
Expand Down Expand Up @@ -278,7 +278,7 @@ def verify_loss_tensor(loss_tensor):


@pytest.fixture
def generation_setup():
def generation_setup(request):
"""Setup and teardown specifically for generation tests."""
policy = None
cluster = None
Expand All @@ -300,7 +300,9 @@ def generation_setup():
config = basic_llama_test_config

print("Creating generation HfPolicy...")
policy = HfPolicy(cluster=cluster, config=config)
policy = HfPolicy(
cluster=cluster, config=config, init_reference_model=request.param
)

# Create a test batch
print("Creating test batch...")
Expand Down Expand Up @@ -364,6 +366,7 @@ def generation_setup():


@pytest.mark.timeout(180)
@pytest.mark.parametrize("generation_setup", [False], indirect=True)
def test_hf_policy_generation(generation_setup, tracker):
policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup

Expand Down Expand Up @@ -450,6 +453,7 @@ def test_hf_policy_generation(generation_setup, tracker):


@pytest.mark.timeout(180)
@pytest.mark.parametrize("generation_setup", [True], indirect=True)
def test_all_hf_policy_generation_lps_ref_training(generation_setup):
policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup

Expand Down