Skip to content

Refactor checkpoint system with StateDictTransforms and converter protocol#2898

Draft
mori360 wants to merge 11 commits intogh/mori360/11/basefrom
gh/mori360/11/head
Draft

Refactor checkpoint system with StateDictTransforms and converter protocol#2898
mori360 wants to merge 11 commits intogh/mori360/11/basefrom
gh/mori360/11/head

Conversation

@mori360
Copy link
Copy Markdown
Contributor

@mori360 mori360 commented Apr 8, 2026

Summary

This commit refactors the CheckpointManager to cleanly separate concerns:

  1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.
  2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:
  • key_filter(fqn) -> bool: Identifies converter-owned keys
  • state_dict_transform(sd) -> sd: Applies reverse transforms for export
  • state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion
  1. ModelWrapper modes: state_dict(mode) supports three modes:
  • "full": Complete state dict for interval saves and resume
  • "base": Excludes converter-owned keys (for HF container creation)
  • "export": Applies state_dict_transform for last-step export saves
  1. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().
  2. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.
  3. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.
  4. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.init signature with sd_transforms parameter.

Test plan

  1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
fig_commit2
  1. Performance Analysis — cache_state_dict Removal
    Old code (main):
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call

New code:

  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
The checkpoint refactor has no performance regression.

Stack from ghstack (oldest at bottom):


…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Apr 8, 2026
…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

ghstack-source-id: 7c168ac
Pull Request resolved: #2898
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 8, 2026
…nverter protocol"

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Apr 8, 2026
…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

ghstack-source-id: ae00ce9
Pull Request resolved: #2898
…nverter protocol"

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Apr 8, 2026
…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

ghstack-source-id: 7152ab2
Pull Request resolved: #2898
@mori360
Copy link
Copy Markdown
Contributor Author

mori360 commented Apr 9, 2026

@claude review the pr

@claude
Copy link
Copy Markdown

claude bot commented Apr 9, 2026

Claude finished @mori360's task in 2m 55s —— View job


Review: Refactor checkpoint system with StateDictTransforms and converter protocol

  • Read all changed files and understand the diff
  • Review state_dict_transforms.py (new file)
  • Review checkpoint.py changes
  • Review model_converter.py protocol changes
  • Review trainer integration changes
  • Review test coverage
  • Post comprehensive review

Overall Assessment

This is a substantial refactor that cleanly separates concerns that were previously tangled in the checkpoint manager. The overall direction is good: StateDictTransforms for content transforms, the converter protocol for per-converter hooks, and the ModelWrapper modes for different state dict views. The test coverage is reasonable for the new abstractions.

However, there are several bugs and issues that need to be addressed before merging.


Bugs

1. Dead code after raise in dcp_save (Critical)

checkpoint.py:506-531 — The raise ValueError(...) at line 508 makes all code from line 512 onward unreachable. The entire HF storage writer setup branch is dead code:

if to_hf and self.sd_transforms.sd_adapter is None:
    raise ValueError(
        "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
    )
    # The state_dict is already in HF format ...  <-- UNREACHABLE
    if fqn_to_index_mapping:
        storage_writer = HuggingFaceStorageWriter(...)
    else:
        storage_writer = HuggingFaceStorageWriter(...)

This means to_hf=True saves will never use HuggingFaceStorageWriter — they'll fall through to the else branch at line 532 and save as DCP format. This breaks HF safetensors saving entirely.

The old code had the correct structure: if to_hf: with the writer setup inside, and the assert before the writer construction. The refactor inverted the condition. This should be:

if to_hf:
    if self.sd_transforms.sd_adapter is None:
        raise ValueError(...)
    if fqn_to_index_mapping:
        storage_writer = HuggingFaceStorageWriter(...)
    else:
        storage_writer = HuggingFaceStorageWriter(...)
else:
    checkpoint_save_id = checkpoint_id

Fix this →

2. Dead code after raise in _save_last_step (Critical)

checkpoint.py:1005-1033 — Same pattern: raise ValueError(...) at line 1007 makes the entire HF-format save logic (converter key splitting, FQN mapping merging) unreachable:

if self.last_save_in_hf and not self.last_save_model_only:
    raise ValueError(
        "Only model can be saved when saving in HF safetensors format."
    )
    # Split by converter key_filters ...  <-- UNREACHABLE
    remaining = dict(states)
    mapped: dict[str, Any] = {}
    ...

