Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 86 additions & 106 deletions colossalai/utils/model/lazy_init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# coding: utf-8

import torch
from colossalai.tensor import ColoParameter
import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor

import types
import inspect
import typing
from typing import List, Callable
from colossalai.utils.model.utils import substitute_init_recursively
import copy


class LazyInitContext():
Expand All @@ -18,8 +18,7 @@ class LazyInitContext():

Note:
This API is only experimental and subject to future changes.
It should be integrated with meta tensor initialization in the future.


Usage:
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
Expand All @@ -36,14 +35,17 @@ class LazyInitContext():
assert not model.weight.is_meta and torch.all(model.weight == 0)

Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""

tensor_set_value_func = ['zero_']
tensor_set_value_func = ['zero_', 'fill_']
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updates the Note of the descriptions above?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.


def __init__(self, extra_torch_tensor_func: List[str] = None):
self._intercepted_init_func_cache = []
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
self._nn_init_methods = self._get_nn_init_methods()
self._torch_mod_cls = torch.nn.modules.module.Module

Expand All @@ -53,14 +55,20 @@ def __init__(self, extra_torch_tensor_func: List[str] = None):
else:
self._torch_tensor_funcs = self.tensor_set_value_func

def _cache_func(self, func):
@property
def to_meta(self):
return self._to_meta

def _cache_init_func(self, func):
"""
This method wraps the ``torch.nn.init`` method so that the function call
is cached instead of being executed.
This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
so that the function call is cached instead of being executed.
"""

def wrapped_init_func(*args, **kwargs):
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs))
def wrapped_init_func(tensor, *args, **kwargs):
if tensor not in self._intercepted_nn_init_func_cache:
self._intercepted_nn_init_func_cache[tensor] = []
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))

return wrapped_init_func

Expand All @@ -76,17 +84,10 @@ def _get_nn_init_methods(self):
for name in nn_init_method_names:
nn_init_methods.append((name, getattr(torch.nn.init, name)))

def _has_tensor_in_arg(func):
hints = typing.get_type_hints(func)
for k, v in hints.items():
if v is torch.Tensor:
return True
return False

def _is_init_method(item):
name, func = item
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')
or not _has_tensor_in_arg(func)):

if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
return False
else:
return True
Expand All @@ -103,11 +104,13 @@ def _wrap_module_init(self, func):
has_device = 'device' in inspect.signature(func).parameters

def layer_lazy_init(module, *args, **kwargs):
self._intercepted_init_func_cache.append(
dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs)))
# if this module contains device argument
# we set it to meta to initialize as meta backend
if has_device:
kwargs['device'] = 'meta'
func(module, *args, **kwargs)

# if device is not found, we intialize it and convert to meta
if not has_device:
module.to('meta')

Expand All @@ -122,7 +125,7 @@ def _get_tmp_origin_func_ref(self, name):
def _patch_nn_init_funcs(self):
# patch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, self._cache_func(func))
setattr(torch.nn.init, name, self._cache_init_func(func))

def _unpatch_nn_init_funcs(self):
# unpatch nn.init functions
Expand Down Expand Up @@ -150,7 +153,7 @@ def _patch_torch_tensor_funcs(self):
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, func_name)
setattr(torch.Tensor, origin_func_name, origin_func)
setattr(torch.Tensor, func_name, self._cache_func(origin_func))
setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))

def _unpatch_torch_tensor_funcs(self):
for func_name in self._torch_tensor_funcs:
Expand All @@ -159,17 +162,18 @@ def _unpatch_torch_tensor_funcs(self):
setattr(torch.Tensor, func_name, origin_func)

def __enter__(self):
self._patch_submodule_init()
self._patch_torch_tensor_funcs()
self._patch_nn_init_funcs()

if self._to_meta:
self._patch_submodule_init()
return self

def __exit__(self, *args, **kwargs):
self._unpatch_submodule_init()
# build model_rebuild_dict in reverse order to make sure get correct init func for inherited class.
self.module_rebuild_dict = {}
self._intercepted_init_func_cache.reverse()
for cache in self._intercepted_init_func_cache:
self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs'])
self._intercepted_init_func_cache.reverse()
if self._to_meta:
self._unpatch_submodule_init()
self._unpatch_nn_init_funcs()
self._unpatch_torch_tensor_funcs()

