Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions colossalai/booster/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@
from .torch_ddp_plugin import TorchDDPPlugin

__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin']

import torch
from packaging import version

if version.parse(torch.__version__) >= version.parse('1.12.0'):
from .torch_fsdp_plugin import TorchFSDPPlugin
__all__.append('TorchFSDPPlugin')
285 changes: 285 additions & 0 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from packaging import version
from torch.distributed import ProcessGroup

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
from torch.distributed.fsdp import FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch,
CPUOffload,
MixedPrecision,
ShardingStrategy,
)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._init_utils import ProcessGroupType
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper

from .dp_plugin_base import DPPluginBase

__all__ = ['TorchFSDPPlugin']


class TorchFSDPCheckpointIO(GeneralCheckpointIO):

def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()

def __set_model_optim_state(
self,
model,
state_dict_type,
state_dict_config,
optim_state_dict_config,
):
return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config)

def load_sharded_model(self, model: nn.Module, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def save_sharded_model(self, model: nn.Module, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):

# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.")

def load_unsharded_model(self, model: nn.Module, checkpoint: str):
Comment thread
wukong1992 marked this conversation as resolved.
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
full_state_dict = self.load_state_dict(checkpoint)
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
full_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

model.load_state_dict(full_state_dict)

def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str):
"""
Load Optimizer from checkpoint with automatic unwrapping.
"""

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
optim_full_state_dict = self.load_state_dict(checkpoint)
FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

optim.load_state_dict(optim_full_state_dict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
model_state_dict = model.state_dict()
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True))
model_state_dict = model.state_dict()
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(model_state_dict, checkpoint)

def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str):
"""
Save optimizer to checkpoint but only on master process.
"""

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):
optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):
self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT,
FullOptimStateDictConfig(rank0_only=True))
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
self.save_checkpoint(optim_state_dict, checkpoint)


class TorchFSDPModel(ModelWrapper):

def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)

def unwrap(self):
return self.module.module


class TorchFSDPPlugin(DPPluginBase):
"""
Plugin for PyTorch FSDP.

Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchFSDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchFSDPPlugin()

>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
"""

if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse(
torch.__version__) < version.parse('2.0.0'):

def __init__(
self,
process_group: Optional[ProcessGroup] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
backward_prefetch: Optional[BackwardPrefetch] = None,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states)
elif version.parse(torch.__version__) >= version.parse('2.0.0'):

def __init__(
self,
process_group: ProcessGroupType = None,
sharding_strategy: Optional[ShardingStrategy] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
mixed_precision: Optional[MixedPrecision] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
limit_all_gathers: bool = False,
use_orig_params: bool = False,
ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
):
super().__init__()
self.fsdp_kwargs = dict(process_group=process_group,
sharding_strategy=sharding_strategy,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=backward_prefetch,
mixed_precision=mixed_precision,
ignored_modules=ignored_modules,
param_init_fn=param_init_fn,
device_id=device_id,
sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params,
ignored_parameters=ignored_parameters)
else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")

def support_no_sync(self) -> bool:
False

def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")

def control_precision(self) -> bool:
return True

def supported_precisions(self) -> List[str]:
return ['fp16', 'bf16']

def control_device(self) -> bool:
return True

def supported_devices(self) -> List[str]:
return ['cuda']

def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:

model = model.cuda()
# wrap the model with PyTorch FSDP
model = TorchFSDPModel(model, **self.fsdp_kwargs)

if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)

return model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
return True

def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()
4 changes: 2 additions & 2 deletions colossalai/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def test_something():
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
assert torch_version.major >= 1

# only torch >= 1.8 has ProcessRaisedException
if torch_version.minor >= 8:
if torch_version >= version.parse("1.8.0"):
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception
Expand Down
64 changes: 64 additions & 0 deletions tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from contextlib import nullcontext

import pytest
import torch
import torch.distributed as dist
from packaging import version
from torch import nn
from torch.optim import SGD

import colossalai
from colossalai.booster import Booster

if version.parse(torch.__version__) >= version.parse('1.12.0'):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from colossalai.booster.plugin import TorchFSDPPlugin

from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo


def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()

data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}

model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

assert isinstance(model.module, FSDP)
assert isinstance(optimizer, OptimizerWrapper)

output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])

booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()


def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if 'diffusers' in name:
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()


def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_fsdp_plugin()


@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher")
@rerun_if_address_is_in_use()
def test_torch_fsdp_plugin():
spawn(run_dist, 2)