Describe the bug
In nemo_reinforcer/algorithms/loss_functions.py, there's a potential risk of getting NaN values when the mask tensor contains all zeros. This occurs in the following code section:
mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item()
...
with torch.no_grad():
probs_ratio = masked_mean(ratios.detach(), mask).item()
probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()
If mask contains all zeros, the masked_mean function would likely perform a division by zero (since it's calculating a mean over masked elements), resulting in NaN values.
Steps/Code to reproduce bug
NA
Expected behavior
Add a check to handle the case when the mask is all zeros:
with torch.no_grad():
if mask.sum() > 0:
probs_ratio = masked_mean(ratios.detach(), mask).item()
probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item()
else:
probs_ratio = 0.0 # Default value when mask is all zeros
probs_ratio_clamped = 0.0 # Default value when mask is all zeros
Environment overview (please complete the following information)
- Environment location: Docker
- Method of install: [pip install or from source]. uv pip install -e '.[dev,test]'
Describe the bug
In
nemo_reinforcer/algorithms/loss_functions.py, there's a potential risk of getting NaN values when the mask tensor contains all zeros. This occurs in the following code section:If
maskcontains all zeros, the masked_mean function would likely perform a division by zero (since it's calculating a mean over masked elements), resulting in NaN values.Steps/Code to reproduce bug
NA
Expected behavior
Add a check to handle the case when the mask is all zeros:
Environment overview (please complete the following information)