def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
"""
Expand All @@ -178,80 +182,56 @@ def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back:
Args:
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized

"""
# build param mapping
param_id_to_name = dict()
for name, param in model.named_parameters():
param_id_to_name[id(param)] = name
for name, buffer in model.named_buffers():
param_id_to_name[id(buffer)] = name

assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.'

def _process_arg(arg):
"""
Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict,
we need to rebuild it with real parameters. If arg is a tuple or list, we will process
the element of arg with this function again.
"""
if torch.is_tensor(arg):
tensor_id = id(arg)
if tensor_id in param_id_to_name:
arg = _replace_meta_param_with_real_param(arg)

elif isinstance(arg, torch.nn.Module):
if arg in self.module_rebuild_dict:
arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back)

elif isinstance(arg, (tuple, list)):
rst_list = []
for element in arg:
processed_element = _process_arg(element)
rst_list.append(processed_element)
arg = rst_list
return arg

def _replace_meta_param_with_real_param(meta_param):
if meta_param.device != 'meta':
return meta_param
tensor_id = id(meta_param)
param_full_name = param_id_to_name[tensor_id]
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)

if '.' in param_full_name:
submodule_name, param_name = param_full_name.rsplit('.', 1)
submodule = model.get_submodule(submodule_name)
else:
submodule = model
param_name = param_full_name
setattr(submodule, param_name, real_param)

# execute call_back function on the materailized tensor
# this can where sharding comes in
if call_back:
call_back(real_param)
return real_param

func, args, kwargs = self.module_rebuild_dict[model]
args = list(args)

# check args for parameter replacement
for idx, arg in enumerate(args):
arg = _process_arg(arg)
args[idx] = arg

# check kwargs for parameter replacement
for arg_name, arg in kwargs.items():
if arg_name == 'device':
arg = device

def _init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
_init_recursively(mod)

# initialize and shard tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
_init_and_shard(module, name, param)

for name, buf in module.named_buffers(recurse=False):
_init_and_shard(module, name, buf)

@torch.no_grad()
def _init_and_shard(module, name, tensor):
# check whether the tensor is a buffer or parameter
is_param = isinstance(tensor, nn.parameter.Parameter)

# get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
pg = getattr(tensor, 'pg', None)

# convert the tensor from meta to materialized one
if tensor.is_meta:
materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache
tensor = materialized_tensor

# apply init function
if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(tensor, *args, **kwargs)

# convert it to ColoTensor or ColoParameter
if is_param:
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
else:
arg = _process_arg(arg)
kwargs[arg_name] = arg
tensor = ColoTensor.from_torch_tensor(tensor)

# apply sharding
if dist_spec:
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)

# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)

# build user specified model
with torch.no_grad():
func(model, *args, **kwargs)
_init_recursively(model)

return model
35 changes: 25 additions & 10 deletions tests/test_utils/test_lazy_init_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,42 @@
torch.manual_seed(MANUAL_SEED)


def test_lazy_init():
cpu_rng_state = torch.get_rng_state()
origin_model = resnet34(num_classes=10)
origin_param_dict = dict(origin_model.named_parameters())
torch.set_rng_state(cpu_rng_state)
ctx = LazyInitContext()
def test_lazy_init_with_meta():
ctx = LazyInitContext(to_meta=True)
with ctx:
model = resnet34(num_classes=10)

for param in model.parameters():
assert param.is_meta
for buffer in model.buffers():
assert buffer.is_meta

ctx.lazy_init_parameters(model)

for name, param in model.named_parameters():
assert not param.is_meta, name

for buffer in model.buffers():
assert not buffer.is_meta


def test_lazy_init_without_meta():
ctx = LazyInitContext(to_meta=False)
with ctx:
model = resnet34(num_classes=10)

for param in model.parameters():
assert not param.is_meta
for buffer in model.buffers():
assert not buffer.is_meta
param_dict = dict(model.named_parameters())
for key in origin_param_dict.keys():
assert origin_param_dict[key].data.equal(param_dict[key].data)

conv1_weight_before_init = model.conv1.weight.clone()
ctx.lazy_init_parameters(model)
conv1_weight_after_init = model.conv1.weight.clone()

assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)


if __name__ == '__main__':
test_lazy_init()
test_lazy_init_with_meta()
test_lazy_init_without_meta()