Skip to content

Improve Model checkpointing (hf) #34

@SahilJain314

Description

@SahilJain314

We currently have a super naive implementation of checkpointing implemented in the code, where we just gather the state dict and save it. For loading, every rank loads the full checkpoint from disk onto CPU memory and goes from there.

This also seems to screw up the optimizer inexplicable for SFT (doesn't occur on GRPO somehow) @ashors1 to add repro for SFT.

The naive system is pretty slow (especially on load, when it hammers the disk). Some dist_cp or similar checkpointing mechanism would be ideal for 'during training' checkpoints (as long as a universal converter script from this to a normal HF checkpoint exists)

Metadata

Metadata

Assignees

Labels

PerformanceRelated to improving performance

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions