diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 15bfead6..97b20056 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -449,6 +449,7 @@ def generate_models(): GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -618,6 +619,15 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def resume_from_checkpoint(self, name: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: + response = http_post( + url=f'{self.server_url}/resume_from_checkpoint', + json_data={'name': name, 'adapter_name': self.adapter_name, + 'resume_only_model': resume_only_model, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( diff --git a/cookbook/client/twinkle/self_host/self_cognition.py b/cookbook/client/twinkle/self_host/self_cognition.py index c4e14d30..5d7fa666 100644 --- a/cookbook/client/twinkle/self_host/self_cognition.py +++ b/cookbook/client/twinkle/self_host/self_cognition.py @@ -99,16 +99,19 @@ def train(): # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint + start_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] # Step 7: Run the training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for cur_step, batch in enumerate(dataloader, start=start_step + 1): # Forward pass + backward pass (computes gradients) model.forward_backward(inputs=batch) @@ -125,13 +128,17 @@ def train(): # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) - if step % 2 == 0: + if cur_step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 9: Upload the checkpoint to ModelScope Hub diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py index 4ea2e13e..650cf67b 100644 --- a/cookbook/megatron/tp.py +++ b/cookbook/megatron/tp.py @@ -1,81 +1,113 @@ -import os +from pathlib import Path + from peft import LoraConfig from tqdm import tqdm import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle import DeviceMesh, get_device_placement, get_logger from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.model import MegatronModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, tp=pp=dp=2 -device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) logger = get_logger() +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' +DP_SIZE = 2 +TP_SIZE = 2 +PP_SIZE = 2 +BATCH_SIZE = 16 +LEARNING_RATE = 1e-4 +LOG_INTERVAL = 5 +EVAL_INTERVAL = 20 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 + +OUTPUT_DIR = './output/megatron_tp' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +device_mesh = DeviceMesh.from_sizes(dp_size=DP_SIZE, tp_size=TP_SIZE, pp_size=PP_SIZE) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=16) - for step, batch in tqdm(enumerate(dataloader)): + return dataset + + +def save_checkpoint(model: MegatronModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): model.forward_only(inputs=batch) - metrics = model.calculate_metric(is_training=False) - return metrics + return model.calculate_metric(is_training=False) def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() - # Global batch size = 1, dp_size = 1 - dataloader = DataLoader(dataset=dataset, batch_size=16) - # Use a MegatronModel - model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B') + dataset = build_dataset(TRAIN_SAMPLES) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) + + model = MegatronModel(model_id=MODEL_ID) lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') - # Add a lora to model, with name `default` # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='default', lr=1e-4) - # Add LRScheduler for lora `default` + model.add_adapter_to_model(ADAPTER_NAME, lora_config) + model.set_optimizer(optimizer_cls='default', lr=LEARNING_RATE) model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader)) + + start_step = 0 + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] + logger.info(get_device_placement()) - # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 - # lora: 10G * 8 - # full: 40G * 8 - for step, batch in enumerate(dataloader): - # Do forward and backward + + best_loss = float('inf') + + for step, batch in enumerate(dataloader, start=start_step): model.forward_backward(inputs=batch) - # Step model.clip_grad_and_step() - if step % 5 == 0: - # Print metric + if step % LOG_INTERVAL == 0: metric = model.calculate_metric(is_training=True) logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 20 == 0: - metrics = eval(model) + if step > 0 and step % EVAL_INTERVAL == 0: + metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) if __name__ == '__main__': diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index 7b6bd2a8..450906c5 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -1,3 +1,5 @@ +from pathlib import Path + from peft import LoraConfig from tqdm import tqdm @@ -8,77 +10,116 @@ from twinkle.model import TransformersModel from twinkle.preprocessor import SelfCognitionProcessor -# Construct a device_mesh, fsdp_size=2, dp=4 -device_mesh = DeviceMesh.from_sizes(fsdp_size=2, dp_size=4) +logger = get_logger() + +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' +DATASET_ID = 'ms://swift/self-cognition' +TEMPLATE_NAME = 'Qwen3_5Template' +MODEL_NAME = 'twinkle大模型' +MODEL_AUTHOR = 'ModelScope社区' +FSDP_SIZE = 2 +DP_SIZE = 4 +BATCH_SIZE = 8 +LEARNING_RATE = 1e-4 +GRADIENT_ACCUMULATION_STEPS = 2 +LOG_INTERVAL = 20 +EVAL_INTERVAL = 40 +EVAL_SAMPLES = 100 +TRAIN_SAMPLES = 1000 + +OUTPUT_DIR = './output/fsdp2' +RESUME_FROM_CHECKPOINT = None +RESUME_ONLY_MODEL = False +IGNORE_DATA_SKIP = False +ADAPTER_NAME = 'default' + +# Construct a device_mesh +device_mesh = DeviceMesh.from_sizes(fsdp_size=FSDP_SIZE, dp_size=DP_SIZE) # use torchrun mode twinkle.initialize(mode='local', global_device_mesh=device_mesh) -logger = get_logger() - -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) +def build_dataset(num_samples: int) -> Dataset: + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(num_samples))) + dataset.set_template(TEMPLATE_NAME, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor(MODEL_NAME, MODEL_AUTHOR)) dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=8) - for step, batch in tqdm(enumerate(dataloader)): + return dataset + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + model.save( + checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def evaluate(model): + dataloader = DataLoader(dataset=build_dataset(EVAL_SAMPLES), batch_size=BATCH_SIZE) + for batch in tqdm(dataloader): model.forward_only(inputs=batch) model.calculate_loss() - metrics = model.calculate_metric(is_training=False) - return metrics + return model.calculate_metric(is_training=False) def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B') - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() + dataset = build_dataset(TRAIN_SAMPLES) # Global batch size = 8, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = TransformersModel(model_id=MODEL_ID) model.model._no_split_modules = {'Qwen3_5DecoderLayer'} lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) + model.set_optimizer(optimizer_cls='AdamW', lr=LEARNING_RATE) # Add LRScheduler for lora `default` model.set_lr_scheduler( scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + logger.info(get_device_placement()) # Print the training config logger.info(model.get_train_configs()) logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 + optimizer_group = model.optimizer_group[ADAPTER_NAME] + best_loss = float('inf') # lora: 8G * 8 # full: 18G * 8 - for step, batch in enumerate(dataloader): + for batch in dataloader: # Do forward and backward model.forward_backward(inputs=batch) # Step model.clip_grad_and_step() - if step % 20 == 0: + cur_step = optimizer_group.cur_step + if cur_step % LOG_INTERVAL == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: - metrics = eval(model) + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + if cur_step > 0 and cur_step % EVAL_INTERVAL == 0: + metrics = evaluate(model) logger.info(f'Eval metric: {metrics}') - metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') + metrics['step'] = cur_step + current_loss = float(metrics['loss']) + if current_loss < best_loss: + save_checkpoint(model, f'checkpoint-{cur_step}', dataloader) + best_loss = current_loss + save_checkpoint(model, 'last-checkpoint', dataloader) if __name__ == '__main__': diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md index f10b0351..a9005a55 100644 --- a/docs/source_en/Components/Model/TransformersModel.md +++ b/docs/source_en/Components/Model/TransformersModel.md @@ -50,3 +50,17 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## Checkpoint and Resume + +`TransformersModel.save()` can save either weights only or a resumable training checkpoint. + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` saves weights together with optimizer, scheduler, scaler, RNG, and `trainer_state.json`. +- `model.resume_from_checkpoint(checkpoint_dir)` restores full training state (weights, optimizer, scheduler, scaler, RNG) and returns `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`. +- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` loads weights only and returns progress metadata without restoring optimizer state. +- `dataloader.resume_from_checkpoint(consumed_train_samples)` skips already-consumed samples. +- `dataloader.get_state()` returns `{'consumed_train_samples': int}` — the dataloader automatically tracks consumed samples, so you don't need to maintain a counter manually. + +For full-parameter training, restore model weights by constructing `TransformersModel` with the checkpoint path as `model_id`, for example `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`, and then call `resume_from_checkpoint(...)` to restore optimizer state and training progress. + +For end-to-end resume logic, including dataloader skipping, refer to `cookbook/transformers/fsdp2.py`. diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 906f0db0..185b91ac 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -230,6 +230,64 @@ When running, you need to launch training like this: CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### Resume from Checkpoint + +The training loops above can be extended to support checkpoint resumption. For a complete example, refer to `cookbook/transformers/fsdp2.py`. + +**Saving a Checkpoint** + +```python +model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name=ADAPTER_NAME, + save_optimizer=True, # Store optimizer state + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], # Persist training progress +) +``` + +> `DataLoader` automatically tracks consumed samples internally — call `dataloader.get_state()` to retrieve the current count. + +**Resuming Training** + +```python +from pathlib import Path + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False # True: weights only, skip optimizer/scheduler restoration +IGNORE_DATA_SKIP = False # True: do not skip consumed samples from trainer_state.json + +if RESUME_FROM_CHECKPOINT: + checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) + progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) +``` + +How the two flags combine: + +| `RESUME_ONLY_MODEL` | `IGNORE_DATA_SKIP` | Effect | +|---|---|---| +| `False` (default) | `False` (default) | Full resume: restore weights + optimizer + scheduler + RNG, skip consumed data | +| `True` | `False` | Weights only, but still skip consumed data (restart optimization from fresh) | +| `True` | `True` | Weights only, restart dataset from the beginning | + +**LoRA / Adapter vs Full-Parameter Training** + +The flow above uses LoRA as the default example. For full-parameter training, the only difference is in `TransformersModel` initialization — use the checkpoint path as `model_id` instead of the base model ID: + +```python +# LoRA / adapter: base model loaded from hub, checkpoint contains only adapter weights + training state +model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') +progress = model.resume_from_checkpoint(resume_path) + +# Full-parameter: model weights are saved entirely in the checkpoint — use it directly as model_id +model = TransformersModel(model_id=resume_path) +progress = model.resume_from_checkpoint(resume_path) +``` + +> All subsequent calls to `resume_from_checkpoint` and `dataloader.resume_from_checkpoint` are identical in both cases. + ### Ray Training [Ray](https://github.com/ray-project/ray) is a commonly used scheduling middleware framework for multi-machine model training and inference scenarios. It provides additional optimizations for multi-model, multi-device execution and resource management, and supports integration with Kubernetes systems for production deployment. These characteristics make it particularly suitable for complex training scenarios such as RL and GKD. @@ -413,6 +471,8 @@ python train.py A major feature of Twinkle is support for multi-tenant mixed training. Specifically, multiple users can use a single base model for LoRA training, which can greatly reduce server-side deployment costs. +Checkpoint resumption is also supported in client-server training. The recommended flow is to call `model.resume_from_checkpoint(resume_path)` to restore weights and optimizer state, then call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip consumed data. See [Twinkle-Client](./Server%20and%20Client/Twinkle-Client.md) and [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py). + Suppose we start a service using eight GPUs. First, we need to start the Ray cluster: ```shell diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md index 1b3a0f6c..af53f5ab 100644 --- a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md @@ -133,29 +133,36 @@ model.set_optimizer('Adam', lr=1e-4) # model.set_lr_scheduler('LinearLR') # Step 5: Resume training (optional) +start_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] # Step 6: Training loop logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for cur_step, batch in enumerate(dataloader, start=start_step + 1): # Forward propagation + backward propagation model.forward_backward(inputs=batch) - # Gradient clipping + optimizer update (equivalent to clip_grad_norm / step / zero_grad / lr_step) + # Gradient clipping + optimizer update (equivalent to calling clip_grad_norm / step / zero_grad / lr_step in sequence) model.clip_grad_and_step() - # Log metrics every 2 steps (aligned with gradient_accumulation_steps) - if step % 2 == 0: + # Print metric every 2 steps (aligned with gradient_accumulation_steps) + if cur_step % 2 == 0: metric = model.calculate_metric(is_training=True) - logger.info(f'Epoch {epoch}, step {step}/{len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: Save checkpoint - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 8: Upload to ModelScope Hub (optional) @@ -168,6 +175,14 @@ for epoch in range(3): # ) ``` +For checkpoint resumption, the recommended client-side flow is: + +1. Query the server for an existing checkpoint path with `client.list_checkpoints(...)` or `client.get_latest_checkpoint_path(...)`. +2. Call `model.resume_from_checkpoint(resume_path)` to restore weights, optimizer, scheduler, RNG, and progress metadata. +3. Call `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` to skip already-consumed samples. + +This matches the end-to-end example in `cookbook/client/twinkle/self_host/self_cognition.py`. + ## Differences with Megatron Backend When using the Megatron backend, the main differences in client code: diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 710022bb..3f19af5b 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -231,6 +231,64 @@ if __name__ == '__main__': CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train.py ``` +### 断点续训 + +上面的训练循环可以扩展为支持断点续训。完整示例可直接参考 `cookbook/transformers/fsdp2.py`。 + +**保存检查点** + +```python +model.save( + checkpoint_name, + output_dir='./output/fsdp2', + adapter_name=ADAPTER_NAME, + save_optimizer=True, # 保存优化器状态 + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], # 落盘训练进度 +) +``` + +> `DataLoader` 内部自动追踪已消费样本数,通过 `dataloader.get_state()` 获取。 + +**恢复训练** + +```python +from pathlib import Path + +RESUME_FROM_CHECKPOINT = './output/fsdp2/last-checkpoint' +RESUME_ONLY_MODEL = False # True: 仅恢复权重,不恢复优化器/调度器等训练状态 +IGNORE_DATA_SKIP = False # True: 不从 trainer_state.json 跳过已消费数据 + +if RESUME_FROM_CHECKPOINT: + checkpoint_path = str(Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()) + progress = model.resume_from_checkpoint(checkpoint_path, resume_only_model=RESUME_ONLY_MODEL) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) +``` + +两个开关的组合效果: + +| `RESUME_ONLY_MODEL` | `IGNORE_DATA_SKIP` | 效果 | +|---|---|---| +| `False`(默认) | `False`(默认) | 完整续训:恢复权重 + 优化器 + 调度器 + RNG,并跳过已消费数据 | +| `True` | `False` | 仅恢复权重,但仍跳过已消费数据(适合沿用权重、重新开始优化) | +| `True` | `True` | 仅恢复权重,从数据集开头重新训练 | + +**LoRA / adapter vs 全参训练** + +上述流程默认以 LoRA 为例。全参训练的恢复仅有一处不同——`TransformersModel` 初始化时,`model_id` 需要用 checkpoint 路径替代 base model ID: + +```python +# LoRA / adapter:base model 从 hub 加载,checkpoint 仅含 adapter 权重和训练状态 +model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') +progress = model.resume_from_checkpoint(resume_path) + +# 全参:模型权重已整体保存到 checkpoint,直接将其作为 model_id +model = TransformersModel(model_id=resume_path) +progress = model.resume_from_checkpoint(resume_path) +``` + +> 二者后续的 `resume_from_checkpoint` 及 `dataloader.resume_from_checkpoint` 调用完全一致。 + ### Ray训练 [Ray](https://github.com/ray-project/ray)是多机模型训练和推理场景中常用的调度中间件框架。它针对多模型、多设备的执行和资源管理进行了额外优化, @@ -412,6 +470,7 @@ python train.py ``` ### 远程训练 +client-server 训练场景同样支持断点续训。推荐流程是调用 `model.resume_from_checkpoint(resume_path)` 恢复权重和优化器状态,再调用 `dataloader.resume_from_checkpoint(progress['consumed_train_samples'])` 跳过已消费数据。详细示例可参考 [Twinkle客户端](./服务端和客户端/Twinkle客户端.md) 和 [self_cognition.py](../../../cookbook/client/twinkle/self_host/self_cognition.py)。 Twinkle 的一大特色是支持多租户用户混合训练。具体来说,多个用户可以使用一个基模进行 LoRA 训练,这样可以极大减小服务端部署成本。 diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" index 9548ed35..967479b9 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" @@ -133,16 +133,19 @@ model.set_optimizer('Adam', lr=1e-4) # model.set_lr_scheduler('LinearLR') # Step 5: 恢复训练(可选) +start_step = 0 if resume_path: - logger.info(f'Resuming training from {resume_path}') - model.load(resume_path, load_optimizer=True) + logger.info(f'Resuming from checkpoint {resume_path}') + progress = model.resume_from_checkpoint(resume_path) + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + start_step = progress['cur_step'] # Step 6: 训练循环 logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') - for step, batch in enumerate(dataloader): + for cur_step, batch in enumerate(dataloader, start=start_step + 1): # 前向传播 + 反向传播 model.forward_backward(inputs=batch) @@ -150,12 +153,16 @@ for epoch in range(3): model.clip_grad_and_step() # 每 2 步打印一次指标(与 gradient_accumulation_steps 对齐) - if step % 2 == 0: + if cur_step % 2 == 0: metric = model.calculate_metric(is_training=True) - logger.info(f'Epoch {epoch}, step {step}/{len(dataloader)}, metric: {metric.result}') + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric.result}') # Step 7: 保存检查点 - twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) + twinkle_path = model.save( + name=f'twinkle-epoch-{epoch}', + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) logger.info(f'Saved checkpoint: {twinkle_path}') # Step 8: 上传到 ModelScope Hub(可选) @@ -168,6 +175,14 @@ for epoch in range(3): # ) ``` +Twinkle Client 场景下,推荐的断点续训流程是: + +1. 先通过 `client.list_checkpoints(...)` 或 `client.get_latest_checkpoint_path(...)` 获取已有 checkpoint 路径。 +2. 调用 `model.resume_from_checkpoint(resume_path)` 恢复权重、优化器、调度器、随机数状态和训练进度元数据。 +3. 使用返回结果中的 `consumed_train_samples` 调用 `dataloader.resume_from_checkpoint(...)`,跳过已经训练过的数据。 + +完整示例可直接参考 `cookbook/client/twinkle/self_host/self_cognition.py`。 + ## Megatron 后端的差异 使用 Megatron 后端时,客户端代码的主要差异: diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" index b7b9cf0f..cd0f16ad 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" @@ -50,3 +50,17 @@ for data in dataloader: model.forward_backward(...) model.clip_grad_and_step(..., gradient_accumulation_steps=16) ``` + +## 检查点保存与续训 + +`TransformersModel.save()` 既可以只保存权重,也可以保存可续训的训练检查点。 + +- `model.save(name, save_optimizer=True, consumed_train_samples=...)` 保存权重、优化器、调度器、scaler、RNG 状态和 `trainer_state.json`。 +- `model.resume_from_checkpoint(checkpoint_dir)` 恢复完整训练状态(权重、优化器、调度器、scaler、RNG),返回 `{'cur_step', 'consumed_train_samples', 'gradient_accumulation_steps'}`。 +- `model.resume_from_checkpoint(checkpoint_dir, resume_only_model=True)` 仅加载权重并返回进度元数据,不恢复优化器状态。 +- `dataloader.resume_from_checkpoint(consumed_train_samples)` 跳过已消费的样本。 +- `dataloader.get_state()` 返回 `{'consumed_train_samples': int}` — DataLoader 会自动追踪已消费样本数,无需手动维护计数器。 + +对于全参训练,恢复模型权重时需要在创建 `TransformersModel` 时直接把 checkpoint 路径传给 `model_id`,例如 `TransformersModel(model_id='./output/fsdp2/last-checkpoint')`,随后再调用 `resume_from_checkpoint(...)` 恢复优化器和训练进度。 + +如果需要完整的断点续训流程,包括 dataloader 跳过已消费数据的逻辑,建议直接参考 `cookbook/transformers/fsdp2.py`。 diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py index b3ce4f0f..c392d56c 100644 --- a/src/twinkle/dataloader/dataloader.py +++ b/src/twinkle/dataloader/dataloader.py @@ -1,4 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import os +import warnings from functools import partial from typing import Callable, Optional, Type, Union @@ -51,6 +54,11 @@ def __init__(self, self.dataloader_params['batch_size'] = batch_size self.device_mesh = device_mesh self.processor: Optional[InputProcessor] = None + self._skip_samples = 0 + self._consumed_train_samples = 0 + self._base_batch_sampler = None + self._base_sampler = None + self._retry_sampler_seed = self._resolve_retry_sampler_seed() self._set_work_init_fn() def _set_work_init_fn(self): @@ -60,6 +68,17 @@ def _set_work_init_fn(self): num_workers=num_workers, rank=self.device_mesh.data_rank if self.device_mesh else 0) + @staticmethod + def _resolve_retry_sampler_seed() -> int: + env_seed = os.environ.get('TWINKLE_SEED') + if env_seed is not None: + return int(env_seed) + try: + from twinkle.infra import _seed + return int(_seed) + except Exception: + return 42 + @remote_function() def __len__(self): self._lazy_init_dataloader() @@ -97,7 +116,9 @@ def _lazy_init_dataloader(self): if not isinstance(self.dataset, IterableDataset): self.dataloader.__initialized = False - self._repeat_sample_and_shard() + self._base_batch_sampler = self.dataloader.batch_sampler + self._base_sampler = self.dataloader.sampler + self._rebuild_sampler_stack() self.dataloader.__initialized = True @remote_function() @@ -114,13 +135,57 @@ def __iter__(self): self.batch_size, self.device_mesh, max_retries=self.max_retries) - return _iter - - def _repeat_sample_and_shard(self): - if self.dataloader.batch_sampler is not None and hasattr(self.dataloader.batch_sampler, 'sampler'): - self.dataloader.batch_sampler.sampler = RetrySampler( - self.dataloader.batch_sampler.sampler, self.dataset, max_retries=self.max_retries) - self.dataloader.batch_sampler = DeviceMeshSampler(self.dataloader.batch_sampler, self.device_mesh, - self.min_batch_size) - elif self.dataloader.sampler is not None: - self.dataloader.sampler = RetrySampler(self.dataloader.sampler, self.dataset, max_retries=self.max_retries) + return self._tracking_iter(_iter) + + def _tracking_iter(self, inner): + for batch in inner: + self._consumed_train_samples += self.batch_size + yield batch + + @remote_function() + def skip_consumed_samples(self, consumed_train_samples: int) -> None: + from torch.utils.data import IterableDataset + + if isinstance(self.dataset, IterableDataset): + warnings.warn('IterableDataset does not support consumed-data skipping; continuing without skipping.') + self._skip_samples = 0 + return + + self._skip_samples = max(int(consumed_train_samples), 0) + self._consumed_train_samples = self._skip_samples + if self.dataloader is not None: + self.dataloader.__initialized = False + self._rebuild_sampler_stack() + self.dataloader.__initialized = True + + @remote_function() + def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + self.skip_consumed_samples(consumed_train_samples) + + @remote_function() + def get_state(self) -> dict: + return {'consumed_train_samples': self._consumed_train_samples} + + def _rebuild_sampler_stack(self): + if self._base_batch_sampler is not None and hasattr(self._base_batch_sampler, 'sampler'): + batch_sampler = copy.copy(self._base_batch_sampler) + batch_sampler.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + seed=self._retry_sampler_seed, + ) + self.dataloader.batch_sampler = DeviceMeshSampler( + batch_sampler, + self.device_mesh, + self.min_batch_size, + skip_samples=self._skip_samples, + ) + elif self._base_sampler is not None: + self.dataloader.sampler = RetrySampler( + self._base_sampler, + self.dataset, + max_retries=self.max_retries, + skip_samples=self._skip_samples, + seed=self._retry_sampler_seed, + ) diff --git a/src/twinkle/dataloader/device_mesh_sampler.py b/src/twinkle/dataloader/device_mesh_sampler.py index 955b85cd..1f649de3 100644 --- a/src/twinkle/dataloader/device_mesh_sampler.py +++ b/src/twinkle/dataloader/device_mesh_sampler.py @@ -12,15 +12,28 @@ class DeviceMeshSampler(BatchSampler): device_mesh: The device mesh. """ - def __init__(self, original_sampler: BatchSampler, device_mesh: DeviceMesh, min_batch_size: int = None): + def __init__(self, + original_sampler: BatchSampler, + device_mesh: DeviceMesh, + min_batch_size: int = None, + skip_samples: int = 0): self.original_sampler = original_sampler self.device_mesh = device_mesh self.min_batch_size = min_batch_size + self.skip_samples = skip_samples if self.min_batch_size is None and self.device_mesh is not None: self.min_batch_size = self.device_mesh.data_world_size def __iter__(self): + skipped = 0 for batch in self.original_sampler: + if skipped < self.skip_samples: + if skipped + len(batch) <= self.skip_samples: + skipped += len(batch) + continue + batch = batch[self.skip_samples - skipped:] + skipped = self.skip_samples + if not self.device_mesh: yield batch else: diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py index 6fe84f04..4d8c92e0 100644 --- a/src/twinkle/dataloader/retry_sampler.py +++ b/src/twinkle/dataloader/retry_sampler.py @@ -14,13 +14,22 @@ class RetrySampler(Sampler): max_retries: The maximum number of retries. """ - def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20): + def __init__(self, + original_sampler: Sampler, + dataset: Dataset, + max_retries=20, + skip_samples: int = 0, + seed: int = 42): self.original_sampler = original_sampler self.dataset = dataset self.max_retries = max_retries + self.skip_samples = skip_samples + self.seed = int(seed) def __iter__(self): - total = 0 + emitted = 0 + seen_valid = 0 + target_total = max(len(self.dataset) - self.skip_samples, 0) for idx in self.original_sampler: for _ in range(self.max_retries): try: @@ -29,23 +38,25 @@ def __iter__(self): data = self.dataset[idx] if not data: continue + seen_valid += 1 + if seen_valid <= self.skip_samples: + break yield idx - total += 1 + emitted += 1 break except Exception: # noqa import traceback traceback.print_exc() continue else: - raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.') + raise RuntimeError(f'Max retries exceeded: {self.max_retries}, no valid data found.') - origin_dataset_len = len(self.dataset) - if total >= origin_dataset_len: + if emitted >= target_total: return - for idx in np.random.RandomState().permutation(len(self.dataset)).tolist(): - if total >= origin_dataset_len: - raise StopIteration + for idx in np.random.RandomState(self.seed).permutation(len(self.dataset)).tolist(): + if emitted >= target_total: + return for _ in range(self.max_retries): try: # Skip None values and raises @@ -53,7 +64,7 @@ def __iter__(self): if not data: continue yield idx - total += 1 + emitted += 1 break except Exception: # noqa import traceback diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 1a08b23a..4df53e99 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -87,6 +87,14 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs) -> None: def get_state_dict(self, **kwargs) -> Dict[str, Any]: ... + @abstractmethod + def resume_from_checkpoint(self, + checkpoint_dir: str, + *, + resume_only_model: bool = False, + **kwargs) -> Dict[str, Any]: + ... + @abstractmethod def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None: ... diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 68b6f39a..f61e66ab 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -828,6 +828,17 @@ def save(self, optimizer_config=optimizer_config, **kwargs, ) + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + with open(state_path, 'w') as f: + json.dump(trainer_state, f, indent=2) # Final synchronization to ensure all ranks complete save. if dist.is_initialized(): @@ -849,19 +860,18 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): ``no_load_rng``, etc.). """ resume = kwargs.pop('load_optimizer', False) - if output_dir is None and not resume: - if os.path.exists(name): - checkpoint_dir = name - else: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) - else: - if output_dir is None: - output_dir = 'output' + if output_dir is not None: checkpoint_dir = os.path.join(output_dir, name) + elif os.path.exists(name): + checkpoint_dir = name + elif not resume: + # load from hub + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) + else: + checkpoint_dir = os.path.join('output', name) - adapter_name = kwargs.get('adapter_name', self._get_default_group()) + adapter_name = kwargs.pop('adapter_name', self._get_default_group()) if resume: self._load_mcore_optimizer( @@ -881,6 +891,22 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): if dist.is_initialized(): dist.barrier() + @remote_function(dispatch='all') + def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.pop('adapter_name', self._get_default_group()) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path) as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, adapter_name=adapter_name, **kwargs) + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } + @staticmethod def _get_rng_state() -> 'ShardedObject': from megatron.core import parallel_state as mpu @@ -1091,7 +1117,8 @@ def _load_mcore_optimizer( # Restore optimizer + LR scheduler. if not no_load_optim and optimizer is not None and 'optimizer' in state_dict: - optimizer.load_state_dict(state_dict['optimizer']) + with torch.no_grad(): + optimizer.load_state_dict(state_dict['optimizer']) if (opt_param_scheduler is not None and 'opt_param_scheduler' in state_dict): opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'], ) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 04c7cba5..140a1373 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -1,6 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import json +import numpy as np import os +import random import re +import torch import torch.distributed as dist import torch.nn as nn from contextlib import contextmanager @@ -203,10 +207,71 @@ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], self._check_adapter_valid(kwargs.get('adapter_name')) super().set_lr_scheduler(scheduler_cls, **kwargs) + @staticmethod + def _rank_local_optimizer_path(checkpoint_dir: str) -> str: + rank = dist.get_rank() if dist.is_initialized() else 0 + return os.path.join(checkpoint_dir, f'optimizer_rank_{rank}.pt') + + @staticmethod + def _save_local_training_rng_state(): + from megatron.core import tensor_parallel + + rng_state = { + 'random_rng_state': random.getstate(), + 'np_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + } + if torch.cuda.is_available(): + rng_state['cuda_rng_state'] = torch.cuda.get_rng_state() + rng_state['rng_tracker_states'] = tensor_parallel.get_cuda_rng_tracker().get_states() + return rng_state + + @staticmethod + def _load_local_training_rng_state(rng_state): + from megatron.core import tensor_parallel + + random.setstate(rng_state['random_rng_state']) + np.random.set_state(rng_state['np_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + if 'cuda_rng_state' in rng_state and torch.cuda.is_available(): + torch.cuda.set_rng_state(rng_state['cuda_rng_state']) + tensor_parallel.get_cuda_rng_tracker().set_states(rng_state['rng_tracker_states']) + + def _save_multi_lora_optimizer(self, checkpoint_dir: str, optimizer_config, **kwargs): + os.makedirs(checkpoint_dir, exist_ok=True) + state_dict = { + 'checkpoint_version': 1, + 'iteration': optimizer_config.cur_step, + 'rng_state': self._save_local_training_rng_state(), + } + if optimizer_config.optimizer is not None: + state_dict['optimizer'] = optimizer_config.optimizer.state_dict() + if optimizer_config.lr_scheduler is not None: + state_dict['opt_param_scheduler'] = optimizer_config.lr_scheduler.state_dict() + + torch.save(state_dict, self._rank_local_optimizer_path(checkpoint_dir)) + + def _load_multi_lora_optimizer(self, checkpoint_dir: str, adapter_name: str = '', **kwargs): + no_load_optim = kwargs.pop('no_load_optim', False) + no_load_rng = kwargs.pop('no_load_rng', False) + optimizer_config = self.optimizer_group.get(adapter_name) + state_dict = torch.load(self._rank_local_optimizer_path(checkpoint_dir), map_location='cpu', weights_only=False) + + if not no_load_optim and optimizer_config is not None: + if optimizer_config.optimizer is not None and 'optimizer' in state_dict: + optimizer_config.optimizer.load_state_dict(state_dict['optimizer']) + if optimizer_config.lr_scheduler is not None and 'opt_param_scheduler' in state_dict: + optimizer_config.lr_scheduler.load_state_dict(state_dict['opt_param_scheduler']) + if not no_load_rng and 'rng_state' in state_dict: + self._load_local_training_rng_state(state_dict['rng_state']) + if optimizer_config is not None and 'iteration' in state_dict: + optimizer_config.cur_step = state_dict['iteration'] + @remote_function(dispatch='all', collect='first', sync=True) def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): - self._check_adapter_valid(kwargs.get('adapter_name')) - optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] + adapter_name = kwargs.pop('adapter_name', None) + self._check_adapter_valid(adapter_name) + optimizer_config = self.optimizer_group[adapter_name] if optimizer_config.cur_step % interval != 0: return @@ -215,8 +280,9 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): if output_dir is None: output_dir = 'output' checkpoint_dir = os.path.join(output_dir, name) + save_optimizer = kwargs.pop('save_optimizer', False) - with self.multi_adapter.save_context(kwargs.get('adapter_name')) as real_adapter_name: + with self.multi_adapter.save_context(adapter_name) as real_adapter_name: save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' # Use partial to bind adapter_name to save_lora_converter lora_converter = partial(self.multi_adapter.save_lora_converter, adapter_name=real_adapter_name) @@ -228,7 +294,25 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): else: self._save_megatron_format(checkpoint_dir, real_adapter_name, lora_converter=lora_converter) - self._save_tokenizer(checkpoint_dir, adapter_name=kwargs.get('adapter_name')) + self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) + if save_optimizer: + with self.optimizer_context(real_adapter_name): + self._save_multi_lora_optimizer( + checkpoint_dir, + optimizer_config=optimizer_config, + **kwargs, + ) + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + } + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + with open(os.path.join(checkpoint_dir, 'trainer_state.json'), 'w') as f: + json.dump(trainer_state, f, indent=2) # Final synchronization to ensure all ranks complete save if dist.is_initialized(): dist.barrier() @@ -237,25 +321,55 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): @remote_function(dispatch='all') def load(self, name: str, output_dir: Optional[str] = None, **kwargs): + load_optimizer = kwargs.pop('load_optimizer', False) + adapter_name = kwargs.pop('adapter_name', None) if output_dir is None: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) + if os.path.exists(name): + checkpoint_dir = name + else: + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) else: checkpoint_dir = os.path.join(output_dir, name) bridge = self.strategy.bridge - with self.multi_adapter.save_context(kwargs.get('adapter_name')) as adapter_name: + with self.multi_adapter.save_context(adapter_name) as real_adapter_name: model = self.strategy.unwrap_model(self.model) bridge.load_weights( model, checkpoint_dir, peft_format=True, - adapter_name=adapter_name, + adapter_name=real_adapter_name, converter=self.multi_adapter.load_lora_converter) + if load_optimizer: + with self.optimizer_context(real_adapter_name): + self._load_multi_lora_optimizer(checkpoint_dir, adapter_name=adapter_name, **kwargs) + if dist.is_initialized(): dist.barrier() + @remote_function(dispatch='all', collect='first', sync=True) + def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.pop('adapter_name', None) + self._check_adapter_valid(adapter_name) + + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path) as f: + trainer_state = json.load(f) + + self.load(checkpoint_dir, load_optimizer=not resume_only_model, adapter_name=adapter_name, **kwargs) + + optimizer_config = self.optimizer_group.get(adapter_name) + if not resume_only_model and optimizer_config is not None: + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } + @remote_function(execute='first') def get_state_dict(self, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index c8f39bec..cedd7af6 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -235,15 +235,17 @@ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): return super().save(name, output_dir, interval, **kwargs) @remote_function() - def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **kwargs): + def load(self, name: str, output_dir: Optional[str] = None, **kwargs): adapter_name = kwargs.get('adapter_name') self._check_adapter_valid(adapter_name) with self.multi_adapter.save_context(kwargs.get('adapter_name')): load_optimizer = kwargs.get('load_optimizer', False) if output_dir is None: - # load from hub - token = kwargs.pop('token', None) - checkpoint_dir = HubOperation.download_model(name, token=token) + if os.path.exists(name): + checkpoint_dir = name + else: + token = kwargs.pop('token', None) + checkpoint_dir = HubOperation.download_model(name, token=token) else: checkpoint_dir = os.path.join(output_dir, name) model = self.strategy.unwrap_model(self.model) @@ -253,7 +255,7 @@ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **k self.multi_adapter.set_state_dict(adapter_name, adapter_weights) if load_optimizer: - self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._restore_training_state(checkpoint_dir, adapter_name=adapter_name) @remote_function() def set_grad_scaler(self, **kwargs): diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 89b497e2..8d31291a 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -124,6 +124,57 @@ def wrap_model(self, model, *args): def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def _get_fsdp_plugin(self): + state = self.accelerator.state + return state.fsdp_plugin if hasattr(state, 'fsdp_plugin') else None + + def _prepare_fsdp2_sd_options(self): + fsdp_plugin = self._get_fsdp_plugin() + if fsdp_plugin is None or fsdp_plugin.fsdp_version != 2: + return None + + from torch.distributed.checkpoint.state_dict import StateDictOptions + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + return StateDictOptions( + full_state_dict=fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT, + cpu_offload=getattr(fsdp_plugin.state_dict_config, 'offload_to_cpu', False), + broadcast_from_rank0=getattr(fsdp_plugin.state_dict_config, 'rank0_only', False), + ) + + def needs_wrapped_optimizer_state(self) -> bool: + fsdp_plugin = self._get_fsdp_plugin() + return fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2 + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + import torch + fsdp_plugin = self._get_fsdp_plugin() + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict(model, optimizer, options=self._prepare_fsdp2_sd_options()) + if self.accelerator.process_index == 0: + torch.save(optim_state, output_path) + return + + if self.accelerator.process_index == 0: + torch.save(optimizer.state_dict(), output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + import torch + fsdp_plugin = self._get_fsdp_plugin() + if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2: + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + optim_state = None + rank0_only = getattr(fsdp_plugin.optim_state_dict_config, 'rank0_only', False) + if self.accelerator.process_index == 0 or not rank0_only: + optim_state = torch.load(input_path, weights_only=True) + set_optimizer_state_dict(model, optimizer, optim_state, options=self._prepare_fsdp2_sd_options()) + return + + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + def get_full_state_dict(self, model) -> dict: """Collect full state dict.""" from twinkle.utils import torch_util diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 48a1da85..ad675006 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -151,6 +151,53 @@ def wrap_model(self, model, optimizer=None): return model, optimizer + def _prepare_optimizer_state_dict_options(self, *, for_load: bool): + from torch.distributed.checkpoint.state_dict import StateDictOptions + + return StateDictOptions( + full_state_dict=True, + cpu_offload=not for_load, + broadcast_from_rank0=for_load, + ) + + def needs_wrapped_optimizer_state(self) -> bool: + return self.device_mesh is not None + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + if Platform.is_master(): + torch.save(optimizer.state_dict(), output_path) + return + + from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict + + optim_state = get_optimizer_state_dict( + model, + optimizer, + options=self._prepare_optimizer_state_dict_options(for_load=False), + ) + if Platform.is_master(): + torch.save(optim_state, output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + import torch + if not self.needs_wrapped_optimizer_state(): + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) + return + + from torch.distributed.checkpoint.state_dict import set_optimizer_state_dict + + optim_state = {} + if Platform.is_master(): + optim_state = torch.load(input_path, map_location='cpu', weights_only=True) + set_optimizer_state_dict( + model, + optimizer, + optim_state, + options=self._prepare_optimizer_state_dict_options(for_load=True), + ) + def get_ep_clip_kwargs(self, model) -> Dict[str, Any]: """Return EP-aware kwargs for normalize_and_clip_grad_norm.""" model = self.unwrap_model(model) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 3c6f1a4e..e29d416a 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -949,3 +949,14 @@ def wrap_model(self, model, optimizer=None): def unwrap_model(self, model): return model + + def needs_wrapped_optimizer_state(self) -> bool: + return False + + def save_optimizer_checkpoint(self, model, optimizer, output_path: str): + from twinkle.utils.platforms import Platform + if Platform.is_master(): + torch.save(optimizer.state_dict(), output_path) + + def load_optimizer_checkpoint(self, model, optimizer, input_path: str): + optimizer.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False)) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 281e3f3e..b3d1b69e 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -2,7 +2,9 @@ import asyncio import contextlib import json +import numpy as np import os +import random import re import threading import torch @@ -868,22 +870,56 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name) if kwargs.get('save_optimizer', False): - self._save_optimizer(checkpoint_dir, adapter_name=adapter_name) + self._save_training_state( + checkpoint_dir, + adapter_name=adapter_name, + consumed_train_samples=kwargs.get('consumed_train_samples', 0), + ) return checkpoint_dir def _save_optimizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] + optimizer = optimizer_config.optimizer + lr_scheduler = optimizer_config.lr_scheduler + if optimizer is not None: + optimizer_path = os.path.join(output_dir, 'optimizer.pt') + self.strategy.save_optimizer_checkpoint(self.model, optimizer, optimizer_path) if Platform.is_master(): - optimizer = optimizer_config.optimizer - lr_scheduler = optimizer_config.lr_scheduler - if optimizer is not None: - torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) if lr_scheduler is not None: torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) + def _save_training_state(self, output_dir, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + self._save_optimizer(output_dir, adapter_name=adapter_name) + + optimizer_config = self.optimizer_group[adapter_name] + + if not Platform.is_master(): + return + + trainer_state = { + 'checkpoint_version': 1, + 'cur_step': optimizer_config.cur_step, + 'gradient_accumulation_steps': optimizer_config.gradient_accumulation_steps, + 'consumed_train_samples': kwargs.get('consumed_train_samples', 0), + } + with open(os.path.join(output_dir, 'trainer_state.json'), 'w', encoding='utf-8') as f: + json.dump(trainer_state, f) + + if optimizer_config.scaler is not None: + torch.save( + { + 'scaler_state_dict': optimizer_config.scaler.state_dict(), + 'scaler_has_nan': optimizer_config.scaler_has_nan, + }, + os.path.join(output_dir, 'scaler.pt'), + ) + + torch.save(self._get_training_rng_state(), os.path.join(output_dir, 'rng_state.pt')) + def _save_tokenizer(self, output_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] @@ -927,10 +963,11 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): model_sd = model.state_dict() converted_weights = {} for key, value in adapter_weights.items(): - if f'.{adapter_name}.weight' not in key: - key = key.replace('.weight', f'.{adapter_name}.weight') - if key in model_sd: - param = model_sd[key] + model_key = key + if f'.{adapter_name}.weight' not in model_key: + model_key = model_key.replace('.weight', f'.{adapter_name}.weight') + if model_key in model_sd: + param = model_sd[model_key] if isinstance(param, DTensor) and not isinstance(value, DTensor): value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements) converted_weights[key] = value @@ -949,18 +986,26 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): def _load_optimizer(self, checkpoint_dir, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + strict = kwargs.pop('strict', False) # assume optimizer and lr_scheduler are created optimizer_config = self.optimizer_group[adapter_name] optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt') scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt') + if strict and not os.path.exists(optimizer_path): + raise FileNotFoundError(optimizer_path) + if strict and optimizer_config.lr_scheduler is not None and not os.path.exists(scheduler_path): + logger.warning( + f'Missing scheduler checkpoint {scheduler_path}; resuming without restoring lr scheduler state.', ) + if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None: - state_dict = torch.load(optimizer_path, map_location='cpu') - optimizer_config.optimizer.load_state_dict(state_dict) + if self.strategy.needs_wrapped_optimizer_state() and not self._model_wrapped: + self._lazy_wrap_model() + self.strategy.load_optimizer_checkpoint(self.model, optimizer_config.optimizer, optimizer_path) if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None: - state_dict = torch.load(scheduler_path, map_location='cpu') + state_dict = torch.load(scheduler_path, map_location='cpu', weights_only=True) optimizer_config.lr_scheduler.load_state_dict(state_dict) def _ensure_lora_dtype(self, model): @@ -979,6 +1024,87 @@ def _ensure_lora_dtype(self, model): if 'lora_' in name.lower() and param.dtype != base_dtype: param.data = param.data.to(base_dtype) + def _load_scaler_state(self, scaler_path, **kwargs): + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + if optimizer_config.scaler is None: + raise ValueError(f'Grad scaler is not configured for adapter {adapter_name!r}') + + scaler_state = torch.load(scaler_path, map_location='cpu', weights_only=True) + optimizer_config.scaler.load_state_dict(scaler_state['scaler_state_dict']) + optimizer_config.scaler_has_nan = scaler_state.get('scaler_has_nan', False) + + def _get_training_rng_state(self): + state = { + 'python_rng_state': random.getstate(), + 'numpy_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + } + + device_prefix = Platform.device_prefix() + device_module = getattr(torch, device_prefix, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + state['device_type'] = device_prefix + state['device_rng_state'] = device_module.get_rng_state() + else: + state['device_type'] = 'cpu' + state['device_rng_state'] = None + return state + + def _load_rng_state(self, rng_path): + rng_state = torch.load(rng_path, map_location='cpu', weights_only=False) + random.setstate(rng_state['python_rng_state']) + np.random.set_state(rng_state['numpy_rng_state']) + torch.set_rng_state(rng_state['torch_rng_state']) + + device_type = rng_state.get('device_type') + device_rng_state = rng_state.get('device_rng_state') + if device_type != 'cpu' and device_rng_state is not None: + device_module = getattr(torch, device_type, None) + if device_module and hasattr(device_module, 'is_available') and device_module.is_available(): + device_module.set_rng_state(device_rng_state) + + def _restore_training_state(self, checkpoint_dir, *, adapter_name=''): + trainer_state_path = os.path.join(checkpoint_dir, 'trainer_state.json') + with open(trainer_state_path) as f: + trainer_state = json.load(f) + + adapter_name = adapter_name or self._get_default_group() + optimizer_config = self.optimizer_group[adapter_name] + self._load_optimizer(checkpoint_dir, adapter_name=adapter_name) + scaler_path = os.path.join(checkpoint_dir, 'scaler.pt') + if os.path.exists(scaler_path) and optimizer_config.scaler is not None: + self._load_scaler_state(scaler_path, adapter_name=adapter_name) + rng_path = os.path.join(checkpoint_dir, 'rng_state.pt') + if os.path.exists(rng_path): + self._load_rng_state(rng_path) + optimizer_config.cur_step = trainer_state['cur_step'] + optimizer_config.gradient_accumulation_steps = trainer_state['gradient_accumulation_steps'] + + return trainer_state + + @remote_function() + def resume_from_checkpoint(self, checkpoint_dir, *, resume_only_model=False, **kwargs): + adapter_name = kwargs.get('adapter_name', '') + + has_adapter = ( + os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.safetensors')) + or os.path.exists(os.path.join(checkpoint_dir, 'adapter_model.bin'))) + if has_adapter: + self.load(checkpoint_dir, adapter_name=adapter_name) + + if not resume_only_model: + trainer_state = self._restore_training_state(checkpoint_dir, adapter_name=adapter_name) + else: + with open(os.path.join(checkpoint_dir, 'trainer_state.json')) as f: + trainer_state = json.load(f) + + return { + 'cur_step': trainer_state['cur_step'], + 'consumed_train_samples': trainer_state['consumed_train_samples'], + 'gradient_accumulation_steps': trainer_state['gradient_accumulation_steps'], + } + @remote_function(collect='first') def get_state_dict(self, **kwargs): return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group())) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 7f566840..b9d3e295 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -12,6 +12,7 @@ import torch import traceback from fastapi import Depends, FastAPI, HTTPException, Request +from pathlib import Path from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable @@ -391,6 +392,31 @@ async def _task(): await run_task(self.schedule_task_and_wait(_task, model_id=adapter_name, token=token, task_type='load')) + @app.post('/twinkle/resume_from_checkpoint', response_model=types.TrainingProgressResponse) + async def resume_from_checkpoint( + request: Request, + body: types.ResumeFromCheckpointRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.TrainingProgressResponse: + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_resource_exists(adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + checkpoint_dir = ( + Path(resolved.checkpoint_dir, resolved.checkpoint_name).as_posix() + if resolved.checkpoint_dir else body.name) + ret = self.model.resume_from_checkpoint( + checkpoint_dir, + resume_only_model=body.resume_only_model, + adapter_name=adapter_name, + ) + return {'result': ret} + + return await run_task(self.schedule_task_and_wait(_task, task_type='resume')) + @app.post('/twinkle/upload_to_hub', response_model=types.UploadToHubResponse) async def upload_to_hub( request: Request, diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 0a067ddd..4164940d 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -82,4 +82,44 @@ def __next__(self): ) response.raise_for_status() return response.json()["result"] + + + def skip_consumed_samples(self, consumed_train_samples: int): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'skip_consumed_samples', + **{'consumed_train_samples': consumed_train_samples}, + } + ) + response.raise_for_status() + return response.json()["result"] + + + def resume_from_checkpoint(self, consumed_train_samples, **kwargs): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'resume_from_checkpoint', + **{'consumed_train_samples': consumed_train_samples}, + **kwargs + } + ) + response.raise_for_status() + return response.json()["result"] + + + def get_state(self): + response = http_post( + url=f'{self.server_url}/call', + json_data={ + 'processor_id': self.processor_id, + 'function': 'get_state', + **{}, + } + ) + response.raise_for_status() + return response.json()["result"] \ No newline at end of file diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index c2509f42..76e573d0 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -20,6 +20,7 @@ GetStateDictResponse, GetTrainConfigsResponse, SaveResponse, + TrainingProgressResponse, ) @@ -189,6 +190,15 @@ def load(self, name: str, **kwargs) -> None: ) response.raise_for_status() + def resume_from_checkpoint(self, name: str, *, resume_only_model: bool = False, **kwargs) -> Dict[str, Any]: + response = http_post( + url=f'{self.server_url}/resume_from_checkpoint', + json_data={'name': name, 'adapter_name': self.adapter_name, + 'resume_only_model': resume_only_model, **kwargs} + ) + response.raise_for_status() + return TrainingProgressResponse(**response.json()).result + def apply_patch(self, patch_cls: str, **kwargs) -> None: """Apply a patch to the model.""" response = http_post( diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 51541dba..d58bb00c 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -28,6 +28,7 @@ LrStepResponse, ModelResult, OkResponse, + ResumeFromCheckpointRequest, SaveRequest, SaveResponse, SetLossRequest, @@ -41,6 +42,7 @@ SetTemplateRequest, SetTemplateResponse, StepResponse, + TrainingProgressResponse, UploadToHubRequest, UploadToHubResponse, UploadStatusResponse, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index e2add1ce..fd14d91f 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -97,6 +97,16 @@ class Config: extra = 'allow' +class ResumeFromCheckpointRequest(BaseModel): + """Request for /resume_from_checkpoint endpoint.""" + name: str + adapter_name: str = '' + resume_only_model: bool = False + + class Config: + extra = 'allow' + + class AddAdapterRequest(BaseModel): adapter_name: str config: str @@ -220,6 +230,11 @@ class SaveResponse(BaseModel): checkpoint_dir: Optional[str] = None +class TrainingProgressResponse(BaseModel): + """Response for /resume_from_checkpoint endpoint.""" + result: Dict[str, Any] + + # --- Void responses (return None → OkResponse) --- class BackwardResponse(OkResponse): diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 79bf78ad..2da0a4f8 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,18 +1,31 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import concurrent.futures import numpy as np import os import pytest from pathlib import Path +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import IterableDataset as TorchIterableDataset +from unittest.mock import MagicMock import twinkle +import twinkle.hub.hub as _hub_module from twinkle import DeviceMesh from twinkle.data_format import Message from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta +from twinkle.dataset import Dataset, DatasetMeta, IterableDataset from twinkle.processor import InputProcessor twinkle.initialize(mode='local') + +@pytest.fixture(autouse=True) +def _disable_process_pool(monkeypatch): + mock_executor = MagicMock() + mock_executor.submit.side_effect = RuntimeError('Process pool is disabled in this test environment.') + monkeypatch.setattr(_hub_module, '_executor', mock_executor) + + TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true' @@ -22,6 +35,44 @@ def convert_to_messages(example): return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]} +def _build_resume_rows(): + return [ + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, + ] + + +class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + +class _InMemoryIterableDataset(TorchIterableDataset): + + def __init__(self, rows): + self.rows = rows + + def __iter__(self): + return iter(self.rows) + + class TestDataLoaderBasic: def test_dataloader_basic(self): @@ -157,3 +208,25 @@ def test_retry_sampler_length(self): total_samples = sum(len(batch) for batch in dataloader) assert total_samples == original_len + + +class TestResumeSkip: + + def test_dataloader_skip_consumed_samples_for_map_style_dataset(self): + dataset = _InMemoryDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + batches = list(dataloader) + + texts = [item['text'] for batch in batches for item in batch] + assert texts[0] == 'Another example' + + def test_dataloader_warns_when_skip_requested_for_iterable_dataset(self, recwarn): + dataset = _InMemoryIterableDataset(_build_resume_rows()) + dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0) + + dataloader.skip_consumed_samples(2) + next(iter(dataloader)) + + assert 'does not support consumed-data skipping' in str(recwarn.pop(UserWarning).message) diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py index b8438207..3b7a4ebc 100644 --- a/tests/dataloader/test_sampler.py +++ b/tests/dataloader/test_sampler.py @@ -1,15 +1,28 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import numpy as np import os import pytest from pathlib import Path +from torch.utils.data import Dataset as TorchDataset from torch.utils.data import RandomSampler, SequentialSampler +from unittest.mock import MagicMock import twinkle +import twinkle.hub.hub as _hub_module +from twinkle import DeviceMesh from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta twinkle.initialize(mode='local') + +@pytest.fixture(autouse=True) +def _disable_process_pool(monkeypatch): + mock_executor = MagicMock() + mock_executor.submit.side_effect = RuntimeError('Process pool is disabled in this test environment.') + monkeypatch.setattr(_hub_module, '_executor', mock_executor) + + TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data' @@ -162,3 +175,42 @@ def test_sequential_vs_random_order(self): different = seq_texts != rand_texts assert different or len(seq_texts) == 1 + + +class TestResumeSkipSamplerOrdering: + + def test_sequential_sampler_skip_happens_before_device_mesh_slice(self): + + class _InMemoryDataset(TorchDataset): + + def __init__(self, rows): + self.rows = rows + + def __len__(self): + return len(self.rows) + + def __getitem__(self, idx): + return self.rows[idx] + + dataset = _InMemoryDataset([ + { + 'text': 'Hello world' + }, + { + 'text': 'Test data' + }, + { + 'text': 'Another example' + }, + { + 'text': 'Sample text' + }, + ]) + sampler = SequentialSampler(dataset) + device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', )) + dataloader = DataLoader(dataset=dataset, batch_size=4, sampler=sampler, device_mesh=device_mesh, num_workers=0) + + dataloader.skip_consumed_samples(2) + first_batch = list(dataloader)[0] + + assert first_batch[0]['text'] == 'Another example'