Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions colossalai/zero/init_ctx/init_context.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,45 @@
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn as nn

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses


class ZeroContextConfig(object):
@dataclass
class ZeroContextConfig:
"""The configuration used to control zero context initialization.

Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
"""

def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
super().__init__()
target_device: torch.device
is_replicated: bool = True
shard_param: bool = False

if shard_param:
assert replicated, "Non-replicated parameters can't be sharded."
def __post_init__(self):
if self.shard_param:
assert self.is_replicated, "Non-replicated parameters can't be sharded."

# replicated no-shard parameters should locate in cuda, since we will broadcast them soon
if replicated and not shard_param:
assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."

self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
if self.is_replicated and not self.shard_param:
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."


class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(self,
self.seed = seed
self.dp_process_group = gpc.get_group(ParallelMode.DATA)

self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)

ZeroContextMgr().current_context = self

Expand Down Expand Up @@ -124,7 +123,7 @@ def calc_fanin_fanout(tensor: torch.Tensor):
return fan_in, fan_out

def _pre_context_exec(self):
"""
"""
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
Expand Down Expand Up @@ -248,7 +247,7 @@ def hijack_context_config(self, **kwargs):

def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
replicated=is_replicated,
is_replicated=is_replicated,
shard_param=False)


Expand Down