From d58206c7491948197fa469650c88db5ab608613c Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 19 Dec 2022 18:17:48 +0800 Subject: [PATCH 1/5] [utils] lazy init. --- colossalai/fx/profiler/tensor.py | 32 ++- colossalai/utils/model/experimental.py | 345 +++++++++++++++++++++++++ 2 files changed, 360 insertions(+), 17 deletions(-) create mode 100644 colossalai/utils/model/experimental.py diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 43165305f010..767f7938dc81 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,6 +1,4 @@ import uuid -from copy import deepcopy -from typing import Optional import torch from torch.types import _bool, _device, _dtype @@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor): _tensor: torch.Tensor - __slots__ = ['_tensor'] - @staticmethod def __new__(cls, elem, fake_device=None): # Avoid multiple wrapping @@ -47,7 +43,7 @@ def __new__(cls, elem, fake_device=None): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=fake_device if fake_device is not None else elem.device, + device=fake_device if fake_device is not None else torch.device('cpu'), requires_grad=elem.requires_grad) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. @@ -59,8 +55,8 @@ def __new__(cls, elem, fake_device=None): def __repr__(self): if self.grad_fn: - return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})" - return f"MetaTensor({self._tensor}, fake_device='{self.device}')" + return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" + return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -118,16 +114,18 @@ def to(self, *args, **kwargs) -> torch.Tensor: MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan') """ # this imitates c++ function in the way of @overload - device = None - for arg in args: - if isinstance(arg, str) or isinstance(arg, _device): - device = arg - if 'device' in kwargs: - device = kwargs['device'] - result = super().to(*args, **kwargs) - if device is not None: - result = MetaTensor(result, fake_device=device) - return result + fake_device = None + + def replace(x): + nonlocal fake_device + if isinstance(x, str) or isinstance(x, _device): + fake_device = x + return 'meta' + return x + + elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + print(fake_device) + return MetaTensor(elem, fake_device=fake_device) def cpu(self, *args, **kwargs): if self.device.type == 'cpu': diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py new file mode 100644 index 000000000000..6a3b97a57ee1 --- /dev/null +++ b/colossalai/utils/model/experimental.py @@ -0,0 +1,345 @@ +import contextlib +import copy +import pprint +from typing import Callable, List + +import torch +import torch.nn as nn +from torch.types import _bool, _device, _dtype +from torch.utils._pytree import tree_map + +from colossalai.fx.profiler import MetaTensor +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.utils.model.utils import substitute_init_recursively + +init = torch.nn.init + +# reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/parameter.py#L73 +_TorchTensorMethods = [ + torch.Tensor.half, + torch.Tensor.float, + torch.Tensor.double, + torch.Tensor.char, + torch.Tensor.short, + torch.Tensor.int, + torch.Tensor.long, + torch.Tensor.cuda, + torch.Tensor.cpu, +] + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_TorchFactoryFunc = [ + "arange", + "empty", + "eye", + "full", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", +] + + +def init_from_spec(cls, t: torch.Tensor, spec: ColoTensorSpec): + if cls == ColoTensor: + return cls.from_torch_tensor(t, spec) + else: + return cls.from_torch_tensor(t, t.requires_grad, spec) + + +class UninitializedTensor(torch.Tensor): + + _cls_to_become = ColoTensor + _repr = True + + @staticmethod + def __new__(cls, func, *args, dtype=None, device=None, **kwargs): + elem = func(*args, dtype=dtype, device='meta', **kwargs) + r = torch.Tensor._make_wrapper_subclass(cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=device if device is not None else torch.device('cpu'), + requires_grad=elem.requires_grad) + r._data = MetaTensor(elem, fake_device=device) + return r + + def __init__(self, func, *args, dtype=None, device=None, **kwargs): + self._factory_fn = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs) + self._cached_fn = list() # (func, args, kwargs) + self._spec = ColoTensorSpec(pg=None, dist_attr=None, compute_attr=None) # Default Spec + + def __repr__(self): + if self._repr: + self.__class__._repr = False + s = f'UninitializedTensor: {self._factory_fn}\n'\ + f'_data: {self._data}\n'\ + f'cached_fn: {pprint.pformat(self._cached_fn)}\n'\ + f'spec: {self._spec}' + self.__class__._repr = True + return s + else: + return 'UninitializedTensor(...)' + + def materialize(self): + func, args, kwargs = self._factory_fn + t = func(*args, **kwargs) + + # apply cached_fn + t = self._apply_cache(t) + + # apply spec + return init_from_spec(self._cls_to_become, t, self._spec) + + # TODO(super-dainiu): device seems incorrect + def traceable(self): + func, args, kwargs = self._factory_fn + t = MetaTensor(func(*args, **{**kwargs, 'device': 'meta'}), fake_device=kwargs['device']) + return self._apply_cache(t) + + def _apply_cache(self, t): + # apply cached methods + # super-dainiu: support methods for single Tensor only + replace = lambda x: t if isinstance(x, UninitializedTensor) else x + + for (func, args, kwargs) in self._cached_fn: + o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) + t = o if isinstance(o, torch.Tensor) else t # if func returns non-Tensor, discard the value + return t + + # cache everything + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + t = None + + if func in _TorchTensorMethods: + t: UninitializedTensor = args[0].clone() + t._cached_fn.append((func, (t,) + args[1:], kwargs)) + t._data = func(t._data, *args[1:], **kwargs) + if isinstance(t._data, MetaTensor): + return t + else: + return t._data + + def unwrap(t_): + nonlocal t + if isinstance(t_, UninitializedTensor): + t = t_.clone() + t._cached_fn.append((func, args, kwargs)) + t_ = t_._data + return t_ + + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + t._data = func(*args, **kwargs) + + if isinstance(t._data, MetaTensor): + return t + else: + return t._data + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + pass + + def to(self, *args, **kwargs) -> "UninitializedTensor": + t = self.clone() + t._cached_fn.append((torch.Tensor.to, (t,) + args, kwargs)) + t._data = t._data.to(*args, **kwargs) + return t + + def clone(self) -> "UninitializedTensor": + func, args, kwargs = self._factory_fn + t = UninitializedTensor(func, *args, **kwargs) + t._cached_fn = [x for x in self._cached_fn] + t._spec = copy.deepcopy(self._spec) + return t + + @property + def spec(self) -> ColoTensorSpec: + return self._spec + + @spec.setter + def spec(self, other: ColoTensorSpec): + self._spec = other + + +# TODO: Not correct +class UninitializedParameter(UninitializedTensor, nn.Parameter): + + _cls_to_become = ColoParameter + + @staticmethod + def __new__(cls, elem=None, requires_grad=True): + if elem is None: + elem = UninitializedTensor(torch.empty, 0) + if type(elem) is UninitializedTensor or type(elem) is UninitializedParameter: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + r = torch.Tensor._make_wrapper_subclass(cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=requires_grad) + r._data = elem._data + return r + raise RuntimeError(f"Creating an `UninitializedParameter` with `Tensor.subclasses` of " + f"`{type(elem).__name__}` is unexpected. Should be one of `UninitializedParameter`" + f", `UnintializedTensor`, or None.") + + def __init__(self, elem=None, requires_grad=True): + self._factory_fn = elem._factory_fn + self._cached_fn = elem._cached_fn + self._spec = elem._spec + + def clone(self): + return UninitializedParameter(super().clone(), requires_grad=self.requires_grad) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +class LazyInitContext(): + """ + A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor + initialization functions for lazy initialization + + Note: + This API is only experimental and subject to future changes. + + Usage: + with LazyInitContext() as ctx: + model = nn.Linear(10, 10) + model.weight.zero_() + + # make sure the weight is a meta tensor + assert model.weight.is_meta + + # initialize weights + ctx.lazy_init_parameters(model) + + # make sure the weight is not a meta tensor + # and initialized correctly + assert not model.weight.is_meta and torch.all(model.weight == 0) + + Args: + to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This + argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. + extra_torch_tensor_func (List[str]): extra torch tensor functions related + to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. + """ + + def __init__(self): + self.overrides = {} + self._orig_nn_param = torch.nn.Parameter + + def __enter__(self): + + def wrap_factory_method(target): + # factory functions (eg. torch.empty()) + def wrapper(*args, **kwargs): + return UninitializedTensor(target, *args, **kwargs) + + return wrapper, target + + def wrap_factory_like_method(orig_target, target): + # factory_like functions (eg. torch.empty_like()) + def wrapper(*args, **kwargs): + orig_t = args[0] + return UninitializedTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + + return wrapper, target + + self.overrides = { + target: wrap_factory_method(getattr(torch, target)) + for target in _TorchFactoryFunc + if callable(getattr(torch, target, None)) + } + + self.overrides.update({ + target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) + for target in _TorchFactoryFunc + if callable(getattr(torch, target + '_like', None)) + }) + + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, wrapper) + + # cannot monkey patch nn.Parameter because it is a class (????) + setattr(torch.nn.parameter, 'Parameter', UninitializedParameter) + + def __exit__(self, exc_type, exc_val, exc_tb): + for name, (wrapper, orig) in self.overrides.items(): + setattr(torch, name, orig) + setattr(torch.nn.parameter, 'Parameter', self._orig_nn_param) + + @staticmethod + def materialize(module: torch.nn.Module): + """Materialize and shard ``nn.Module`` -- Initialize all ``nn.Parameter`` as ``ColoParameter`` + + Args: + module (torch.nn.Module): LazyInit Module + """ + + @torch.no_grad() + def init_recursively(module: nn.Module): + # recursively initialize the module + for mod in module.children(): + init_recursively(mod) + + # initialize tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + setattr(module, name, param.materialize()) + + for name, buf in module.named_buffers(recurse=False): + setattr(module, name, buf.materialize()) + + init_recursively(module) + return module + + @staticmethod + @contextlib.contextmanager + def traceable(module: torch.nn.Module): + """Enable ``ColoTracer`` -- Initialize all ``nn.Parameters`` as ``MetaTensor`` + + Args: + module (torch.nn.Module): LazyInit Module + """ + orig_val = dict() + + def init_recursively(module: nn.Module): + # recursively initialize the module + for mod in module.children(): + init_recursively(mod) + + # initialize tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + setattr(module, name, param.traceable()) + orig_val[(module, name)] = param + + for name, buf in module.named_buffers(recurse=False): + setattr(module, name, buf.traceable()) + orig_val[(module, name)] = buf + + init_recursively(module) + + yield + + # restore original values + for (module, name), val in orig_val.items(): + setattr(module, name, val) + + # Things to hack: + # 1. torch.Tensor factory function (DONE) + # 2. nn.Parameter + # 3. init (DONE) From 1abc18792ab36a7bbfdd8fb4f04418633f8a138e Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Mon, 19 Dec 2022 18:23:14 +0800 Subject: [PATCH 2/5] [utils] remove description. --- colossalai/utils/model/experimental.py | 28 -------------------------- 1 file changed, 28 deletions(-) diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index 6a3b97a57ee1..1905b7eb588a 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -210,34 +210,6 @@ def clone(self): class LazyInitContext(): - """ - A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor - initialization functions for lazy initialization - - Note: - This API is only experimental and subject to future changes. - - Usage: - with LazyInitContext() as ctx: - model = nn.Linear(10, 10) - model.weight.zero_() - - # make sure the weight is a meta tensor - assert model.weight.is_meta - - # initialize weights - ctx.lazy_init_parameters(model) - - # make sure the weight is not a meta tensor - # and initialized correctly - assert not model.weight.is_meta and torch.all(model.weight == 0) - - Args: - to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This - argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. - extra_torch_tensor_func (List[str]): extra torch tensor functions related - to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. - """ def __init__(self): self.overrides = {} From ebb96de402221bb9f727ba0cb780eac13655bc29 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Tue, 20 Dec 2022 22:55:38 +0800 Subject: [PATCH 3/5] [utils] complete. --- colossalai/fx/profiler/tensor.py | 15 ++- colossalai/utils/model/experimental.py | 149 ++++++++----------------- 2 files changed, 56 insertions(+), 108 deletions(-) diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 767f7938dc81..7606f17cf9d5 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -72,13 +72,13 @@ def unwrap(x): x = x.to(torch.device('meta')) return x + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + if 'device' in kwargs: fake_device = kwargs['device'] kwargs['device'] = torch.device('meta') - args = tree_map(unwrap, args) - kwargs = tree_map(unwrap, kwargs) - # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) @@ -124,7 +124,6 @@ def replace(x): return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) - print(fake_device) return MetaTensor(elem, fake_device=fake_device) def cpu(self, *args, **kwargs): @@ -132,7 +131,7 @@ def cpu(self, *args, **kwargs): return self.to(*args, **kwargs) return self.to(*args, device='cpu', **kwargs) - def cuda(self, *args, **kwargs): - if self.device.type == 'cuda': - return self.to(*args, **kwargs) - return self.to(*args, device='cuda', **kwargs) + def cuda(self, device=None, non_blocking=False): + if device is not None: + return self.to(device=device, non_blocking=non_blocking) + return self.to(device='cuda:0', non_blocking=non_blocking) diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index 1905b7eb588a..a1fd5e338ef5 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -1,31 +1,14 @@ import contextlib import copy import pprint -from typing import Callable, List +from typing import Callable, List, Union import torch import torch.nn as nn -from torch.types import _bool, _device, _dtype from torch.utils._pytree import tree_map from colossalai.fx.profiler import MetaTensor from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec -from colossalai.utils.model.utils import substitute_init_recursively - -init = torch.nn.init - -# reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/parameter.py#L73 -_TorchTensorMethods = [ - torch.Tensor.half, - torch.Tensor.float, - torch.Tensor.double, - torch.Tensor.char, - torch.Tensor.short, - torch.Tensor.int, - torch.Tensor.long, - torch.Tensor.cuda, - torch.Tensor.cpu, -] # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _TorchFactoryFunc = [ @@ -41,19 +24,12 @@ "randint", "randperm", "zeros", + "tensor", ] -def init_from_spec(cls, t: torch.Tensor, spec: ColoTensorSpec): - if cls == ColoTensor: - return cls.from_torch_tensor(t, spec) - else: - return cls.from_torch_tensor(t, t.requires_grad, spec) - - class UninitializedTensor(torch.Tensor): - _cls_to_become = ColoTensor _repr = True @staticmethod @@ -67,7 +43,7 @@ def __new__(cls, func, *args, dtype=None, device=None, **kwargs): layout=elem.layout, device=device if device is not None else torch.device('cpu'), requires_grad=elem.requires_grad) - r._data = MetaTensor(elem, fake_device=device) + r._meta_data = MetaTensor(elem, fake_device=device) return r def __init__(self, func, *args, dtype=None, device=None, **kwargs): @@ -79,7 +55,7 @@ def __repr__(self): if self._repr: self.__class__._repr = False s = f'UninitializedTensor: {self._factory_fn}\n'\ - f'_data: {self._data}\n'\ + f'meta_data: {self._meta_data}\n'\ f'cached_fn: {pprint.pformat(self._cached_fn)}\n'\ f'spec: {self._spec}' self.__class__._repr = True @@ -87,7 +63,7 @@ def __repr__(self): else: return 'UninitializedTensor(...)' - def materialize(self): + def materialize(self) -> Union[ColoParameter, ColoTensor]: func, args, kwargs = self._factory_fn t = func(*args, **kwargs) @@ -95,22 +71,37 @@ def materialize(self): t = self._apply_cache(t) # apply spec - return init_from_spec(self._cls_to_become, t, self._spec) + if isinstance(self, nn.Parameter): + return ColoParameter.from_torch_tensor(t, t.requires_grad, self._spec) + else: + return ColoTensor.from_torch_tensor(t, self._spec) - # TODO(super-dainiu): device seems incorrect - def traceable(self): + def traceable(self) -> MetaTensor: func, args, kwargs = self._factory_fn t = MetaTensor(func(*args, **{**kwargs, 'device': 'meta'}), fake_device=kwargs['device']) - return self._apply_cache(t) + if isinstance(self, nn.Parameter): + return nn.Parameter(self._apply_cache(t), requires_grad=self.requires_grad) + else: + return self._apply_cache(t) - def _apply_cache(self, t): + def _apply_cache(self, t) -> torch.Tensor: # apply cached methods # super-dainiu: support methods for single Tensor only replace = lambda x: t if isinstance(x, UninitializedTensor) else x + packed = None for (func, args, kwargs) in self._cached_fn: - o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - t = o if isinstance(o, torch.Tensor) else t # if func returns non-Tensor, discard the value + if func == torch.Tensor.requires_grad_: + packed = func, args, kwargs # requires grad should be set at last + else: + o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) + t = o if isinstance(o, torch.Tensor) else t # if func returns non-Tensor, discard the value + + # super-dainiu: set requires_grad after all inplace-ops are done + if packed is not None: + func, args, kwargs = packed + func(*tree_map(replace, args), **tree_map(replace, kwargs)) + return t # cache everything @@ -120,31 +111,32 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} t = None - if func in _TorchTensorMethods: + unwrap = lambda t: t._meta_data if isinstance(t, UninitializedTensor) else t + + if isinstance(func, torch._C.ScriptMethod): t: UninitializedTensor = args[0].clone() t._cached_fn.append((func, (t,) + args[1:], kwargs)) - t._data = func(t._data, *args[1:], **kwargs) - if isinstance(t._data, MetaTensor): - return t - else: - return t._data + t._meta_data = getattr(t._meta_data, func.name)(*tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs)) + + else: - def unwrap(t_): - nonlocal t - if isinstance(t_, UninitializedTensor): - t = t_.clone() - t._cached_fn.append((func, args, kwargs)) - t_ = t_._data - return t_ + def unwrap(t_): + nonlocal t + if isinstance(t_, UninitializedTensor): + t = t_ if (func.__name__.endswith('_') + or func.__name__ == "__set__") and not (func.__name__.endswith('__')) else t_.clone() + t._cached_fn.append((func, args, kwargs)) + t_ = t_._meta_data + return t_ - args = tree_map(unwrap, args) - kwargs = tree_map(unwrap, kwargs) - t._data = func(*args, **kwargs) + args = tree_map(unwrap, args) + kwargs = tree_map(unwrap, kwargs) + t._meta_data = func(*args, **kwargs) - if isinstance(t._data, MetaTensor): + if isinstance(t._meta_data, MetaTensor): return t else: - return t._data + return t._meta_data @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -153,7 +145,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def to(self, *args, **kwargs) -> "UninitializedTensor": t = self.clone() t._cached_fn.append((torch.Tensor.to, (t,) + args, kwargs)) - t._data = t._data.to(*args, **kwargs) + t._meta_data = t._meta_data.to(*args, **kwargs) return t def clone(self) -> "UninitializedTensor": @@ -171,42 +163,8 @@ def spec(self) -> ColoTensorSpec: def spec(self, other: ColoTensorSpec): self._spec = other - -# TODO: Not correct -class UninitializedParameter(UninitializedTensor, nn.Parameter): - - _cls_to_become = ColoParameter - - @staticmethod - def __new__(cls, elem=None, requires_grad=True): - if elem is None: - elem = UninitializedTensor(torch.empty, 0) - if type(elem) is UninitializedTensor or type(elem) is UninitializedParameter: - # For ease of BC maintenance, keep this path for standard Tensor. - # Eventually (tm), we should change the behavior for standard Tensor to match. - r = torch.Tensor._make_wrapper_subclass(cls, - elem.size(), - strides=elem.stride(), - storage_offset=elem.storage_offset(), - dtype=elem.dtype, - layout=elem.layout, - device=elem.device, - requires_grad=requires_grad) - r._data = elem._data - return r - raise RuntimeError(f"Creating an `UninitializedParameter` with `Tensor.subclasses` of " - f"`{type(elem).__name__}` is unexpected. Should be one of `UninitializedParameter`" - f", `UnintializedTensor`, or None.") - - def __init__(self, elem=None, requires_grad=True): - self._factory_fn = elem._factory_fn - self._cached_fn = elem._cached_fn - self._spec = elem._spec - - def clone(self): - return UninitializedParameter(super().clone(), requires_grad=self.requires_grad) - - __torch_function__ = torch._C._disabled_torch_function_impl + def detach(self): + return self.clone() class LazyInitContext(): @@ -247,13 +205,9 @@ def wrapper(*args, **kwargs): for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, wrapper) - # cannot monkey patch nn.Parameter because it is a class (????) - setattr(torch.nn.parameter, 'Parameter', UninitializedParameter) - def __exit__(self, exc_type, exc_val, exc_tb): for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, orig) - setattr(torch.nn.parameter, 'Parameter', self._orig_nn_param) @staticmethod def materialize(module: torch.nn.Module): @@ -310,8 +264,3 @@ def init_recursively(module: nn.Module): # restore original values for (module, name), val in orig_val.items(): setattr(module, name, val) - - # Things to hack: - # 1. torch.Tensor factory function (DONE) - # 2. nn.Parameter - # 3. init (DONE) From 42d72e1a854477c4b148759564c7bbb13c383691 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 23 Dec 2022 19:38:49 +0800 Subject: [PATCH 4/5] [utils] finalize. --- colossalai/utils/model/experimental.py | 324 +++++++++++++++++++------ 1 file changed, 249 insertions(+), 75 deletions(-) diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index a1fd5e338ef5..72d747b2d473 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -1,17 +1,20 @@ import contextlib import copy +import gc import pprint -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import torch import torch.nn as nn from torch.utils._pytree import tree_map +from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.profiler import MetaTensor -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html -_TorchFactoryFunc = [ +_TorchFactoryMethod = [ "arange", "empty", "eye", @@ -27,10 +30,52 @@ "tensor", ] - -class UninitializedTensor(torch.Tensor): +orig_empty = torch.empty # avoid override + +scm = ShapeConsistencyManager() + + +class LazyTensor(torch.Tensor): + """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf). + + Usage: + 1. Use ``LazyTensor`` instead of ``torch.Tensor``. + >>> x = LazyTensor(torch.zeros, 2, 3) + >>> x += 1 + >>> y = x * x + >>> y = y.cuda().half() + >>> y[0, 0] = 0 + >>> y = y.materialize() # materialize the tensor + >>> print(y) + tensor([[0., 1., 1.], + [1., 1., 1.]], device='cuda:0', dtype=torch.float16) + + 2. Generate ``MetaTensor`` from ``LazyTensor`` + >>> x = LazyTensor(torch.zeros, 2, 3) + >>> x.reshape(3, 2) + >>> x = x.traceable() # generate ``MetaTensor`` + >>> print(x) + MetaTensor(..., size=(3, 2), device=cpu, dtype=torch.float32) + + 3. Use ``LazyTensor`` to generate sharded ``nn.Parameter``. + >>> x = LazyTensor(torch.zeros, 2, 3) + >>> x.spec = ... # some ``ShardingSpec`` + >>> x.distribute() # distribute the tensor according to the ``ShardingSpec`` + + Warnings: + 1. Cases that ``LazyTensor`` can't deal with. + >>> x = LazyTensor(torch.ones, 2, 3) + >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion + + 2. ``LazyTensor.materialize()`` can't be called multiple times. + >>> x = LazyTensor(torch.ones, 2, 3) + >>> x.materialize() + >>> x.materialize() # this is disallowed + """ _repr = True + _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _cached_data: Optional[torch.Tensor] = None # materialized data @staticmethod def __new__(cls, func, *args, dtype=None, device=None, **kwargs): @@ -47,138 +92,243 @@ def __new__(cls, func, *args, dtype=None, device=None, **kwargs): return r def __init__(self, func, *args, dtype=None, device=None, **kwargs): - self._factory_fn = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs) - self._cached_fn = list() # (func, args, kwargs) - self._spec = ColoTensorSpec(pg=None, dist_attr=None, compute_attr=None) # Default Spec + self._factory_method = (func, args, {'dtype': dtype, 'device': device, **kwargs}) # (func, args, kwargs) + self._cached_buffer = list() # (func, args, kwargs) + self._spec = None + self._data = self def __repr__(self): if self._repr: + # avoid recursive representation self.__class__._repr = False - s = f'UninitializedTensor: {self._factory_fn}\n'\ - f'meta_data: {self._meta_data}\n'\ - f'cached_fn: {pprint.pformat(self._cached_fn)}\n'\ + s = f'LazyTensor(..., size={tuple(self._meta_data.shape)}, device={self._meta_data.device}, dtype={self._meta_data.dtype})\n'\ + f'factory method: {self._factory_method}\n'\ + f'cached: {pprint.pformat(self._cached_buffer) if self._cached_data is None else self._cached_data}\n'\ f'spec: {self._spec}' self.__class__._repr = True return s else: - return 'UninitializedTensor(...)' + return 'LazyTensor(...)' - def materialize(self) -> Union[ColoParameter, ColoTensor]: - func, args, kwargs = self._factory_fn - t = func(*args, **kwargs) + def materialize(self) -> torch.Tensor: + """Materialize the ``LazyTensor`` to ``torch.Tensor``. - # apply cached_fn - t = self._apply_cache(t) + Warnings: + Calling ``self.materialize()`` will clear all cached sequence and factory method, + because we don't allow materialize the same ``LazyTensor`` twice. + This is mentioned in the paper: https://arxiv.org/pdf/2102.13267.pdf (Part 4.3). - # apply spec + Returns: + torch.Tensor: The materialized tensor. + """ + target = self._data._realize_cached_data() if isinstance(self, nn.Parameter): - return ColoParameter.from_torch_tensor(t, t.requires_grad, self._spec) - else: - return ColoTensor.from_torch_tensor(t, self._spec) + target = nn.Parameter(target, requires_grad=self.requires_grad) + self._clear_all() + return target def traceable(self) -> MetaTensor: - func, args, kwargs = self._factory_fn - t = MetaTensor(func(*args, **{**kwargs, 'device': 'meta'}), fake_device=kwargs['device']) + """Generate ``MetaTensor`` from ``LazyTensor``. (Mostly for tracing) + + Returns: + MetaTensor: The generated ``MetaTensor``. + """ if isinstance(self, nn.Parameter): - return nn.Parameter(self._apply_cache(t), requires_grad=self.requires_grad) + return nn.Parameter(self._meta_data, requires_grad=self.requires_grad) else: - return self._apply_cache(t) + return self._meta_data + + def distribute(self) -> torch.Tensor: + """Distribute the ``LazyTensor`` according to the ``ShardingSpec``. + + Returns: + torch.Tensor: The sharded tensor. + """ + if self._spec is None: + raise RuntimeError('ShardingSpec is not set for\n{self}') + spec, device_mesh = self._spec, self._spec.device_mesh + t = self.materialize() + + # TODO(some man): better not be coupled with auto-parallel + t.data = scm.apply_for_autoparallel_runtime(t.data, ShardingSpec(device_mesh, t.shape, {}), + spec).detach().clone() + return t + + def _realize_cached_data(self) -> torch.Tensor: + # self._cached_data should be generated after the first call of this function + if self._cached_data is None: + if self._factory_method is not None: + # apply factory method + func, args, kwargs = self._factory_method + + # apply cached sequence + self._cached_data = self._apply_cache_buffer(func(*args, **kwargs)) + else: + # apply cached sequence only + self._cached_data = self._apply_cache_buffer() + return self._cached_data - def _apply_cache(self, t) -> torch.Tensor: - # apply cached methods + def _apply_cache_buffer(self, target=None) -> torch.Tensor: + # dump all cached sequence # super-dainiu: support methods for single Tensor only - replace = lambda x: t if isinstance(x, UninitializedTensor) else x + def replace(x): + if x is self: + return target + elif isinstance(x, LazyTensor): + return x._realize_cached_data() + return x + packed = None - for (func, args, kwargs) in self._cached_fn: + for (func, args, kwargs) in self._cached_buffer: if func == torch.Tensor.requires_grad_: packed = func, args, kwargs # requires grad should be set at last else: o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - t = o if isinstance(o, torch.Tensor) else t # if func returns non-Tensor, discard the value + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value # super-dainiu: set requires_grad after all inplace-ops are done if packed is not None: func, args, kwargs = packed func(*tree_map(replace, args), **tree_map(replace, kwargs)) - return t + return target - # cache everything + # clear all means: + # 1. clear factory method + # 2. clear cached sequence + # 3. clear cached data + def _clear_all(self): + self._cached_data = None + self._cached_buffer = None + self._data = None + gc.collect() # avoid memory leak + + # cache everything with __torch_function__ @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - t = None - - unwrap = lambda t: t._meta_data if isinstance(t, UninitializedTensor) else t + target = None if isinstance(func, torch._C.ScriptMethod): - t: UninitializedTensor = args[0].clone() - t._cached_fn.append((func, (t,) + args[1:], kwargs)) - t._meta_data = getattr(t._meta_data, func.name)(*tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs)) + + def unwrap(x): + if isinstance(x, LazyTensor): + return x._meta_data + return x + + target: LazyTensor = args[0].clone() + target._cached_buffer.append((func, args, kwargs)) + target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), + **tree_map(unwrap, kwargs)) else: - def unwrap(t_): - nonlocal t - if isinstance(t_, UninitializedTensor): - t = t_ if (func.__name__.endswith('_') - or func.__name__ == "__set__") and not (func.__name__.endswith('__')) else t_.clone() - t._cached_fn.append((func, args, kwargs)) - t_ = t_._meta_data - return t_ + def unwrap(x): + nonlocal target + if isinstance(x, LazyTensor): + target = x if (func.__name__.endswith('_') and not (func.__name__.endswith('__')) + or func.__name__ == "__setitem__") else x.clone() + target._cached_buffer.append((func, args, kwargs)) + return x._meta_data + return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - t._meta_data = func(*args, **kwargs) + o = func(*args, **kwargs) - if isinstance(t._meta_data, MetaTensor): - return t + if isinstance(o, MetaTensor): + target._meta_data = o + return target else: - return t._meta_data + return o @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass + pass # skip - def to(self, *args, **kwargs) -> "UninitializedTensor": - t = self.clone() - t._cached_fn.append((torch.Tensor.to, (t,) + args, kwargs)) - t._meta_data = t._meta_data.to(*args, **kwargs) - return t + def clone(self) -> "LazyTensor": + """Create a new ``LazyTensor`` with same cached sequence and factory method. - def clone(self) -> "UninitializedTensor": - func, args, kwargs = self._factory_fn - t = UninitializedTensor(func, *args, **kwargs) - t._cached_fn = [x for x in self._cached_fn] - t._spec = copy.deepcopy(self._spec) - return t + Returns: + LazyTensor: the new ``LazyTensor`` + """ + target = LazyTensor(orig_empty, 0, dtype=self._meta_data.dtype, device=self._meta_data.device) + target._factory_method = None + target._cached_buffer = list() + target._meta_data = self._meta_data.clone() + target._cached_data = self._cached_data.clone() if self._cached_data is not None else None + target._spec = copy.deepcopy(self._spec) + return target + + def detach(self) -> "LazyTensor": + target = self.clone() + target._cached_buffer.append((torch.Tensor.detach_, (self,), {})) + return target @property - def spec(self) -> ColoTensorSpec: + def spec(self) -> ShardingSpec: return self._spec @spec.setter - def spec(self, other: ColoTensorSpec): + def spec(self, other: ShardingSpec): self._spec = other - def detach(self): - return self.clone() + @property + def data(self) -> "LazyTensor": + return self._data.detach() + + @data.setter + def data(self, other: "LazyTensor") -> "LazyTensor": + """This avoid the following infinite recursion, which is very common in ``nn.Module`` initialization. + + Usage: + >>> a = LazyTensor(torch.empty, 0, dtype=torch.float32, device='cpu') + >>> b = a.cuda() + >>> a.data = b + """ + self._data = other class LazyInitContext(): + """Context manager for lazy initialization. Enables initializing the model without allocating real memory. + + Usage: + 1. The model is initialized, but no real memory is allocated. + >>> ctx = LazyInitContext() + >>> with ctx: + >>> model = MyModel().cuda() + + 2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated. + >>> with ctx.traceable(model): + >>> gm = symbolic_trace(model, meta_args=meta_args) + >>> # Solve the execution strategy and apply the strategy to the model + >>> strategy = StrategyAndSpec() + + 3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device) + >>> model = ctx.materialize(model) + + 3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario) + >>> model = apply_strategy_to_all_params(model, strategy) + >>> model = ctx.distribute(model) + + Warnings: + This API is still experimental and further modifications can be made to it. + For example: + 1. Quantization strategies can be applied before allocating real memory. + 2. Lazy initialization seems slower than normal initialization. + """ def __init__(self): self.overrides = {} - self._orig_nn_param = torch.nn.Parameter def __enter__(self): def wrap_factory_method(target): # factory functions (eg. torch.empty()) def wrapper(*args, **kwargs): - return UninitializedTensor(target, *args, **kwargs) + return LazyTensor(target, *args, **kwargs) return wrapper, target @@ -186,19 +336,19 @@ def wrap_factory_like_method(orig_target, target): # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return UninitializedTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + return LazyTensor(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) return wrapper, target self.overrides = { target: wrap_factory_method(getattr(torch, target)) - for target in _TorchFactoryFunc + for target in _TorchFactoryMethod if callable(getattr(torch, target, None)) } self.overrides.update({ target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) - for target in _TorchFactoryFunc + for target in _TorchFactoryMethod if callable(getattr(torch, target + '_like', None)) }) @@ -211,10 +361,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): @staticmethod def materialize(module: torch.nn.Module): - """Materialize and shard ``nn.Module`` -- Initialize all ``nn.Parameter`` as ``ColoParameter`` + """Initialize all ``nn.Parameter`` from ``LazyTensor``. Args: - module (torch.nn.Module): LazyInit Module + module (torch.nn.Module): Target ``nn.Module`` """ @torch.no_grad() @@ -233,13 +383,37 @@ def init_recursively(module: nn.Module): init_recursively(module) return module + @staticmethod + def distribute(module: torch.nn.Module): + """Initialize and shard all ``nn.Parameter`` from ``LazyTensor``. + + Args: + module (torch.nn.Module): Sharded target ``nn.Module`` + """ + + @torch.no_grad() + def init_recursively(module: nn.Module): + # recursively initialize the module + for mod in module.children(): + init_recursively(mod) + + # initialize tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + setattr(module, name, param.distribute()) + + for name, buf in module.named_buffers(recurse=False): + setattr(module, name, buf.distribute()) + + init_recursively(module) + return module + @staticmethod @contextlib.contextmanager def traceable(module: torch.nn.Module): - """Enable ``ColoTracer`` -- Initialize all ``nn.Parameters`` as ``MetaTensor`` + """Initialize all ``nn.Parameters`` as ``MetaTensor``. This enables ``ColoTracer`` with control flow. Args: - module (torch.nn.Module): LazyInit Module + module (torch.nn.Module): Traceable ``nn.Module`` with ``MetaTensor`` as parameters. """ orig_val = dict() From f364a66f9780d0f991f44f464e3f9619b6a6eb11 Mon Sep 17 00:00:00 2001 From: dainiu <19307110036@fudan.edu.cn> Date: Fri, 23 Dec 2022 20:10:45 +0800 Subject: [PATCH 5/5] [utils] fix names. --- colossalai/utils/model/experimental.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index 72d747b2d473..8291227b7ba2 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -147,12 +147,12 @@ def distribute(self) -> torch.Tensor: if self._spec is None: raise RuntimeError('ShardingSpec is not set for\n{self}') spec, device_mesh = self._spec, self._spec.device_mesh - t = self.materialize() + target = self.materialize() # TODO(some man): better not be coupled with auto-parallel - t.data = scm.apply_for_autoparallel_runtime(t.data, ShardingSpec(device_mesh, t.shape, {}), - spec).detach().clone() - return t + target.data = scm.apply_for_autoparallel_runtime(target.data, ShardingSpec(device_mesh, target.shape, {}), + spec).detach().clone() + return target def _realize_cached_data(self) -> torch.Tensor: # self._cached_data should be generated after the first call of this function