Refactor checkpoint system with StateDictTransforms and converter protocol#2898
Refactor checkpoint system with StateDictTransforms and converter protocol#2898mori360 wants to merge 11 commits intogh/mori360/11/basefrom
Conversation
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. [ghstack-poisoned]
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. ghstack-source-id: 7c168ac Pull Request resolved: #2898
…nverter protocol" Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. [ghstack-poisoned]
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. ghstack-source-id: ae00ce9 Pull Request resolved: #2898
…nverter protocol" Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. [ghstack-poisoned]
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. ghstack-source-id: 7152ab2 Pull Request resolved: #2898
|
@claude review the pr |
|
Claude finished @mori360's task in 2m 55s —— View job Review: Refactor checkpoint system with StateDictTransforms and converter protocol
Overall AssessmentThis is a substantial refactor that cleanly separates concerns that were previously tangled in the checkpoint manager. The overall direction is good: However, there are several bugs and issues that need to be addressed before merging. Bugs1. Dead code after
if to_hf and self.sd_transforms.sd_adapter is None:
raise ValueError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
# The state_dict is already in HF format ... <-- UNREACHABLE
if fqn_to_index_mapping:
storage_writer = HuggingFaceStorageWriter(...)
else:
storage_writer = HuggingFaceStorageWriter(...)This means The old code had the correct structure: if to_hf:
if self.sd_transforms.sd_adapter is None:
raise ValueError(...)
if fqn_to_index_mapping:
storage_writer = HuggingFaceStorageWriter(...)
else:
storage_writer = HuggingFaceStorageWriter(...)
else:
checkpoint_save_id = checkpoint_id2. Dead code after
if self.last_save_in_hf and not self.last_save_model_only:
raise ValueError(
"Only model can be saved when saving in HF safetensors format."
)
# Split by converter key_filters ... <-- UNREACHABLE
remaining = dict(states)
mapped: dict[str, Any] = {}
...This should be restructured: the validation check ( if self.last_save_in_hf and not self.last_save_model_only:
raise ValueError(...)
if self.last_save_in_hf:
# Split by converter key_filters ...
...Design / Correctness Issues3. In the old code, 4.
assert adapter is not None # Should be ValueErrorThis is a user-facing scenario: if 5. The old code used if from_hf and (
self.sd_transforms.sd_adapter is None
or self.sd_transforms.hf_assets_path is None
):
raise ValueError(...)This is checked even when Style / Maintainability6. Duplicated
7.
8.
9.
Test Coverage Observations10. Tests don't cover the HF save/load paths The new
Given the critical bugs in the HF paths, adding tests there would be valuable. 11. The LoRA-focused Minor Nits
SummaryThe architectural direction of this refactor is sound — separating content transforms from I/O, giving converters explicit hooks, and removing the eager cache. However, bugs #1 and #2 are critical: they make all HF safetensors save paths broken (dead code after |
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. ghstack-source-id: 2a1566e Pull Request resolved: #2898
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. ghstack-source-id: 08fbeff Pull Request resolved: #2898
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…tocol Separate state dict content transforms from checkpoint I/O orchestration into StateDictTransforms. Extend ModelConverter protocol with key_filter(), state_dict_transform(), and state_dict_adapter() for per-converter checkpoint support. Refactor CheckpointManager: extract load helpers, add multi-source loading (additional_load_path), split save/do_save for subclass control, and add ModelWrapper modes (full/base/export). Wire converter params through all trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source loading, and save/purge guards. ghstack-source-id: 465a91f Pull Request resolved: #2898
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
…nverter protocol"
### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
- key_filter(fqn) -> bool: Identifies converter-owned keys
- state_dict_transform(sd) -> sd: Applies reverse transforms for export
- state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
3. ModelWrapper modes: state_dict(mode) supports three modes:
- "full": Complete state dict for interval saves and resume
- "base": Excludes converter-owned keys (for HF container creation)
- "export": Applies state_dict_transform for last-step export saves
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.
### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />
2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
class ModelWrapper(Stateful):
def __init__(self, model):
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict() # (A) eager O(N) at init, 1 call (eager cache)
def _get_state_dict(self):
return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
def state_dict(self):
return self.cache_state_dict # (B) O(1) — return cached ref, 0 calls for both save or load
def load_state_dict(self, state_dict):
set_model_state_dict(...)
self.cache_state_dict = self._get_state_dict() # (C) O(N) refresh after load, 1 call
```
New code:
```
class ModelWrapper(Stateful):
def __init__(self, model, *, key_filter=None, state_dict_transform=None):
pass # no cache, no eager call, 0 call
def state_dict(self, mode="full"):
sd = self._get_state_dict() # O(N) on-demand, 1 call for save or load
# optional filtering/transform based on mode
return sd
def load_state_dict(self, state_dict):
set_model_state_dict(...) # no cache refresh, 0 call
```
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
[ghstack-poisoned]
Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
Test plan
Old code (main):
New code:
Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)
Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s
Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
The checkpoint refactor has no performance regression.
Stack from ghstack (oldest at bottom):