This should be restructured: the validation check (last_save_in_hf and not last_save_model_only) should raise, and the actual HF save logic should be in a separate elif self.last_save_in_hf: block:

if self.last_save_in_hf and not self.last_save_model_only:
    raise ValueError(...)
if self.last_save_in_hf:
    # Split by converter key_filters ...
    ...

Fix this →


Design / Correctness Issues

3. dcp_save now takes fqn_to_index_mapping but never uses the model adapter's mapping

In the old code, dcp_save internally called self.sd_adapter.to_hf(state_dict) and used self.sd_adapter.fqn_to_index_mapping. The new code removes the to_hf content transform from dcp_save (correct — separation of concerns), but now requires the caller to pass the correct fqn_to_index_mapping. This is only done in _save_last_step, but due to the dead code bug above, it never actually executes. Once bug #2 is fixed, verify that the merged fqn_to_index_mapping from all adapters is correct.

4. _load_with_adapter uses assert for user-facing validation

checkpoint.py:666 — Per project guidelines (CLAUDE.md), ValueError should be used for user-facing errors, not assert:

assert adapter is not None  # Should be ValueError

This is a user-facing scenario: if _load_with_adapter is called without an adapter and the model's sd_adapter is also None, the user gets an obscure AssertionError instead of a helpful message.

5. _resolve_initial_load validation vs. old code

The old code used assert model_only and assert from_hf for the HF/quantized validations. The new code correctly uses raise ValueError(...), which is an improvement. However, the _resolve_initial_load method at line 862-868 has a combined check:

if from_hf and (
    self.sd_transforms.sd_adapter is None
    or self.sd_transforms.hf_assets_path is None
):
    raise ValueError(...)

This is checked even when initial_load_path is provided, but the return at line 858 means this code is only reached when initial_load_path is None. This is correct behavior but could be clearer with a comment.


Style / Maintainability

6. Duplicated StateDictTransforms construction across 3 trainers

trainer.py:453-462, forge/engine.py:264-273, ft/trainer.py:306-318 — The identical StateDictTransforms(...) construction is duplicated in all three trainer files. Per project guidelines on reuse over duplication, consider extracting this into a helper or making it part of the CheckpointManager construction.

7. _STEP_DIR_RE extraction

checkpoint.py:58_STEP_DIR_RE = re.compile(r"step-(\d+)") was previously inline (or defined differently). The extraction to a module-level constant is a good improvement.

8. FTCheckpointManager.__init__ uses **kwargs forwarding

ft/checkpoint.py:84**kwargs: Any to forward new CheckpointManager params is pragmatic but fragile. If someone adds a required parameter to CheckpointManager.__init__ that conflicts with a kwargs key, the error will be confusing. Consider explicitly listing the forwarded params (key_filter, state_dict_transform, converter_sd_adapters) for clarity.

9. ModelWrapper.state_dict(mode) — mode as string vs. enum

checkpoint.py:109mode: str = "full" with validation against a frozenset is fine, but a Literal["full", "base", "export"] type hint would give better IDE support and static analysis than a plain str.


Test Coverage Observations

10. Tests don't cover the HF save/load paths

The new TestMultiSourceLoading is a good addition, but there's no test for:

Given the critical bugs in the HF paths, adding tests there would be valuable.

11. TestModelWrapper tests are well-structured

The LoRA-focused TestModelWrapper tests (lines 706-800) covering full/base/export modes with a concrete LoRA merge transform are a good pattern. These effectively validate the ModelWrapper abstraction.


Minor Nits

  • checkpoint.py:586has_converter_keys = self.states[MODEL]._key_filter is not None accesses a private attribute directly. Consider adding a property to ModelWrapper like has_key_filter instead.
  • state_dict_transforms.py:77 — The apply_dtype_convert docstring says "No-op when export_dtype is float32 (the training default)" but doesn't explain why float32 is special (answer: it's the native training dtype, so no conversion needed).
  • checkpoint.py:438-441 — The validation if self.last_save_in_hf and sd_transforms.sd_adapter is None: raise ValueError(...) is good — this catches the error early at init time rather than at save time.

Summary

The architectural direction of this refactor is sound — separating content transforms from I/O, giving converters explicit hooks, and removing the eager cache. However, bugs #1 and #2 are critical: they make all HF safetensors save paths broken (dead code after raise). These must be fixed before this can land. Once those are fixed, the multi-adapter HF save logic in _save_last_step should be tested end-to-end.

…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Apr 9, 2026
…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

