From b789fc70812a33b5e8955a61a10c69755d78bb8a Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 19 Sep 2023 19:19:41 +0800 Subject: [PATCH 1/5] [lazy] support _like methods and clamp --- .isort.cfg | 1 + colossalai/lazy/construction.py | 84 +++++++++++++++++++++ colossalai/lazy/lazy_init.py | 126 ++++++++++++++++++++++++-------- tests/test_lazy/test_models.py | 8 +- 4 files changed, 182 insertions(+), 37 deletions(-) create mode 100644 colossalai/lazy/construction.py diff --git a/.isort.cfg b/.isort.cfg index 4f881c8b3dda..ccbf575fdbfa 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -4,3 +4,4 @@ multi_line_output=3 include_trailing_comma = true ignore_comments = true profile = black +honor_noqa = true diff --git a/colossalai/lazy/construction.py b/colossalai/lazy/construction.py new file mode 100644 index 000000000000..c95eeb195726 --- /dev/null +++ b/colossalai/lazy/construction.py @@ -0,0 +1,84 @@ +from contextlib import contextmanager +from typing import Callable, Dict, Tuple + +import torch + +__all__ = [ + "_LEGACY_TENSOR_CONSTRUCTOR", + "_NO_META_FACTORY", + "_NORMAL_FACTORY", + "ConstructorManager", +] + +# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html +_NORMAL_FACTORY = [ + "arange", + "full", + "empty", + "linspace", + "logspace", + "ones", + "rand", + "randn", + "randint", + "randperm", + "zeros", + "tensor", +] + +# factory function that does not support meta tensor backend +_NO_META_FACTORY = [ + "eye", +] + +_LEGACY_TENSOR_CONSTRUCTOR = { + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, +} + + +class ConstructorManager: + # function name: (new, old) + overwrites: Dict[str, Tuple[Callable, Callable]] = {} + changed: bool = False + + @staticmethod + def apply(overwrites: Dict[Callable, Callable]): + ConstructorManager.overwrites.clear() + ConstructorManager.overwrites.update(overwrites) + ConstructorManager.redo() + + @staticmethod + def undo(): + assert ConstructorManager.changed, "No constructor change to undo" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, old) + ConstructorManager.changed = False + + @staticmethod + def redo(): + assert not ConstructorManager.changed, "Constructor already changed" + for name, (new, old) in ConstructorManager.overwrites.items(): + setattr(torch, name, new) + ConstructorManager.changed = True + + @staticmethod + @contextmanager + def disable(): + ConstructorManager.undo() + yield + ConstructorManager.redo() + + @staticmethod + def clear(): + if ConstructorManager.changed: + ConstructorManager.undo() + ConstructorManager.overwrites.clear() diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index ebaf2e1600fc..5f5329adf383 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -4,15 +4,19 @@ import torch import torch.distributed as dist import torch.nn as nn +from packaging import version from torch import Tensor from torch.nn import Parameter from torch.utils._pytree import tree_map -from colossalai._analyzer._subclasses import MetaTensor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.d_tensor import distribute_tensor from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from .construction import ConstructorManager + +import colossalai._analyzer._subclasses._meta_registration # noqa + # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ "arange", @@ -54,7 +58,21 @@ "BoolTensor": torch.bool, } -_EMPTY_DATA = torch.empty(0) +# These ops have at least one lazy tensor argument and maybe a scalar argument +# scalar value should be converted to meta tensor +# this is a hack for torch 2.0 +_EXPAND_SCALAR_OPS = [ + "where", + "clamp", + "clamp_min", + "clamp_max", + "clamp_", + "clamp_min_", + "clamp_max_", +] +_old_tensor_factory = torch.tensor + +_EMPTY_DATA = torch.empty(1) class _MyTensor(Tensor): @@ -145,34 +163,49 @@ class LazyTensor(torch.Tensor): """ _repr = True - _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _meta_data: Optional[torch.Tensor] = None # shape, dtype, device _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None + _device: torch.device # fake device of mate tensor @staticmethod def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): + # tips for torch 2.0: + # torch 2.0 disables torch dispatch for subclass of tensor + # MetaTensor is cannot be used + # Now lazy tensor contains device injection and meta tensor if concrete_data is not None: # some ops don't support meta backend and should have concrete data elem = concrete_data else: if meta_data is None: - device = kwargs.get("device", "cpu") - elem = func(*args, **{**kwargs, "device": "meta"}) - meta_data = MetaTensor(elem, device=device) - elem = meta_data._tensor + with ConstructorManager.disable(): + meta_data = func(*args, **{**kwargs, "device": "meta"}) + elem = meta_data # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r._meta_data = meta_data + # TODO: test this + # if isinstance(r, nn.Parameter): + # r = nn.Parameter(r) return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): + self._device = torch.device(kwargs.get("device", None) or "cpu") if func.__name__ in _NORMAL_FACTORY: kwargs = {**kwargs, "device": LazyTensor.default_device} self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._op_buffer = [] # (func, args, kwargs, replace) self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + @property + def device(self) -> torch.device: + return self._materialized_data.device if self._materialized_data is not None else self._device + + def __repr__(self): + return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" + def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). @@ -303,22 +336,25 @@ def unwrap(x): meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta + elif version.parse(torch.__version__) >= version.parse("2.0.0") and func.__name__ in _EXPAND_SCALAR_OPS: + return _old_tensor_factory(x, device="meta") return x def wrap(y, i=None): - if isinstance(y, MetaTensor): - if y in meta_to_lazy: - # inplace op, just return origin lazy tensor - return meta_to_lazy[y] + if isinstance(y, torch.Tensor): + if y.is_meta: + if y in meta_to_lazy: + # inplace op, just return origin lazy tensor + return meta_to_lazy[y] + else: + # out of place op, create new lazy tensor + fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] + fn.__name__ = func.__name__ + lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) + return lazy_y else: - # out of place op, create new lazy tensor - fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] - fn.__name__ = func.__name__ - lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs) - return lazy_y - elif type(y) is Tensor: - # for early materialized tensor - return LazyTensor(lambda: None, concrete_data=y) + # for early materialized tensor + return LazyTensor(lambda: None, concrete_data=y) return y cls._pre_op_fn() @@ -327,9 +363,36 @@ def wrap(y, i=None): return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return wrap(o) - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass # skip + def to(self, *args, **kwargs) -> torch.Tensor: + if self._materialized_data is not None: + return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs)) + + device = None + + def replace(x): + nonlocal device + if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool): + device = x + return torch.device("meta") + return x + + meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) + + def factory_fn(*a, **kw): + new_tensor = self.materialize() if type(self) is LazyTensor else self + return new_tensor.to(*args, **kwargs) + + return LazyTensor(factory_fn, meta_data=meta_data, device=device) + + def cpu(self, *args, **kwargs): + if self.device.type == "cpu": + return self.to(*args, **kwargs) + return self.to(*args, device="cpu", **kwargs) + + def cuda(self, device=None, non_blocking=False): + if device is not None: + return self.to(device=device, non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) def clone(self) -> "LazyTensor": def factory_fn(): @@ -455,7 +518,6 @@ def __init__( default_device: Optional[Union[torch.device, str, int]] = None, ): assert tensor_cls is LazyTensor or tensor_cls is _MyTensor - self.overrides = {} self.tensor_cls = tensor_cls self.old_default_device = LazyTensor.default_device self.default_device = default_device @@ -478,7 +540,9 @@ def wrap_factory_like_method(orig_target, target): # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) + return self.tensor_cls( + orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs + ) return wrapper, target @@ -513,13 +577,13 @@ def wrapper(*args, **kwargs): return wrapper, target - self.overrides = { + overrides = { target: wrap_factory_method(getattr(torch, target)) for target in _NORMAL_FACTORY if callable(getattr(torch, target, None)) } - self.overrides.update( + overrides.update( { target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) for target in _NORMAL_FACTORY @@ -527,7 +591,7 @@ def wrapper(*args, **kwargs): } ) - self.overrides.update( + overrides.update( { target: wrap_legacy_constructor(getattr(torch, target), dtype) for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() @@ -535,7 +599,7 @@ def wrapper(*args, **kwargs): } ) - self.overrides.update( + overrides.update( { target: wrap_no_meta_factory(getattr(torch, target)) for target in _NO_META_FACTORY @@ -543,14 +607,12 @@ def wrapper(*args, **kwargs): } ) - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, wrapper) + ConstructorManager.apply(overrides) def __exit__(self, exc_type, exc_val, exc_tb): self.tensor_cls.default_device = self.old_default_device LazyInitContext._replaced = False - for name, (wrapper, orig) in self.overrides.items(): - setattr(torch, name, orig) + ConstructorManager.clear() @staticmethod def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index 978cf06b55a0..a1b5763d4cd8 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -11,14 +11,12 @@ def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if ( - name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") - or name.startswith("transformers_llama") - or name.startswith(("transformers_vit", "transformers_blip2")) + if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith( + ("transformers_vit", "transformers_blip2") ): continue check_lazy_init(entry, verbose=True, default_device=default_device) if __name__ == "__main__": - test_torchvision_models_lazy_init("torchvision") + test_torchvision_models_lazy_init("transformers", "cpu") From a520d781b530de3c017c0168231bc1806a849bd3 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 20 Sep 2023 11:07:13 +0800 Subject: [PATCH 2/5] [lazy] pass transformers models --- colossalai/lazy/construction.py | 7 +++++-- colossalai/lazy/lazy_init.py | 13 ++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/colossalai/lazy/construction.py b/colossalai/lazy/construction.py index c95eeb195726..6764eaf774ab 100644 --- a/colossalai/lazy/construction.py +++ b/colossalai/lazy/construction.py @@ -73,9 +73,12 @@ def redo(): @staticmethod @contextmanager def disable(): - ConstructorManager.undo() + enabled = ConstructorManager.changed + if enabled: + ConstructorManager.undo() yield - ConstructorManager.redo() + if enabled: + ConstructorManager.redo() @staticmethod def clear(): diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 5f5329adf383..cccc0583f35f 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -72,7 +72,7 @@ ] _old_tensor_factory = torch.tensor -_EMPTY_DATA = torch.empty(1) +_EMPTY_DATA = torch.empty(0) class _MyTensor(Tensor): @@ -181,6 +181,7 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): else: if meta_data is None: with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 meta_data = func(*args, **{**kwargs, "device": "meta"}) elem = meta_data # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here @@ -336,7 +337,11 @@ def unwrap(x): meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta - elif version.parse(torch.__version__) >= version.parse("2.0.0") and func.__name__ in _EXPAND_SCALAR_OPS: + elif ( + version.parse(torch.__version__) >= version.parse("2.0.0") + and func.__name__ in _EXPAND_SCALAR_OPS + and not isinstance(x, torch.Tensor) + ): return _old_tensor_factory(x, device="meta") return x @@ -358,7 +363,9 @@ def wrap(y, i=None): return y cls._pre_op_fn() - o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + with ConstructorManager.disable(): + # to disable create lazy tensor in inner ops, this is a hack for torch 2.0 + o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) if isinstance(o, (tuple, list)): return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return wrap(o) From 041a9763060682b817daf179c2ce063e35eeb52f Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 20 Sep 2023 15:47:36 +0800 Subject: [PATCH 3/5] [lazy] fix device move and requires grad --- colossalai/lazy/lazy_init.py | 64 +++++++++++++++++------------------- tests/test_lazy/test_ops.py | 64 ++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 33 deletions(-) create mode 100644 tests/test_lazy/test_ops.py diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index cccc0583f35f..0041823d9373 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -45,6 +45,9 @@ # These ops cannot be unwrapped using .data _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] +# These ops is not related to tensor value and should not be rerun +_NO_RERUN_OPS = ["requires_grad_", "__get__", "numel", "size", "dim"] + _LEGACY_TENSOR_CONSTRUCTOR = { "FloatTensor": torch.float, "DoubleTensor": torch.double, @@ -103,7 +106,7 @@ def _data_tolist(tensor: torch.Tensor) -> list: return tensor.data.tolist() -def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: +def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad: bool) -> torch.Tensor: """Convert a lazy tensor's class to target's class, with target's data. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. @@ -122,7 +125,7 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: # to fit UninitializedParameter delattr(tensor, "_is_param") tensor.data = target - tensor.requires_grad = target.requires_grad + tensor.requires_grad = requires_grad # subclass of torch.Tensor does not have tolist() method # overwrite this method after materialization or distribution tensor.tolist = MethodType(_data_tolist, tensor) @@ -213,9 +216,11 @@ def materialize(self) -> torch.Tensor: Returns: torch.Tensor: The materialized tensor (self). """ + # requires_grad attr is mounted on meta tensor, it should be copied back + requires_grad = self.requires_grad target = self._materialize_data() self.clean() - return _convert_cls(self, target) + return _convert_cls(self, target, requires_grad) def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. @@ -276,12 +281,9 @@ def replace(x): packed = None for func, args, kwargs in self._op_buffer: - if func == torch.Tensor.requires_grad_: - packed = func, args, kwargs # requires grad should be set at last - else: - self._pre_op_fn() - o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value + self._pre_op_fn() + o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value # super-dainiu: set requires_grad after all inplace-ops are done if packed is not None: @@ -333,7 +335,8 @@ def unwrap(x): # for early materialized tensor, use its materialized data directly return x._materialized_data if is_change_meta_op else x._materialized_data.data t = x if is_inplace else x.clone() - t._op_buffer.append((func, args, kwargs)) + if func.__name__ not in _NO_RERUN_OPS: + t._op_buffer.append((func, args, kwargs)) meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta @@ -385,29 +388,27 @@ def replace(x): meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs)) - def factory_fn(*a, **kw): - new_tensor = self.materialize() if type(self) is LazyTensor else self - return new_tensor.to(*args, **kwargs) + if meta_data is self._meta_data and device == self.device: + return self - return LazyTensor(factory_fn, meta_data=meta_data, device=device) + def factory_fn(t: torch.Tensor, **kw): + return t.to(*args, **kwargs) - def cpu(self, *args, **kwargs): - if self.device.type == "cpu": - return self.to(*args, **kwargs) - return self.to(*args, device="cpu", **kwargs) + return LazyTensor(factory_fn, self, meta_data=meta_data, device=device) - def cuda(self, device=None, non_blocking=False): - if device is not None: - return self.to(device=device, non_blocking=non_blocking) - return self.to(device="cuda:0", non_blocking=non_blocking) + def cpu(self, memory_format: torch.memory_format = torch.preserve_format): + return self.to(device=torch.device("cpu"), memory_format=memory_format) + + def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format): + device = torch.device(device or "cuda") + return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format) def clone(self) -> "LazyTensor": - def factory_fn(): + def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return new_tensor.clone() + return t.clone() - target = LazyTensor(factory_fn, meta_data=self._meta_data) + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) return target @@ -423,17 +424,16 @@ def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] - def factory_fn(): + def factory_fn(t: torch.Tensor, **kw): # if self is materialized, return self - new_tensor = self.materialize() if type(self) is LazyTensor else self - return _copy_tensor(new_tensor, new_tensor.requires_grad) + return _copy_tensor(t, t.requires_grad) if self._materialized_data is not None: # self is early materialized copied = _copy_tensor(self._materialized_data, self.requires_grad) target = LazyTensor(lambda: None, concrete_data=copied) else: - target = LazyTensor(factory_fn, meta_data=self._meta_data) + target = LazyTensor(factory_fn, self, meta_data=self._meta_data) if isinstance(self, Parameter): # hack isinstance check of parameter @@ -464,14 +464,12 @@ def data(self, other: "LazyTensor"): if other is self: return - self._op_buffer.append(other._factory_method) - def replace(x): if x is other: return self return x - for func, args, kwargs in other._op_buffer: + for func, args, kwargs in [other._factory_method, *other._op_buffer]: self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) def tolist(self) -> list: diff --git a/tests/test_lazy/test_ops.py b/tests/test_lazy/test_ops.py new file mode 100644 index 000000000000..e6b936198547 --- /dev/null +++ b/tests/test_lazy/test_ops.py @@ -0,0 +1,64 @@ +import copy + +import pytest +import torch +import torch.nn as nn +from lazy_init_utils import SUPPORT_LAZY +from torch.nn import Parameter + +from colossalai.lazy import LazyInitContext + + +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +def test_lazy_ops(): + with LazyInitContext(): + x = torch.rand(2, 3) + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + x.requires_grad is False + y = x.cuda() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + assert x.cpu() is x + p = Parameter(torch.empty(2, 3)) + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + x.materialize() + assert tuple(x.shape) == (2, 3) + assert x.device.type == "cpu" + assert x.requires_grad is False + y.materialize() + assert tuple(y.shape) == (2, 3) + assert y.device.type == "cuda" + assert y.requires_grad is False + p.materialize() + assert tuple(p.shape) == (2, 3) + assert p.device.type == "cpu" + assert p.requires_grad is True + assert isinstance(p, Parameter) + + with LazyInitContext(): + x = torch.empty(2, 3) + x.uniform_() + x.materialize() + assert tuple(x.shape) == (2, 3) + + with LazyInitContext(): + model = nn.Linear(3, 4) + model = model.cuda() + model_copied = copy.deepcopy(model) + LazyInitContext.materialize(model) + assert model.weight.device.type == "cuda" + assert model.bias.device.type == "cuda" + LazyInitContext.materialize(model_copied) + assert model_copied.weight.device.type == "cuda" + assert model_copied.bias.device.type == "cuda" + assert torch.equal(model.weight, model_copied.weight) + assert torch.equal(model.bias, model_copied.bias) + + +if __name__ == "__main__": + test_lazy_ops() From 4feb9fa1a1060cc80921d45ddddc9058d008c51b Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 20 Sep 2023 16:00:54 +0800 Subject: [PATCH 4/5] [lazy] fix requires grad and refactor api --- colossalai/lazy/lazy_init.py | 76 +++++++++--------------------------- 1 file changed, 19 insertions(+), 57 deletions(-) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 0041823d9373..a2c346ec2ce6 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,17 +1,14 @@ from types import MethodType -from typing import Callable, Dict, Optional, Union +from typing import Callable, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn from packaging import version from torch import Tensor from torch.nn import Parameter from torch.utils._pytree import tree_map -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor import distribute_tensor -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.logging import get_dist_logger from .construction import ConstructorManager @@ -46,7 +43,7 @@ _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] # These ops is not related to tensor value and should not be rerun -_NO_RERUN_OPS = ["requires_grad_", "__get__", "numel", "size", "dim"] +_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"] _LEGACY_TENSOR_CONSTRUCTOR = { "FloatTensor": torch.float, @@ -106,7 +103,7 @@ def _data_tolist(tensor: torch.Tensor) -> list: return tensor.data.tolist() -def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad: bool) -> torch.Tensor: +def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: """Convert a lazy tensor's class to target's class, with target's data. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. @@ -125,7 +122,7 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad: bool # to fit UninitializedParameter delattr(tensor, "_is_param") tensor.data = target - tensor.requires_grad = requires_grad + tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method # overwrite this method after materialization or distribution tensor.tolist = MethodType(_data_tolist, tensor) @@ -190,9 +187,7 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r._meta_data = meta_data - # TODO: test this - # if isinstance(r, nn.Parameter): - # r = nn.Parameter(r) + return r def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): @@ -216,25 +211,9 @@ def materialize(self) -> torch.Tensor: Returns: torch.Tensor: The materialized tensor (self). """ - # requires_grad attr is mounted on meta tensor, it should be copied back - requires_grad = self.requires_grad - target = self._materialize_data() - self.clean() - return _convert_cls(self, target, requires_grad) - - def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: - """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. - - Args: - layout (Layout): Distribution layout. - - Returns: - torch.Tensor: The distributed tensor (self). - """ target = self._materialize_data() self.clean() - local_tensor = distribute_tensor(target, device_mesh, sharding_spec) - return _convert_cls(self, local_tensor) + return _convert_cls(self, target) def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" @@ -281,9 +260,12 @@ def replace(x): packed = None for func, args, kwargs in self._op_buffer: - self._pre_op_fn() - o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value + if func.__name__ == "requires_grad_": + packed = (func, args, kwargs) + else: + self._pre_op_fn() + o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value # super-dainiu: set requires_grad after all inplace-ops are done if packed is not None: @@ -633,23 +615,6 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) - @staticmethod - def distribute( - module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False - ) -> nn.Module: - """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. - - Args: - module (nn.Module): Target ``nn.Module`` - layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout. - verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False. - """ - - def apply_fn(name: str, p: LazyTensor): - p.distribute(device_mesh, sharding_spec_dict[name]) - - return _apply_to_lazy_module(module, apply_fn, verbose) - def _apply_to_lazy_module( module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False @@ -689,20 +654,17 @@ def _apply_to_lazy_module( if verbose: non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 - _print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}") - _print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}") - _print_rank_0( - f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%" + logger = get_dist_logger() + logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0]) + logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0]) + logger.info( + f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%", + ranks=[0], ) return module -def _print_rank_0(*args, **kwargs): - if not dist.is_initialized() or dist.get_rank() == 0: - print(*args, **kwargs) - - def _is_int_tuple(args) -> bool: if not isinstance(args, tuple): return False From ea1e02492e2216964f1a5c4e10e0bb5a54c68359 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 20 Sep 2023 16:07:03 +0800 Subject: [PATCH 5/5] [lazy] fix requires grad --- colossalai/lazy/lazy_init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index a2c346ec2ce6..f29e997da495 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -260,8 +260,8 @@ def replace(x): packed = None for func, args, kwargs in self._op_buffer: - if func.__name__ == "requires_grad_": - packed = (func, args, kwargs) + if func == torch.Tensor.requires_grad_: + packed = func, args, kwargs # requires grad should be set at last else: self._pre_op_fn() o = func(*tree_map(replace, args), **tree_map(replace, kwargs))