Describe the bug
Mixed precision training as enabled in https://github.com/NVIDIA/reinforcer/blob/main/nemo_reinforcer/models/policy/hf_policy.py#L104 is causing convergence issues on grpo and sft. Also, there is a bug in pytorch + fsdp where the optimizer states are not respected and are kept in model precision (bf16). pytorch/pytorch#143900
This PR should study convergence for both grpo and sft while enabling mixed precision.
Describe the bug
Mixed precision training as enabled in https://github.com/NVIDIA/reinforcer/blob/main/nemo_reinforcer/models/policy/hf_policy.py#L104 is causing convergence issues on grpo and sft. Also, there is a bug in pytorch + fsdp where the optimizer states are not respected and are kept in model precision (bf16). pytorch/pytorch#143900
This PR should study convergence for both grpo and sft while enabling mixed precision.