-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[booster] support torch fsdp plugin in booster #3697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,285 @@ | ||
| from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from packaging import version | ||
| from torch.distributed import ProcessGroup | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( | ||
| torch.__version__) < version.parse('2.0.0'): | ||
| from torch.distributed.fsdp import FullStateDictConfig | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp import StateDictType | ||
| from torch.distributed.fsdp.fully_sharded_data_parallel import ( | ||
| BackwardPrefetch, | ||
| CPUOffload, | ||
| MixedPrecision, | ||
| ShardingStrategy, | ||
| ) | ||
| elif version.parse(torch.__version__) >= version.parse('2.0.0'): | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp._init_utils import ProcessGroupType | ||
| from torch.distributed.fsdp.api import ( | ||
| BackwardPrefetch, | ||
| CPUOffload, | ||
| FullOptimStateDictConfig, | ||
| FullStateDictConfig, | ||
| MixedPrecision, | ||
| ShardingStrategy, | ||
| StateDictType, | ||
| ) | ||
| from torch.distributed.fsdp.wrap import _FSDPPolicy | ||
| else: | ||
| raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") | ||
|
|
||
| from torch.optim import Optimizer | ||
| from torch.optim.lr_scheduler import _LRScheduler as LRScheduler | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO | ||
| from colossalai.cluster import DistCoordinator | ||
| from colossalai.interface import ModelWrapper, OptimizerWrapper | ||
|
|
||
| from .dp_plugin_base import DPPluginBase | ||
|
|
||
| __all__ = ['TorchFSDPPlugin'] | ||
|
|
||
|
|
||
| class TorchFSDPCheckpointIO(GeneralCheckpointIO): | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.coordinator = DistCoordinator() | ||
|
|
||
| def __set_model_optim_state( | ||
| self, | ||
| model, | ||
| state_dict_type, | ||
| state_dict_config, | ||
| optim_state_dict_config, | ||
| ): | ||
| return FSDP.set_state_dict_type(model, state_dict_type, state_dict_config, optim_state_dict_config) | ||
|
|
||
| def load_sharded_model(self, model: nn.Module, checkpoint: str): | ||
|
|
||
| # TODO(jishaomin): implement this method as it can be supported by Huggingface model | ||
| raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") | ||
|
|
||
| def load_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): | ||
|
|
||
| # TODO(jishaomin): implement this method as it can be supported by Huggingface model | ||
| raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") | ||
|
|
||
| def save_sharded_model(self, model: nn.Module, checkpoint: str): | ||
|
|
||
| # TODO(jishaomin): implement this method as it can be supported by Huggingface model | ||
| raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") | ||
|
|
||
| def save_sharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): | ||
|
|
||
| # TODO(jishaomin): implement this method as it can be supported by Huggingface model | ||
| raise NotImplementedError("Torch FSDP sharded model checkpoint is not supported yet.") | ||
|
|
||
| def load_unsharded_model(self, model: nn.Module, checkpoint: str): | ||
| """ | ||
| Load model from checkpoint with automatic unwrapping. | ||
| """ | ||
| # the model should be unwrapped in self.load_model via ModelWrapper.unwrap | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( | ||
| torch.__version__) < version.parse('2.0.0'): | ||
| full_state_dict = self.load_state_dict(checkpoint) | ||
| elif version.parse(torch.__version__) >= version.parse('2.0.0'): | ||
| full_state_dict = self.load_state_dict(checkpoint) | ||
| self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True)) | ||
| full_state_dict = model.state_dict() | ||
| else: | ||
| raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") | ||
|
|
||
| model.load_state_dict(full_state_dict) | ||
|
|
||
| def load_unsharded_optimizer(self, model: nn.Module, optim: Optimizer, checkpoint: str): | ||
| """ | ||
| Load Optimizer from checkpoint with automatic unwrapping. | ||
| """ | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( | ||
| torch.__version__) < version.parse('2.0.0'): | ||
| optim_full_state_dict = self.load_state_dict(checkpoint) | ||
| elif version.parse(torch.__version__) >= version.parse('2.0.0'): | ||
| optim_full_state_dict = self.load_state_dict(checkpoint) | ||
| FSDP.full_optim_state_dict_to_load(optim_full_state_dict, model, optim) | ||
| else: | ||
| raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") | ||
|
|
||
| optim.load_state_dict(optim_full_state_dict) | ||
|
|
||
| def save_unsharded_model(self, model: nn.Module, checkpoint: str): | ||
| """ | ||
| Save model to checkpoint but only on master process. | ||
| """ | ||
| # the model should be unwrapped in self.load_model via ModelWrapper.unwrap | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( | ||
| torch.__version__) < version.parse('2.0.0'): | ||
| cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | ||
| with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): | ||
| model_state_dict = model.state_dict() | ||
| elif version.parse(torch.__version__) >= version.parse('2.0.0'): | ||
| self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(rank0_only=True)) | ||
| model_state_dict = model.state_dict() | ||
| else: | ||
| raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") | ||
| self.save_checkpoint(model_state_dict, checkpoint) | ||
|
|
||
| def save_unsharded_optimizer(self, model: nn.Module, optimizer: Optimizer, checkpoint: str): | ||
| """ | ||
| Save optimizer to checkpoint but only on master process. | ||
| """ | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( | ||
| torch.__version__) < version.parse('2.0.0'): | ||
| optim_state_dict = FSDP.full_optim_state_dict(model=model, optim=optimizer) | ||
| elif version.parse(torch.__version__) >= version.parse('2.0.0'): | ||
| self.__set_model_optim_state(model, StateDictType.FULL_STATE_DICT, | ||
| FullOptimStateDictConfig(rank0_only=True)) | ||
| optim_state_dict = FSDP.optim_state_dict(model, optimizer) | ||
| else: | ||
| raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") | ||
| self.save_checkpoint(optim_state_dict, checkpoint) | ||
|
|
||
|
|
||
| class TorchFSDPModel(ModelWrapper): | ||
|
|
||
| def __init__(self, module: nn.Module, *args, **kwargs) -> None: | ||
| super().__init__(module) | ||
| self.module = FSDP(module, *args, **kwargs) | ||
|
|
||
| def unwrap(self): | ||
| return self.module.module | ||
|
|
||
|
|
||
| class TorchFSDPPlugin(DPPluginBase): | ||
| """ | ||
| Plugin for PyTorch FSDP. | ||
|
|
||
| Example: | ||
| >>> from colossalai.booster import Booster | ||
| >>> from colossalai.booster.plugin import TorchFSDPPlugin | ||
| >>> | ||
| >>> model, train_dataset, optimizer, criterion = ... | ||
| >>> plugin = TorchFSDPPlugin() | ||
|
|
||
| >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) | ||
| >>> booster = Booster(plugin=plugin) | ||
| >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) | ||
|
|
||
| Args: | ||
| See https://pytorch.org/docs/stable/fsdp.html for details. | ||
| """ | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( | ||
| torch.__version__) < version.parse('2.0.0'): | ||
|
|
||
| def __init__( | ||
| self, | ||
| process_group: Optional[ProcessGroup] = None, | ||
| sharding_strategy: Optional[ShardingStrategy] = None, | ||
| cpu_offload: Optional[CPUOffload] = None, | ||
| auto_wrap_policy: Optional[Callable] = None, | ||
| backward_prefetch: Optional[BackwardPrefetch] = None, | ||
| mixed_precision: Optional[MixedPrecision] = None, | ||
| ignored_modules: Optional[Iterable[torch.nn.Module]] = None, | ||
| param_init_fn: Optional[Callable[[nn.Module], None]] = None, | ||
| device_id: Optional[Union[int, torch.device]] = None, | ||
| sync_module_states: bool = False, | ||
| ): | ||
| super().__init__() | ||
| self.fsdp_kwargs = dict(process_group=process_group, | ||
| sharding_strategy=sharding_strategy, | ||
| cpu_offload=cpu_offload, | ||
| auto_wrap_policy=auto_wrap_policy, | ||
| backward_prefetch=backward_prefetch, | ||
| mixed_precision=mixed_precision, | ||
| ignored_modules=ignored_modules, | ||
| param_init_fn=param_init_fn, | ||
| device_id=device_id, | ||
| sync_module_states=sync_module_states) | ||
| elif version.parse(torch.__version__) >= version.parse('2.0.0'): | ||
|
|
||
| def __init__( | ||
| self, | ||
| process_group: ProcessGroupType = None, | ||
| sharding_strategy: Optional[ShardingStrategy] = None, | ||
| cpu_offload: Optional[CPUOffload] = None, | ||
| auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None, | ||
| backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, | ||
| mixed_precision: Optional[MixedPrecision] = None, | ||
| ignored_modules: Optional[Iterable[torch.nn.Module]] = None, | ||
| param_init_fn: Optional[Callable[[nn.Module], None]] = None, | ||
| device_id: Optional[Union[int, torch.device]] = None, | ||
| sync_module_states: bool = False, | ||
| forward_prefetch: bool = False, | ||
| limit_all_gathers: bool = False, | ||
| use_orig_params: bool = False, | ||
| ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, | ||
| ): | ||
| super().__init__() | ||
| self.fsdp_kwargs = dict(process_group=process_group, | ||
| sharding_strategy=sharding_strategy, | ||
| cpu_offload=cpu_offload, | ||
| auto_wrap_policy=auto_wrap_policy, | ||
| backward_prefetch=backward_prefetch, | ||
| mixed_precision=mixed_precision, | ||
| ignored_modules=ignored_modules, | ||
| param_init_fn=param_init_fn, | ||
| device_id=device_id, | ||
| sync_module_states=sync_module_states, | ||
| forward_prefetch=forward_prefetch, | ||
| limit_all_gathers=limit_all_gathers, | ||
| use_orig_params=use_orig_params, | ||
| ignored_parameters=ignored_parameters) | ||
| else: | ||
| raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") | ||
|
|
||
| def support_no_sync(self) -> bool: | ||
| False | ||
|
|
||
| def no_sync(self, model: nn.Module) -> Iterator[None]: | ||
| raise NotImplementedError("Torch fsdp no_sync func not supported yet.") | ||
|
|
||
| def control_precision(self) -> bool: | ||
| return True | ||
|
|
||
| def supported_precisions(self) -> List[str]: | ||
| return ['fp16', 'bf16'] | ||
|
|
||
| def control_device(self) -> bool: | ||
| return True | ||
|
|
||
| def supported_devices(self) -> List[str]: | ||
| return ['cuda'] | ||
|
|
||
| def configure( | ||
| self, | ||
| model: nn.Module, | ||
| optimizer: Optimizer, | ||
| criterion: Callable = None, | ||
| dataloader: DataLoader = None, | ||
| lr_scheduler: LRScheduler = None, | ||
| ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: | ||
|
|
||
| model = model.cuda() | ||
| # wrap the model with PyTorch FSDP | ||
| model = TorchFSDPModel(model, **self.fsdp_kwargs) | ||
|
|
||
| if not isinstance(optimizer, OptimizerWrapper): | ||
| optimizer = OptimizerWrapper(optimizer) | ||
|
|
||
| return model, optimizer, criterion, dataloader, lr_scheduler | ||
|
|
||
| def control_checkpoint_io(self) -> bool: | ||
| return True | ||
|
|
||
| def get_checkpoint_io(self) -> CheckpointIO: | ||
| return TorchFSDPCheckpointIO() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| from contextlib import nullcontext | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
| from packaging import version | ||
| from torch import nn | ||
| from torch.optim import SGD | ||
|
|
||
| import colossalai | ||
| from colossalai.booster import Booster | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse('1.12.0'): | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from colossalai.booster.plugin import TorchFSDPPlugin | ||
|
|
||
| from colossalai.interface import OptimizerWrapper | ||
| from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
| from tests.kit.model_zoo import model_zoo | ||
|
|
||
|
|
||
| def run_fn(model_fn, data_gen_fn, output_transform_fn): | ||
| plugin = TorchFSDPPlugin() | ||
| booster = Booster(plugin=plugin) | ||
| model = model_fn() | ||
| optimizer = SGD(model.parameters(), lr=1e-3) | ||
| criterion = lambda x: x.mean() | ||
| data = data_gen_fn() | ||
|
|
||
| data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} | ||
|
|
||
| model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) | ||
|
|
||
| assert isinstance(model.module, FSDP) | ||
| assert isinstance(optimizer, OptimizerWrapper) | ||
|
|
||
| output = model(**data) | ||
| output = output_transform_fn(output) | ||
| output_key = list(output.keys())[0] | ||
| loss = criterion(output[output_key]) | ||
|
|
||
| booster.backward(loss, optimizer) | ||
| optimizer.clip_grad_by_norm(1.0) | ||
| optimizer.step() | ||
|
|
||
|
|
||
| def check_torch_fsdp_plugin(): | ||
| for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): | ||
| if 'diffusers' in name: | ||
| continue | ||
| run_fn(model_fn, data_gen_fn, output_transform_fn) | ||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| def run_dist(rank, world_size, port): | ||
| # init dist env | ||
| colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') | ||
| check_torch_fsdp_plugin() | ||
|
|
||
|
|
||
| @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") | ||
| @rerun_if_address_is_in_use() | ||
| def test_torch_fsdp_plugin(): | ||
| spawn(run_dist, 2) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.