Skip to content

Enable mixed precision training #13

@parthchadha

Description

@parthchadha

Describe the bug

Mixed precision training as enabled in https://github.com/NVIDIA/reinforcer/blob/main/nemo_reinforcer/models/policy/hf_policy.py#L104 is causing convergence issues on grpo and sft. Also, there is a bug in pytorch + fsdp where the optimizer states are not respected and are kept in model precision (bf16). pytorch/pytorch#143900

This PR should study convergence for both grpo and sft while enabling mixed precision.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions