Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 133 additions & 32 deletions colossalai/utils/model/experimental.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils._pytree import tree_map

from colossalai.fx.profiler import MetaTensor
from colossalai.fx.profiler.tensor import MetaTensor

# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_TorchFactoryMethod = [
_NORMAL_FACTORY = [
"arange",
"empty",
"eye",
"full",
"linspace",
"logspace",
Expand All @@ -24,17 +23,39 @@
"tensor",
]

# factory function that does not support meta tensor backend
_NO_META_FACTORY = [
"eye",
]

_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']

_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
'DoubleTensor': torch.double,
'HalfTensor': torch.half,
'BFloat16Tensor': torch.bfloat16,
'ByteTensor': torch.uint8,
'CharTensor': torch.int8,
'ShortTensor': torch.short,
'IntTensor': torch.int,
'LongTensor': torch.long,
'BoolTensor': torch.bool,
}


class _MyTensor(Tensor):
"""This class is only for correctness verification.
"""
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None

def __new__(cls, func, *args, dtype=None, device=None, **kwargs) -> '_MyTensor':
def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
cls._pre_op_fn()
data = func(*args, dtype=dtype, device=device, **kwargs)
if concrete_data is not None:
# uniform api as LazyTensor
data = concrete_data
else:
data = func(*args, **kwargs)
return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)

@classmethod
Expand Down Expand Up @@ -66,11 +87,13 @@ class LazyTensor(torch.Tensor):
>>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
>>> z = x.tolist()
>>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
>>> x.data = torch.rand(2, 3) # directly set data of a lazy tensor is not allowed
>>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed


2. Cases that ``LazyTensor`` becomes eager (early materialization).
>>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
>>> chunks = a.split(3) # this also triggers early materialization
>>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization

"""

Expand All @@ -79,12 +102,16 @@ class LazyTensor(torch.Tensor):
_pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None

@staticmethod
def __new__(cls, func, *args, meta_data=None, **kwargs):
if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
elem = meta_data._tensor
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
if concrete_data is not None:
# some ops don't support meta backend and should have concrete data
elem = concrete_data
else:
if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
elem = meta_data._tensor
r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(),
strides=elem.stride(),
Expand All @@ -96,10 +123,10 @@ def __new__(cls, func, *args, meta_data=None, **kwargs):
r._meta_data = meta_data
return r

def __init__(self, func, *args, meta_data=None, **kwargs):
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = None # materialized data
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data

def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
Expand Down Expand Up @@ -212,7 +239,7 @@ def unwrap(x):
if isinstance(x, LazyTensor):
if x._materialized_data is not None:
# for early materialized tensor, use its materialized data directly
return x._materialized_data
return x._materialized_data.data
t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs))
meta = x._meta_data.data
Expand All @@ -232,13 +259,10 @@ def wrap(y, i=None):
return lazy_y
elif type(y) is Tensor:
# for early materialized tensor
with torch._C.DisableTorchFunction():
meta = MetaTensor(y.new_empty(y.shape, dtype=y.dtype, device='meta'), fake_device=y.device)
lazy_y = LazyTensor(lambda: None, meta_data=meta)
lazy_y._materialized_data = y
return lazy_y
return LazyTensor(lambda: None, concrete_data=y)
return y

cls._pre_op_fn()
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
if isinstance(o, (tuple, list)):
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
Expand Down Expand Up @@ -266,7 +290,10 @@ def data(self):

@data.setter
def data(self, other: 'LazyTensor'):
raise NotImplementedError
if other is self:
return
# TODO(ver217): to avoid infinity recursion, do early materialization
self._materialized_data = other._materialize_data()

def tolist(self) -> list:
t = self.materialize()
Expand Down Expand Up @@ -330,18 +357,61 @@ def wrapper(*args, **kwargs):

return wrapper, target

def wrap_legacy_constructor(target, dtype):
# legacy constructor (e.g. torch.LongTensor())
def wrapper(*args, **kwargs):
if len(args) == 1 and isinstance(args[0], torch.Tensor):
# (Tensor other)
return args[0]
elif len(args) == 1:
# (object data, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype}
replaced, orig = self.overrides['tensor']
return replaced(*args, **kwargs)
elif _is_int_tuple(args):
# (tuple of ints size, *, torch.device device)
kwargs = {**kwargs, 'dtype': dtype}
replaced, orig = self.overrides['empty']
return replaced(*args, **kwargs)
else:
raise TypeError(
f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
)

return wrapper, target

def wrap_no_meta_factory(target):
# factory functions which don't support meta tensor backend
def wrapper(*args, **kwargs):
tensor = target(*args, **kwargs)
return self.tensor_cls(lambda: None, concrete_data=tensor)

return wrapper, target

self.overrides = {
target: wrap_factory_method(getattr(torch, target))
for target in _TorchFactoryMethod
for target in _NORMAL_FACTORY
if callable(getattr(torch, target, None))
}

self.overrides.update({
target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
for target in _TorchFactoryMethod
for target in _NORMAL_FACTORY
if callable(getattr(torch, target + '_like', None))
})

self.overrides.update({
target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
if callable(getattr(torch, target, None))
})

self.overrides.update({
target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY
if callable(getattr(torch, target, None))
})

for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper)

Expand All @@ -363,34 +433,65 @@ def materialize(module: torch.nn.Module, verbose: bool = False):
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
non_lazy_numel = 0

# do post cleaning to handle shared parameter
visited_lazy_tensors: List[LazyTensor] = []
# handle shared module
visited_modules = set()

@torch.no_grad()
def init_recursively(module: nn.Module):
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
# recursively initialize the module
for mod in module.children():
init_recursively(mod)
if id(mod) not in visited_modules:
visited_modules.add(id(mod))
init_recursively(mod)

# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if verbose:
param_cnt += 1
if param._materialized_data is None:
if getattr(param, '_materialized_data', False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1
setattr(module, name, param.materialize())
param.clean()
else:
non_lazy_numel += param.numel()
if hasattr(param, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(param)
setattr(module, name, param.materialize())

for name, buf in module.named_buffers(recurse=False):
if verbose:
buf_cnt += 1
if buf._materialized_data is None:
if getattr(buf, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt += 1
setattr(module, name, buf.materialize())
buf.clean()
else:
non_lazy_numel += buf.numel()
if hasattr(buf, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(buf)
setattr(module, name, buf.materialize())

init_recursively(module)

for t in visited_lazy_tensors:
t.clean()

if verbose:
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
return module


def _is_int_tuple(args) -> bool:
if not isinstance(args, tuple):
return False
for x in args:
if not isinstance(x, int):
return False
return True
1 change: 1 addition & 0 deletions tests/kit/model_zoo/torchrec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .torchrec import *
23 changes: 23 additions & 0 deletions tests/test_utils/test_lazy_init/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from tests.kit.model_zoo import model_zoo

# FIXME(ver217): uncomment this line
# from utils import check_lazy_init


# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
@pytest.mark.skip
@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
def test_torchvision_models_lazy_init(subset):
sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue
# FIXME(ver217): uncomment this line
# check_lazy_init(entry, verbose=True)


if __name__ == '__main__':
test_torchvision_models_lazy_init('torchvision')
69 changes: 69 additions & 0 deletions tests/test_utils/test_lazy_init/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import random
from typing import Any, Callable, Optional, Tuple

import numpy as np
import torch

from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
from tests.kit.model_zoo.registry import ModelAttribute

# model_fn, data_gen_fn, output_transform_fn, model_attr
TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]]


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
s1 = m1.state_dict()
s2 = m2.state_dict()

assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}'

for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()):
assert n1 == n2
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'


def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
output_transform_fn: Callable[[Any], dict]) -> None:
data = data_gen_fn()

m1.eval()
m2.eval()
# run forward
with torch.no_grad():
outputs1 = m1(**data)
outputs2 = m2(**data)

# compare output
transformed_out1 = output_transform_fn(outputs1)
transformed_out2 = output_transform_fn(outputs2)

assert len(transformed_out1) == len(transformed_out2)

for key, out1 in transformed_out1.items():
out2 = transformed_out2[key]
assert torch.allclose(out1, out2, atol=1e-5), \
f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'


def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx:
model = model_fn()
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
deferred_model = ctx.materialize(deferred_model, verbose=verbose)
assert_model_eqaual(model, deferred_model)
if check_forward:
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
if verbose:
print(f'{model.__class__.__name__} pass')