Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions colossalai/lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .lazy_init import LazyInitContext, LazyTensor

__all__ = [
'LazyInitContext',
'LazyTensor',
]
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,14 @@ def factory_fn():
copied.requires_grad_()
return copied

target = LazyTensor(factory_fn, meta_data=self._meta_data)
if self._materialized_data is not None:
# self is early materialized
copied = self._materialized_data.detach().clone()
if self.requires_grad:
copied.requires_grad_()
target = LazyTensor(lambda: None, concrete_data=copied)
else:
target = LazyTensor(factory_fn, meta_data=self._meta_data)

memo[id(self)] = target
return target
Expand Down
242 changes: 0 additions & 242 deletions colossalai/utils/model/lazy_init_context.py

This file was deleted.

60 changes: 32 additions & 28 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn

from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor

from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
Expand Down Expand Up @@ -96,34 +96,38 @@ def __init__(self,
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()

def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):

r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""

if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

def _get_non_persistent_buffers_set(self,
module,
memo: Optional[Set[nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""

if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set

def _post_forward(self):
"""This function is only triggered for inference.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.model.experimental import LazyInitContext
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo

Expand Down
Loading