-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[booster] implement Gemini plugin #3352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2ec8f87
[booster] add gemini plugin
ver217 da4614e
[booster] update docstr
ver217 828e76a
[booster] gemini plugin add coloparam convertor
ver217 b7af7e4
[booster] fix coloparam convertor
ver217 739b47d
[booster] fix gemini plugin device
ver217 8e76bb8
[booster] add gemini plugin test
ver217 2af5a8b
[booster] gemini plugin ignore sync bn
ver217 cbd6b9d
[booster] skip some model
ver217 d87fc9d
[booster] skip some model
ver217 9980438
[booster] modify test world size
ver217 2d90c6a
[booster] modify test world size
ver217 ea15a35
[booster] skip test
ver217 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from .gemini_plugin import GeminiPlugin | ||
| from .plugin_base import Plugin | ||
| from .torch_ddp_plugin import TorchDDPPlugin | ||
|
|
||
| __all__ = ['Plugin', 'TorchDDPPlugin'] | ||
| __all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,338 @@ | ||
| import random | ||
| import warnings | ||
| from typing import Callable, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn as nn | ||
| from torch import Tensor | ||
| from torch.optim import Optimizer | ||
| from torch.optim.lr_scheduler import _LRScheduler as LRScheduler | ||
| from torch.utils.data import DataLoader | ||
| from torch.utils.data.distributed import DistributedSampler | ||
|
|
||
| from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO | ||
| from colossalai.cluster import DistCoordinator | ||
| from colossalai.gemini.memory_tracer import MemStats | ||
| from colossalai.interface import ModelWrapper, OptimizerWrapper | ||
| from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper | ||
| from colossalai.tensor.colo_parameter import ColoParameter | ||
| from colossalai.utils import get_current_device | ||
| from colossalai.utils.model.colo_init_context import _convert_to_coloparam | ||
|
|
||
| from .plugin_base import Plugin | ||
|
|
||
| __all__ = ['GeminiPlugin'] | ||
|
|
||
|
|
||
| def convert_to_colo_param(module: nn.Module) -> None: | ||
| """Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini. | ||
|
|
||
| Args: | ||
| module (nn.Module): Module to be converted. | ||
| """ | ||
| converted_modules = set() # handle shared modules | ||
| converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference | ||
|
|
||
| def convert_recursively(m: nn.Module): | ||
| for child in m.children(): | ||
| if child not in converted_modules: | ||
| converted_modules.add(child) | ||
| convert_recursively(child) | ||
|
|
||
| for name, p in m.named_parameters(recurse=False): | ||
| assert not isinstance(p, ColoParameter) | ||
| if p in converted_params: | ||
| target = converted_params[p] | ||
| else: | ||
| target = _convert_to_coloparam(p, p.device, p.dtype) | ||
| converted_params[p] = target | ||
| setattr(m, name, target) | ||
| target.shared_param_modules.append(m) | ||
|
|
||
| convert_recursively(module) | ||
|
|
||
| # optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak | ||
| module._converted_params = converted_params | ||
|
|
||
|
|
||
| def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None: | ||
| """Replace param in optimizer's group with converted ColoParameter. | ||
|
|
||
| Args: | ||
| optimizer (Optimizer): Optimizer to be replaced. | ||
| converted_params (dict): Mapping between (torch.Tensor, ColoTensor). | ||
| """ | ||
| for group in optimizer.param_groups: | ||
| for i, p in enumerate(group['params']): | ||
| if p in converted_params: | ||
| group['params'][i] = converted_params[p] | ||
|
|
||
|
|
||
| class GeminiCheckpointIO(GeneralCheckpointIO): | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.coordinator = DistCoordinator() | ||
|
|
||
| def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): | ||
| """ | ||
| Load model from checkpoint with automatic unwrapping. | ||
| """ | ||
| # the model should be unwrapped in self.load_model via ModelWrapper.unwrap | ||
| return super().load_unsharded_model(model, checkpoint, strict=strict) | ||
|
|
||
| def save_unsharded_model(self, model: GeminiDDP, checkpoint: str): | ||
| """ | ||
| Save model to checkpoint but only on master process. | ||
| """ | ||
| # the model should be unwrapped in self.load_model via ModelWrapper.unwrap | ||
| # as there is communication when get state dict, this must be called on all processes | ||
| state_dict = model.state_dict(only_rank_0=True) | ||
| if self.coordinator.is_master(): | ||
| self.save_checkpoint(state_dict, checkpoint) | ||
|
|
||
| def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): | ||
| """ | ||
| Save optimizer to checkpoint but only on master process. | ||
| """ | ||
| # TODO(ver217): optimizer state dict is sharded | ||
| super().save_unsharded_optimizer(optimizer, checkpoint) | ||
|
|
||
| def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): | ||
| """ | ||
| Save model to checkpoint but only on master process. | ||
| """ | ||
| if self.coordinator.is_master(): | ||
| super().save_lr_scheduler(lr_scheduler, checkpoint) | ||
|
|
||
|
|
||
| class GeminiModel(ModelWrapper): | ||
|
|
||
| def __init__(self, module: nn.Module, gemini_config: dict) -> None: | ||
| super().__init__(module) | ||
| # TODO(ver217): only support Gemini now | ||
| convert_to_colo_param(module) | ||
| self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config) | ||
|
|
||
| def unwrap(self): | ||
| # as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model | ||
| return self.module | ||
|
|
||
|
|
||
| class GeminiOptimizer(OptimizerWrapper): | ||
|
|
||
| def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None: | ||
| replace_param_in_group(optimizer, module.module._converted_params) | ||
| del module.module._converted_params | ||
| optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs) | ||
| super().__init__(optimizer) | ||
|
|
||
| def backward(self, loss: Tensor, *args, **kwargs): | ||
| self.optim.backward(loss) | ||
|
|
||
| def clip_grad_by_norm(self, | ||
| max_norm: Union[float, int], | ||
| norm_type: Union[float, int] = 2, | ||
| error_if_nonfinite: bool = False, | ||
| *args, | ||
| **kwargs) -> Tensor: | ||
| warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') | ||
|
|
||
| def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: | ||
| raise NotImplementedError('Gemini does not support clip_grad_by_value') | ||
|
|
||
|
|
||
| class GeminiPlugin(Plugin): | ||
| """ | ||
| Plugin for Gemini. | ||
|
|
||
| Example: | ||
| >>> from colossalai.booster import Booster | ||
| >>> from colossalai.booster.plugin import GeminiPlugin | ||
| >>> | ||
| >>> model, train_dataset, optimizer, criterion = ... | ||
| >>> plugin = GeminiPlugin() | ||
|
|
||
| >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) | ||
| >>> booster = Booster(plugin=plugin) | ||
| >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) | ||
|
|
||
| Args: | ||
| device (torch.device): device to place the model. | ||
| placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". | ||
| pin_memory (bool, optional): use pin memory on CPU. Defaults to False. | ||
| force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. | ||
| strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. | ||
| search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. | ||
| hidden_dim (int, optional): the hidden dimension of DNN. | ||
| Users can provide this argument to speed up searching. | ||
| If users do not know this argument before training, it is ok. We will use a default value 1024. | ||
| min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. | ||
| If the aggregate size of parameters is still samller than the minimum chunk size, | ||
| all parameters will be compacted into one small chunk. | ||
| memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. | ||
| gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) | ||
| which will be used when using hybrid CPU optimizer. | ||
| This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". | ||
| Defaults to 0.0. | ||
| initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. | ||
| min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. | ||
| growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. | ||
| backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. | ||
| growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. | ||
| hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. | ||
| max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. | ||
| max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do | ||
| clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. | ||
| norm_type (float, optional): norm_type used for `clip_grad_norm`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| device: Optional[torch.device] = None, | ||
| placement_policy: str = "cpu", | ||
| pin_memory: bool = False, | ||
| force_outputs_fp32: bool = False, | ||
| strict_ddp_mode: bool = False, | ||
| search_range_mb: int = 32, | ||
| hidden_dim: Optional[int] = None, | ||
| min_chunk_size_mb: float = 32, | ||
| memstats: Optional[MemStats] = None, | ||
| gpu_margin_mem_ratio: float = 0.0, | ||
| initial_scale: float = 2**32, | ||
| min_scale: float = 1, | ||
| growth_factor: float = 2, | ||
| backoff_factor: float = 0.5, | ||
| growth_interval: int = 1000, | ||
| hysteresis: int = 2, | ||
| max_scale: float = 2**32, | ||
| max_norm: float = 0.0, | ||
| norm_type: float = 2.0, | ||
| ) -> None: | ||
|
|
||
| assert dist.is_initialized( | ||
| ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' | ||
| self.rank = dist.get_rank() | ||
| self.world_size = dist.get_world_size() | ||
| self.gemini_config = dict( | ||
| device=(device or get_current_device()), | ||
| placement_policy=placement_policy, | ||
| pin_memory=pin_memory, | ||
| force_outputs_fp32=force_outputs_fp32, | ||
| strict_ddp_mode=strict_ddp_mode, | ||
| search_range_mb=search_range_mb, | ||
| hidden_dim=hidden_dim, | ||
| min_chunk_size_mb=min_chunk_size_mb, | ||
| memstats=memstats, | ||
| ) | ||
| self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) | ||
| self.optim_kwargs = dict(initial_scale=initial_scale, | ||
| growth_factor=growth_factor, | ||
| backoff_factor=backoff_factor, | ||
| growth_interval=growth_interval, | ||
| hysteresis=hysteresis, | ||
| min_scale=min_scale, | ||
| max_scale=max_scale, | ||
| max_norm=max_norm, | ||
| norm_type=norm_type) | ||
|
|
||
| def support_no_sync(self) -> bool: | ||
| return False | ||
|
|
||
| def control_precision(self) -> bool: | ||
| return True | ||
|
|
||
| def supported_precisions(self) -> List[str]: | ||
| return ['fp16'] | ||
|
|
||
| def control_device(self) -> bool: | ||
| return True | ||
|
|
||
| def supported_devices(self) -> List[str]: | ||
| return ['cuda'] | ||
|
|
||
| def prepare_train_dataloader(self, | ||
| dataset, | ||
| batch_size, | ||
| shuffle=False, | ||
| seed=1024, | ||
| drop_last=False, | ||
| pin_memory=False, | ||
| num_workers=0, | ||
| **kwargs): | ||
| r""" | ||
| Prepare a dataloader for distributed training. The dataloader will be wrapped by | ||
| `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. | ||
|
|
||
| Note: | ||
| 1. Evaluation datasets should not be passed to this function. | ||
|
|
||
| Args: | ||
| dataset (`torch.utils.data.Dataset`): The dataset to be loaded. | ||
| shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. | ||
| seed (int, optional): Random worker seed for sampling, defaults to 1024. | ||
| add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. | ||
| drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size | ||
| is not divisible by the batch size. If False and the size of dataset is not divisible by | ||
| the batch size, then the last batch will be smaller, defaults to False. | ||
| pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. | ||
| num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. | ||
| kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in | ||
| `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_. | ||
|
|
||
| Returns: | ||
| :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. | ||
| """ | ||
| _kwargs = kwargs.copy() | ||
| sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) | ||
|
|
||
| # Deterministic dataloader | ||
| def seed_worker(worker_id): | ||
| worker_seed = seed | ||
| np.random.seed(worker_seed) | ||
| torch.manual_seed(worker_seed) | ||
| random.seed(worker_seed) | ||
|
|
||
| return DataLoader(dataset, | ||
| batch_size=batch_size, | ||
| sampler=sampler, | ||
| worker_init_fn=seed_worker, | ||
| drop_last=drop_last, | ||
| pin_memory=pin_memory, | ||
| num_workers=num_workers, | ||
| **_kwargs) | ||
|
|
||
| def configure( | ||
| self, | ||
| model: nn.Module, | ||
| optimizer: Optimizer, | ||
| criterion: Callable = None, | ||
| dataloader: DataLoader = None, | ||
| lr_scheduler: LRScheduler = None, | ||
| ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: | ||
|
|
||
| if not isinstance(model, ModelWrapper): | ||
| # convert model to sync bn | ||
| # FIXME(ver217): gemini does not support sync bn | ||
| # In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16. | ||
| # This inconsistency of dtype will cause the error. | ||
| # We have two possible solutions: | ||
| # 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks. | ||
| # 2. patch sync bn or write a new on. This is relatively easy, but we need to test it. | ||
| # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) | ||
|
|
||
| # wrap the model with Gemini | ||
| model = GeminiModel(model, self.gemini_config) | ||
|
|
||
| if not isinstance(optimizer, OptimizerWrapper): | ||
| optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs) | ||
|
|
||
| return model, optimizer, criterion, dataloader, lr_scheduler | ||
|
|
||
| def control_checkpoint_io(self) -> bool: | ||
| return True | ||
|
|
||
| def get_checkpoint_io(self) -> CheckpointIO: | ||
| return GeminiCheckpointIO() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.