Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch.utils.data.distributed import DistributedSampler

from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.colo_parameter import ColoParameter
Expand Down Expand Up @@ -83,22 +84,22 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str):
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
self.save_checkpoint(state_dict, checkpoint)
save_state_dict(state_dict, checkpoint, use_safetensors)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Expand Down
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool =
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Expand Down
6 changes: 3 additions & 3 deletions examples/tutorial/new_api/torch_ddp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

## 🚀 Quick Start

This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch.
This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch.

- Training Arguments
- `-r, `--resume`: resume from checkpoint file path
- `-r`, `--resume`: resume from checkpoint file path
- `-c`, `--checkpoint`: the folder to save checkpoints
- `-i`, `--interval`: epoch interval to save checkpoints
- `-f`, `--fp16`: use fp16
Expand Down Expand Up @@ -41,4 +41,4 @@ Expected accuracy performance will be:
| --------- | ------------------------ | --------------------- | --------------------- |
| ResNet-18 | 85.85% | 85.03% | 85.12% |

**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**