Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions colossalai/_analyzer/_subclasses/_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import torch.distributed as dist
from packaging import version

aten = torch.ops.aten

__all__ = [
"_TorchFactoryMethod",
"_TorchOverrideableFactoryMethod",
Expand Down Expand Up @@ -51,6 +49,7 @@
]

if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
Expand Down
50 changes: 0 additions & 50 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,15 @@
from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats

from .plugin_base import Plugin

__all__ = ['GeminiPlugin']


def convert_to_colo_param(module: nn.Module) -> None:
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.

Args:
module (nn.Module): Module to be converted.
"""
converted_modules = set() # handle shared modules
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference

def convert_recursively(m: nn.Module):
for child in m.children():
if child not in converted_modules:
converted_modules.add(child)
convert_recursively(child)

for name, p in m.named_parameters(recurse=False):
assert not isinstance(p, ColoParameter)
if p in converted_params:
target = converted_params[p]
else:
target = _convert_to_coloparam(p, p.device, p.dtype)
converted_params[p] = target
setattr(m, name, target)
target.shared_param_modules.append(m)

convert_recursively(module)

# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
module._converted_params = converted_params


def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
"""Replace param in optimizer's group with converted ColoParameter.

Args:
optimizer (Optimizer): Optimizer to be replaced.
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
"""
for group in optimizer.param_groups:
for i, p in enumerate(group['params']):
if p in converted_params:
group['params'][i] = converted_params[p]


class GeminiCheckpointIO(GeneralCheckpointIO):

def __init__(self) -> None:
Expand Down Expand Up @@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper):

def __init__(self, module: nn.Module, gemini_config: dict) -> None:
super().__init__(module)
# TODO(ver217): only support Gemini now
convert_to_colo_param(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)

def unwrap(self):
Expand All @@ -125,8 +77,6 @@ def unwrap(self):
class GeminiOptimizer(OptimizerWrapper):

def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
replace_param_in_group(optimizer, module.module._converted_params)
del module.module._converted_params
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
super().__init__(optimizer)

Expand Down
15 changes: 10 additions & 5 deletions colossalai/nn/optimizer/nvme_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import math
import os
import tempfile
import math
from typing import Callable, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter
from typing import Optional, List, Dict, Callable


class NVMeOptimizer(torch.optim.Optimizer):
Expand Down Expand Up @@ -42,8 +43,9 @@ def __init__(self,
self.offloader = None
self.is_on_nvme: Dict[Parameter, bool] = {}
self.offloaded_numel: int = 0
self.total_numel: int = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
# As param may be not materialized here, these attributes are initalized when the first step
self.total_numel: Optional[int] = None
self.can_offload_numel: Optional[int] = None

self.prefetch_params: List[Parameter] = []
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
Expand Down Expand Up @@ -77,6 +79,9 @@ def _setup_prefetch_params(self) -> List[Parameter]:
self.prefetch_params.append(p)

def _pre_step(self, *state_keys: str) -> None:
if self.total_numel is None:
self.total_numel = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
self._setup_prefetch_params()
if self.offloader is None or len(self.prefetch_params) == 0:
return
Expand Down
16 changes: 11 additions & 5 deletions colossalai/utils/model/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor
from torch.utils._pytree import tree_map

from colossalai.fx.profiler.tensor import MetaTensor
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout

Expand Down Expand Up @@ -37,7 +37,7 @@
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']

_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
Expand Down Expand Up @@ -75,6 +75,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)


def _data_tolist(tensor: torch.Tensor) -> list:
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
"""
return tensor.data.tolist()


def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.

Expand All @@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(torch.Tensor.tolist, target)
tensor.tolist = MethodType(_data_tolist, tensor)
return tensor


Expand Down Expand Up @@ -144,7 +150,7 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
Expand Down Expand Up @@ -255,7 +261,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__")
or func.__name__ in ('__setitem__', '__set__'))

is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS

Expand Down
9 changes: 6 additions & 3 deletions colossalai/zero/gemini/chunk/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:


def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
if strict_ddp_flag:
if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global()
else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel()


Expand All @@ -67,11 +68,13 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param):
continue

if strict_ddp_flag:
if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
Expand Down
27 changes: 23 additions & 4 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -14,6 +14,7 @@
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor

from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
Expand Down Expand Up @@ -55,7 +56,6 @@ def __init__(self,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
Expand All @@ -67,7 +67,6 @@ def __init__(self,
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()

self._cast_buffers()
self._logger = get_dist_logger()

if self.gemini_manager._premade_memstats_:
Expand All @@ -91,6 +90,8 @@ def __init__(self,
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._cast_buffers()

def _post_forward(self):
"""This function is only triggered for inference.
Expand Down Expand Up @@ -478,7 +479,8 @@ def load_fp32_parameter(chunk_slice, data):
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
self._preprocess_param(p)
assert type(p) is ColoParameter

# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
Expand Down Expand Up @@ -531,10 +533,27 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi

def _cast_buffers(self):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()

def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
Args:
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
"""
if type(p) is ColoParameter:
# model is initialized with ColoInitContext
return
requires_grad = p.requires_grad
if isinstance(p, LazyTensor):
# model is initialized with LazyInitContext
p.materialize()
p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad)


class GeminiDDP(ZeroDDP):

Expand Down
32 changes: 29 additions & 3 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
from contextlib import nullcontext

import torch
import torch.distributed as dist

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo


def check_gemini_plugin(early_stop: bool = True):
@parameterize('init_method', ['lazy', 'none', 'colo'])
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo

Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
is_support_meta = is_compatible_with_meta()
if not is_support_meta and init_method == 'lazy':
return

from colossalai.utils.model.experimental import LazyInitContext
passed_models = []
failed_info = {} # (model_name, error) pair

Expand All @@ -40,10 +50,25 @@ def check_gemini_plugin(early_stop: bool = True):
]:
continue

if init_method == 'lazy' and name in [
'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3',
'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0',
'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf',
'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s'
]:
continue

try:
if init_method == 'colo':
ctx = ColoInitContext()
elif init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
Expand Down Expand Up @@ -76,6 +101,7 @@ def check_gemini_plugin(early_stop: bool = True):
torch.cuda.empty_cache()

if dist.get_rank() == 0:
print(f'Init method: {init_method}')
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
Expand Down