From 0030eaaedaea756e322be43121f7ee65158652b9 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 15 Apr 2025 13:23:06 +0800 Subject: [PATCH 1/4] fix: prevent division by zero in ClippedPGLossFn calculation Signed-off-by: Zhaopeng Qiu --- nemo_reinforcer/algorithms/loss_functions.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 158c9824eb..97f24711c4 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -91,7 +91,9 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) - mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item() + mult_prob_error = ( + (torch.exp(lp_error) * mask).sum() / (mask.sum() + 1e-10) + ).item() next_token_logits = next_token_logits[:, :-1] # Remove last position's logits next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) @@ -127,13 +129,14 @@ def __call__( if mask.sum() > 0: actor_loss = masked_mean(torch.max(loss1, loss2), mask) loss = actor_loss + kl + with torch.no_grad(): + probs_ratio = masked_mean(ratios.detach(), mask).item() + probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item() else: # disable this update since there are no valid tokens loss = loss1.view(-1)[0] * 0 - - with torch.no_grad(): - probs_ratio = masked_mean(ratios.detach(), mask).item() - probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item() + probs_ratio = 0 + probs_ratio_clamped = 0 return ( loss, From 94e135b8567d79c3aedeea2963c019c95747aa36 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 15 Apr 2025 13:23:06 +0800 Subject: [PATCH 2/4] handle all zero mask inside masked_mean func; add unit test Signed-off-by: Zhaopeng Qiu --- nemo_reinforcer/algorithms/loss_functions.py | 20 ++++++-------------- nemo_reinforcer/algorithms/utils.py | 4 +++- tests/unit/algorithms/test_loss_functions.py | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 97f24711c4..6db8b823d9 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -91,9 +91,7 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) - mult_prob_error = ( - (torch.exp(lp_error) * mask).sum() / (mask.sum() + 1e-10) - ).item() + mult_prob_error = masked_mean(lp_error, mask).item() next_token_logits = next_token_logits[:, :-1] # Remove last position's logits next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) @@ -126,17 +124,11 @@ def __call__( loss1 = -advantages * ratios loss2 = -advantages * ratios_clamped - if mask.sum() > 0: - actor_loss = masked_mean(torch.max(loss1, loss2), mask) - loss = actor_loss + kl - with torch.no_grad(): - probs_ratio = masked_mean(ratios.detach(), mask).item() - probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item() - else: - # disable this update since there are no valid tokens - loss = loss1.view(-1)[0] * 0 - probs_ratio = 0 - probs_ratio_clamped = 0 + actor_loss = masked_mean(torch.max(loss1, loss2), mask) + loss = actor_loss + kl + with torch.no_grad(): + probs_ratio = masked_mean(ratios.detach(), mask).item() + probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item() return ( loss, diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a3c42e2a19..c9dbf47b3c 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -118,9 +118,11 @@ def wrapper(*args, **kwargs): # need to surpress the masked tensor warnings from pytorch @surpress_user_warnings -def masked_mean(values, mask, dim=None): +def masked_mean(values, mask, dim=None, check_all_zero_mask=True): """Masks values with mask, and computes the mean of the values using the masked values.""" if dim is None: + if check_all_zero_mask and mask.sum() == 0: + return values.sum() * 0 return values[mask.bool()].mean() return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index af78baf34d..99fa4967ca 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -386,3 +386,17 @@ def test_clipped_pg_loss_zero_mask(): # Loss should be exactly zero torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) + + +def test_masked_mean_all_zeros(): + """Test masked_mean function with all zeros mask.""" + values = torch.tensor([1.0, 2.0, 3.0, 4.0]) + mask = torch.zeros_like(values) + + # With check_zero_mask=True (default) + result = masked_mean(values, mask) + assert torch.assert_allclose(result, torch.tensor(0.0)) + + # With check_zero_mask=False + result = masked_mean(values, mask, check_all_zero_mask=False) + assert torch.isnan(result) # Should be nan when mask is all zeros From f858b5c3547d83ad80c90bb1c17428b8a33e8754 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 15 Apr 2025 13:23:06 +0800 Subject: [PATCH 3/4] update logic for avoiding div-0; add unit test Signed-off-by: Zhaopeng Qiu --- nemo_reinforcer/algorithms/utils.py | 8 ++------ tests/unit/algorithms/test_loss_functions.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index c9dbf47b3c..e68ea1d6fa 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -118,13 +118,9 @@ def wrapper(*args, **kwargs): # need to surpress the masked tensor warnings from pytorch @surpress_user_warnings -def masked_mean(values, mask, dim=None, check_all_zero_mask=True): +def masked_mean(values, mask, dim=None): """Masks values with mask, and computes the mean of the values using the masked values.""" - if dim is None: - if check_all_zero_mask and mask.sum() == 0: - return values.sum() * 0 - return values[mask.bool()].mean() - return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) + return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8) def set_seed(seed: int): diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 99fa4967ca..447bd20a54 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -393,10 +393,18 @@ def test_masked_mean_all_zeros(): values = torch.tensor([1.0, 2.0, 3.0, 4.0]) mask = torch.zeros_like(values) - # With check_zero_mask=True (default) + # All zeros mask should return 0 result = masked_mean(values, mask) - assert torch.assert_allclose(result, torch.tensor(0.0)) + print(result) + torch.testing.assert_allclose(result, torch.tensor(0.0)) # With check_zero_mask=False - result = masked_mean(values, mask, check_all_zero_mask=False) - assert torch.isnan(result) # Should be nan when mask is all zeros + mask[0] = 1 + result = masked_mean(values, mask) + torch.testing.assert_allclose(result, torch.tensor(1.0)) + + # Case 2: dim is not None + values = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + mask = torch.zeros_like(values) + result = masked_mean(values, mask, dim=1) + torch.testing.assert_allclose(result, torch.tensor([0.0, 0.0])) From 91be42d6f8c1175c5fabf68dc84f4fdb0c208a2d Mon Sep 17 00:00:00 2001 From: Alex Qiu Date: Tue, 15 Apr 2025 13:23:06 +0800 Subject: [PATCH 4/4] fix mult_prob_error calc bug Signed-off-by: Alex Qiu --- nemo_reinforcer/algorithms/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 6db8b823d9..320ccbbe6f 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -91,7 +91,7 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) - mult_prob_error = masked_mean(lp_error, mask).item() + mult_prob_error = masked_mean(torch.exp(lp_error), mask).item() next_token_logits = next_token_logits[:, :-1] # Remove last position's logits next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)