We currently have a super naive implementation of checkpointing implemented in the code, where we just gather the state dict and save it. For loading, every rank loads the full checkpoint from disk onto CPU memory and goes from there.
This also seems to screw up the optimizer inexplicable for SFT (doesn't occur on GRPO somehow) @ashors1 to add repro for SFT.
The naive system is pretty slow (especially on load, when it hammers the disk). Some dist_cp or similar checkpointing mechanism would be ideal for 'during training' checkpoints (as long as a universal converter script from this to a normal HF checkpoint exists)
We currently have a super naive implementation of checkpointing implemented in the code, where we just gather the state dict and save it. For loading, every rank loads the full checkpoint from disk onto CPU memory and goes from there.
This also seems to screw up the optimizer inexplicable for SFT (doesn't occur on GRPO somehow) @ashors1 to add repro for SFT.
The naive system is pretty slow (especially on load, when it hammers the disk). Some dist_cp or similar checkpointing mechanism would be ideal for 'during training' checkpoints (as long as a universal converter script from this to a normal HF checkpoint exists)