Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def save_model(self,
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_model(model, checkpoint, prefix, shard, size_per_shard)
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""Load optimizer from checkpoint.
Expand Down
207 changes: 69 additions & 138 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,31 @@
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):

if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._init_utils import ProcessGroupType
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO, utils
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper

Expand All @@ -51,102 +40,71 @@ def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()

def __set_model_optim_state(
self,
model,
state_dict_type,
state_dict_config,
optim_state_dict_config,
):
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)

def load_sharded_model(self, model: nn.Module, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def save_sharded_model(self, model: nn.Module, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)

def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def load_unsharded_model(self, model: nn.Module, checkpoint: str):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Load model from checkpoint with automatic unwrapping.
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
full_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

model.load_state_dict(full_state_dict)

def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Load Optimizer from checkpoint with automatic unwrapping.
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

optim.load_state_dict(optim_full_state_dict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str):
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
raise NotImplementedError("Sharded model checkpoint is not supported yet.")

def load_sharded_model(self,
model: nn.Module,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True):
"""
Load model to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
model_state_dict = model.state_dict()
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
model_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(model_state_dict, checkpoint)

def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
FullOptimStateDictConfig(rank0_only=True))
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(optim_state_dict, checkpoint)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)


class TorchFSDPModel(ModelWrapper):
Expand All @@ -156,7 +114,17 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
self.module = FSDP(module, *args, **kwargs)

def unwrap(self):
return self.module.module
return self.module


class FSDPOptimizerWrapper(OptimizerWrapper):

def __init__(self, optimizer: Optimizer, model: nn.Module):
self.model = model
super().__init__(optimizer)

def unwrap_model(self) -> nn.Module:
return self.model


class TorchFSDPPlugin(DPPluginBase):
Expand All @@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
if version.parse(torch.__version__) >= version.parse('1.12.0'):

def __init__(
self,
Expand All @@ -191,7 +158,6 @@ def __init__(
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
):
super().__init__()
Expand All @@ -203,42 +169,7 @@ def __init__(
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):

def __init__(
self,
process_group: ProcessGroupType = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_parameters=ignored_parameters)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

Expand Down Expand Up @@ -269,14 +200,14 @@ def configure(
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:

model = model.cuda()
# wrap the model with PyTorch FSDP
model = TorchFSDPModel(model, **self.fsdp_kwargs)
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)
optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if not isinstance(optimizer, FSDPOptimizerWrapper):
optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model)

return model, optimizer, criterion, dataloader, lr_scheduler
return fsdp_model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
return True
Expand Down
Loading