ghstack-source-id: 2a1566e
Pull Request resolved: #2898
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Apr 9, 2026
…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

ghstack-source-id: 08fbeff
Pull Request resolved: #2898
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
mori360 added a commit that referenced this pull request Apr 11, 2026
…tocol

Separate state dict content transforms from checkpoint I/O orchestration
into StateDictTransforms. Extend ModelConverter protocol with key_filter(),
state_dict_transform(), and state_dict_adapter() for per-converter checkpoint
support. Refactor CheckpointManager: extract load helpers, add multi-source
loading (additional_load_path), split save/do_save for subclass control, and
add ModelWrapper modes (full/base/export). Wire converter params through all
trainers. Add 19 unit tests for ModelWrapper, path resolution, multi-source
loading, and save/purge guards.

ghstack-source-id: 465a91f
Pull Request resolved: #2898
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
mori360 added 4 commits April 13, 2026 13:29
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
…nverter protocol"

### Summary
This commit refactors the CheckpointManager to cleanly separate concerns:
                                                                                                                        
1. StateDictTransforms: New class (state_dict_transforms.py) that separates content transforms (dtype conversion, HF adapter conversion) from checkpoint I/O orchestration. Previously these were mixed into dcp_save/dcp_load.                                                                   
2. Converter protocol: Model converters (Float8, QAT, etc.) now implement three hooks:                                
  - key_filter(fqn) -> bool: Identifies converter-owned keys                                                          
  - state_dict_transform(sd) -> sd: Applies reverse transforms for export                                             
  - state_dict_adapter() -> BaseStateDictAdapter | None: Provides adapter for HF conversion                           
3. ModelWrapper modes: state_dict(mode) supports three modes:                                                         
  - "full": Complete state dict for interval saves and resume                                                         
  - "base": Excludes converter-owned keys (for HF container creation)                                                 
  - "export": Applies state_dict_transform for last-step export saves                                                 
4. Eliminated cache_state_dict: ModelWrapper no longer eagerly caches the model state dict at construction or refreshes it after every load_state_dict(). State dict is fetched on-demand via _get_state_dict().                                                                           
5. Multi-source loading: Structured _resolve_initial_load() with from_hf/from_quantized flags and a _CheckpointLoadSpec named tuple.                                                                          
6. Save/do_save split: save() handles gating logic (interval, last step), _do_save() handles execution. _save_last_step() is a separate method for model-only/HF saves.                                    
7. Updated all trainers: train.py, graph_trainer, and ft (fault tolerance) all updated to use new CheckpointManager.__init__ signature with sd_transforms parameter.    

### Test plan
1. Checkpoint Save/Load is Lossless — 50 + 50 Steps
<img width="2684" height="770" alt="fig_commit2" src="https://github.com/user-attachments/assets/8cf7e1b7-96cf-4624-a585-a0b96cfb42cf" />

2. Performance Analysis — cache_state_dict Removal
Old code (main):
```
  class ModelWrapper(Stateful):
      def __init__(self, model):
          self.model = [model] if isinstance(model, nn.Module) else model
          self.cache_state_dict = self._get_state_dict()   # (A) eager O(N) at init, 1 call (eager cache)

      def _get_state_dict(self):
          return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}

      def state_dict(self):
          return self.cache_state_dict                      # (B) O(1) — return cached ref, 0 calls for both save or load

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)
          self.cache_state_dict = self._get_state_dict()    # (C) O(N) refresh after load, 1 call
```

New code:
```
  class ModelWrapper(Stateful):
      def __init__(self, model, *, key_filter=None, state_dict_transform=None):
          pass                                              # no cache, no eager call, 0 call

      def state_dict(self, mode="full"):
          sd = self._get_state_dict()                       # O(N) on-demand, 1 call for save or load
          # optional filtering/transform based on mode
          return sd

      def load_state_dict(self, state_dict):
          set_model_state_dict(...)                         # no cache refresh, 0 call
```

Then we run experiments to save/load llama3_8b model 8 GPUs with FSDP, selective activation checkpointing, seq_len=2048, batch_size=1, 10 training steps with checkpoint at step 5. (skip the first round to avoid noise from system warm up)

Before this PR:
Save: 78.74s, Load: 20.81s
After this PR:
Save: 76.58s, Load: 21.7s

Summery:
Same total get_model_state_dict() calls, just redistributed from init+post-load to save+load-container.
**The checkpoint refactor has no performance regression.**



* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant