Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 155 additions & 76 deletions colossalai/utils/model/experimental.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Callable, List, Optional, Union
from types import MethodType
from typing import Callable, Optional, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.utils._pytree import tree_map

from colossalai.fx.profiler.tensor import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout

# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
Expand All @@ -30,6 +34,11 @@

_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']

# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']

_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
'DoubleTensor': torch.double,
Expand All @@ -43,6 +52,8 @@
'BoolTensor': torch.bool,
}

_EMPTY_DATA = torch.empty(0)


class _MyTensor(Tensor):
"""This class is only for correctness verification.
Expand All @@ -64,6 +75,29 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)


def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.

The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.

Args:
tensor (LazyTensor): the LazyTensor to be converted
target (torch.Tensor): target tensor

Returns:
torch.Tensor: the converted tensor
"""
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
tensor.data = target
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(torch.Tensor.tolist, target)
return tensor


class LazyTensor(torch.Tensor):
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).

Expand Down Expand Up @@ -112,14 +146,8 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
elem = meta_data._tensor
r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad)
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
r._meta_data = meta_data
return r

Expand All @@ -129,15 +157,28 @@ def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data

def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).

Returns:
torch.Tensor: The materialized tensor.
torch.Tensor: The materialized tensor (self).
"""
target = self._materialize_data()
if isinstance(self, nn.Parameter):
target = nn.Parameter(target, requires_grad=self.requires_grad)
return target
self.clean()
return _convert_cls(self, target)

def distribute(self, layout: Layout) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.

Args:
layout (Layout): Distribution layout.

Returns:
torch.Tensor: The distributed tensor (self).
"""
target = self._materialize_data()
self.clean()
local_tensor = DTensor(target, layout).local_tensor
return _convert_cls(self, local_tensor)

def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
Expand Down Expand Up @@ -216,6 +257,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__")

is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS

if isinstance(func, torch._C.ScriptMethod):
# FIXME(ver217): torch script functions are not verified

Expand All @@ -239,10 +282,10 @@ def unwrap(x):
if isinstance(x, LazyTensor):
if x._materialized_data is not None:
# for early materialized tensor, use its materialized data directly
return x._materialized_data.data
return x._materialized_data if is_change_meta_op else x._materialized_data.data
t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs))
meta = x._meta_data.data
meta = x._meta_data if is_change_meta_op else x._meta_data.data
meta_to_lazy[meta] = t
return meta
return x
Expand Down Expand Up @@ -290,13 +333,36 @@ def data(self):

@data.setter
def data(self, other: 'LazyTensor'):
"""This is sightly different from oringinal `data` setter.

E.g.:
>>> a = torch.randn(3, 3) # a is a Tensor
>>> b = torch.rand(2, 2)
>>> a.data = b
>>> b.add_(1) # this will affect a
>>> x = torch.randn(3, 3) # x is a LazyTensor
>>> y = torch.rand(2, 2) # y is a LazyTensor
>>> x.data = y
>>> y.add_(1) # this will not affect x

"""
if other is self:
return
# TODO(ver217): to avoid infinity recursion, do early materialization
self._materialized_data = other._materialize_data()

self._op_buffer.append(other._factory_method)

def replace(x):
if x is other:
return self
return x

for func, args, kwargs in other._op_buffer:
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))

def tolist(self) -> list:
t = self.materialize()
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
# And subclass of torch.Tensor does not have tolist() method
t = self._materialize_data()
return t.tolist()

def __hash__(self):
Expand Down Expand Up @@ -421,71 +487,84 @@ def __exit__(self, exc_type, exc_val, exc_tb):
setattr(torch, name, orig)

@staticmethod
def materialize(module: torch.nn.Module, verbose: bool = False):
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.

Args:
module (torch.nn.Module): Target ``nn.Module``
module (nn.Module): Target ``nn.Module``
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
"""
if verbose:
param_cnt = 0
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
non_lazy_numel = 0

# do post cleaning to handle shared parameter
visited_lazy_tensors: List[LazyTensor] = []
# handle shared module
visited_modules = set()

@torch.no_grad()
def init_recursively(module: nn.Module):
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
# recursively initialize the module
for mod in module.children():
if id(mod) not in visited_modules:
visited_modules.add(id(mod))
init_recursively(mod)

# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if verbose:
param_cnt += 1
if getattr(param, '_materialized_data', False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1
else:
non_lazy_numel += param.numel()
if hasattr(param, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(param)
setattr(module, name, param.materialize())

for name, buf in module.named_buffers(recurse=False):
if verbose:
buf_cnt += 1
if getattr(buf, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt += 1
else:
non_lazy_numel += buf.numel()
if hasattr(buf, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(buf)
setattr(module, name, buf.materialize())

init_recursively(module)
def apply_fn(name: str, p: LazyTensor):
p.materialize()

return _apply_to_lazy_module(module, apply_fn, verbose)

@staticmethod
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.

Args:
module (nn.Module): Target ``nn.Module``
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
"""

def apply_fn(name: str, p: LazyTensor):
p.distribute(layout_dict[name])

return _apply_to_lazy_module(module, apply_fn, verbose)


for t in visited_lazy_tensors:
t.clean()
def _apply_to_lazy_module(module: nn.Module,
apply_fn: Callable[[str, torch.Tensor], None],
verbose: bool = False) -> nn.Module:
if verbose:
# verbose info
param_cnt = 0
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
total_numel = 0
non_lazy_numel = 0

for name, p in module.named_parameters():
if verbose:
param_cnt += 1
total_numel += p.numel()
if getattr(p, '_materialized_data', False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1
else:
non_lazy_numel += p.numel()
if isinstance(p, LazyTensor):
apply_fn(name, p)

for name, buf in module.named_buffers():
if verbose:
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
return module
buf_cnt += 1
total_numel += buf.numel()
if getattr(buf, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt += 1
else:
non_lazy_numel += buf.numel()
if isinstance(buf, LazyTensor):
apply_fn(name, buf)

if verbose:
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
_print_rank_0(
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')

return module


def _print_rank_0(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)


def _is_int_tuple(args) -> bool:
Expand Down
Loading