When running a model with fsdp2 and te.autocast(), quantized weights are created during the forward pass and not de-allocated until the backwards pass. This essentially undoes the memory savings of FSDP2, since we end up accumulating the entire model's worth of quantized weights on each rank.
https://github.com/NVIDIA/bionemo-framework/blob/main/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
https://nvidia.slack.com/archives/C03V462SAMS/p1771004012335309