Skip to content

te.autocast() and fully_shard doesn't de-allocate quantized weights until the backwards pass #2681

@pstjohn

Description

@pstjohn

When running a model with fsdp2 and te.autocast(), quantized weights are created during the forward pass and not de-allocated until the backwards pass. This essentially undoes the memory savings of FSDP2, since we end up accumulating the entire model's worth of quantized weights on each rank.

https://github.com/NVIDIA/bionemo-framework/blob/main/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py

https://nvidia.slack.com/archives/C03V462SAMS/p1771004012335309

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions