[feat] Resume from ckpt#135
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a comprehensive "Strict Resume" feature for Transformers models, enabling the restoration of full training state including optimizer, scheduler, scaler, RNG states, and data progress. Key changes involve implementing load_training_state and read_training_progress across the model, server, and client layers, alongside dataloader enhancements to support sample-level skipping for map-style datasets. Feedback highlights several critical improvements: ensuring deterministic RNG in distributed settings by avoiding unseeded random states, addressing the deprecated use of StopIteration in generators, improving security by using weights_only=True during checkpoint loading, and removing an accidental BOM character in the client generator. Additionally, a more robust approach for re-initializing the dataloader is suggested to avoid modifying private PyTorch attributes.
|
/gemini summary |
| def skip_consumed_samples(self, consumed_train_samples: int): | ||
| response = http_post( | ||
| url=f'{self.server_url}/call', | ||
| json_data={ | ||
| 'processor_id': self.processor_id, | ||
| 'function': 'skip_consumed_samples', | ||
| **{'consumed_train_samples': consumed_train_samples}, | ||
| } | ||
| ) | ||
| response.raise_for_status() | ||
| return response.json()["result"] | ||
|
|
||
|
|
||
| def resume_from_checkpoint(self, consumed_train_samples, **kwargs): | ||
| response = http_post( | ||
| url=f'{self.server_url}/call', | ||
| json_data={ | ||
| 'processor_id': self.processor_id, | ||
| 'function': 'resume_from_checkpoint', | ||
| **{'consumed_train_samples': consumed_train_samples}, | ||
| **kwargs | ||
| } | ||
| ) | ||
| response.raise_for_status() | ||
| return response.json()["result"] | ||
|
|
||
|
|
||
| def get_state(self): | ||
| response = http_post( | ||
| url=f'{self.server_url}/call', | ||
| json_data={ | ||
| 'processor_id': self.processor_id, | ||
| 'function': 'get_state', | ||
| **{}, | ||
| } | ||
| ) | ||
| response.raise_for_status() | ||
| return response.json()["result"] |
There was a problem hiding this comment.
This part of the code should be modified in the client_generator.py, and then the script should be called to generate.
There was a problem hiding this comment.
The DataLoader client file was regenerated, not manually edited.
These methods are generated from the new @remote_function() methods in src/twinkle/dataloader/dataloader.py; no generator logic change is needed for this case because client_tools/client_generator.py already scans remote functions and emits the client wrappers.
There was a problem hiding this comment.
Pull request overview
This PR adds end-to-end checkpoint resumption for Twinkle training, covering model weights plus full training state (optimizer/scheduler/scaler/RNG) and dataloader progress (skipping already-consumed samples), with corresponding client/server API surface, documentation, and examples.
Changes:
- Add
resume_from_checkpoint(...)to model interfaces and implementations (Transformers + Megatron), and persist/restore training state (trainer_state.json, optimizer/scheduler/scaler/RNG). - Add dataloader-side consumed-sample tracking + skipping/resume helpers and sampler changes to ensure skipping happens before DP slicing.
- Expose the feature via Twinkle server endpoint + generated client types/wrappers, and add tests/docs/cookbook examples.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/dataloader/test_sampler.py | Adds ordering test for skip vs device-mesh slicing. |
| tests/dataloader/test_dataloader.py | Adds tests for map-style skipping + iterable-dataset warning behavior. |
| src/twinkle_client/types/model.py | Adds request/response models for resume endpoint. |
| src/twinkle_client/types/init.py | Re-exports new resume-related types. |
| src/twinkle_client/model/multi_lora_transformers.py | Adds generated client method resume_from_checkpoint. |
| src/twinkle_client/dataloader/dataloader.py | Adds generated client methods for skip/resume/get_state on dataloader processor. |
| src/twinkle/server/model/twinkle_handlers.py | Adds /twinkle/resume_from_checkpoint endpoint to trigger server-side resume. |
| src/twinkle/model/base.py | Extends base model interface with abstract resume_from_checkpoint. |
| src/twinkle/model/transformers/transformers.py | Saves/restores training state (optimizer/scheduler/scaler/RNG) and implements resume_from_checkpoint. |
| src/twinkle/model/transformers/multi_lora_transformers.py | Updates load flow to restore training state (and accept local paths). |
| src/twinkle/model/transformers/strategy/native_fsdp.py | Adds strategy hooks to save/load optimizer checkpoints (incl. wrapped/FSDP2 cases). |
| src/twinkle/model/transformers/strategy/accelerate.py | Adds optimizer checkpoint save/load support for Accelerate (incl. FSDP2 plugin). |
| src/twinkle/model/transformers/strategy/sequence_parallel/init.py | Adds default optimizer checkpoint save/load hooks for SP strategy. |
| src/twinkle/model/megatron/megatron.py | Persists trainer_state metadata and adds resume_from_checkpoint. |
| src/twinkle/model/megatron/multi_lora_megatron.py | Adds multi-LoRA Megatron resume: per-rank optimizer + RNG + trainer_state. |
| src/twinkle/dataloader/dataloader.py | Adds consumed-sample tracking, skip/resume API, and rebuildable sampler stack. |
| src/twinkle/dataloader/retry_sampler.py | Adds deterministic seeding + skip-aware emission behavior. |
| src/twinkle/dataloader/device_mesh_sampler.py | Adds skip-before-slice support at the batch-sampler layer. |
| docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md | Updates client-server example to use resume flow and persist consumed samples. |
| docs/source_en/Usage Guide/Quick-Start.md | Adds resume-from-checkpoint guide and notes for client-server training. |
| docs/source_en/Components/Model/TransformersModel.md | Documents save/resume semantics and dataloader progress tracking. |
| cookbook/transformers/fsdp2.py | Updates example to demonstrate resume + dataloader skipping + checkpoint metadata. |
| cookbook/megatron/tp_resume.py | New example demonstrating Megatron TP/PP/DP resume flow. |
| cookbook/client/twinkle/self_host/self_cognition.py | Updates client-server training example to use resume flow and persist consumed samples. |
| client_tools/client_generator.py | Updates generator to include resume_from_checkpoint in generated clients. |
PR type
PR information
实现完整训练状态的恢复——包括优化器、调度器、RNG配置以及数据集跳过