Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
super().save_lr_scheduler(lr_scheduler, checkpoint)


class GeminiModel(ModelWrapper):

def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)

def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module


class GeminiOptimizer(OptimizerWrapper):

def __init__(self,
Expand Down Expand Up @@ -393,7 +382,9 @@ def configure(
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)

# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
# TODO(ver217): remove this line
model._colo_zero_stage = 3

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
Expand Down
68 changes: 22 additions & 46 deletions colossalai/tensor/colo_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import torch

from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec

from .colo_tensor import _convert_output

WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}


def is_no_hook_op(func) -> bool:
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS


def filter_colo_parameters(*args, **kwargs):
Expand Down Expand Up @@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):

"""

def __new__(cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)

def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []

@property
def shared_param_modules(self):
return self._shared_param_modules

@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor

def __repr__(self):
return super(ColoParameter, self).__repr__()

@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return ret
if kwargs is None:
kwargs = {}
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return _convert_output(ret, func)
return super().__torch_function__(func, types, args, kwargs)

def __deepcopy__(self, memo):
Expand All @@ -96,9 +74,7 @@ def __deepcopy__(self, memo):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data,
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
tensor = ColoParameter(data, self.requires_grad)
memo[id(self)] = tensor
return tensor

Expand Down
Loading