Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter
from torch.utils._pytree import tree_map

from colossalai._analyzer._subclasses import MetaTensor
Expand Down Expand Up @@ -95,8 +96,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the converted tensor
"""
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, '_is_param')
tensor.data = target
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
Expand Down Expand Up @@ -190,10 +194,10 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
"""
self._factory_method = None
self._op_buffer = None
self._materialized_data = None
self._meta_data = None
delattr(self, '_factory_method')
delattr(self, '_op_buffer')
delattr(self, '_materialized_data')
delattr(self, '_meta_data')

@staticmethod
def _replace_with_materialized(x):
Expand Down Expand Up @@ -346,20 +350,19 @@ def __deepcopy__(self, memo):
def factory_fn():
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
copied = new_tensor.detach().clone()
if new_tensor.requires_grad:
copied.requires_grad_()
return copied
return _copy_tensor(new_tensor, new_tensor.requires_grad)

if self._materialized_data is not None:
# self is early materialized
copied = self._materialized_data.detach().clone()
if self.requires_grad:
copied.requires_grad_()
copied = _copy_tensor(self._materialized_data, self.requires_grad)
target = LazyTensor(lambda: None, concrete_data=copied)
else:
target = LazyTensor(factory_fn, meta_data=self._meta_data)

if isinstance(self, Parameter):
# hack isinstance check of parameter
target._is_param = True

memo[id(self)] = target
return target

Expand Down Expand Up @@ -404,6 +407,10 @@ def tolist(self) -> list:
def __hash__(self):
return id(self)

def __rpow__(self, other):
dtype = torch.result_type(self, other)
return torch.tensor(other, dtype=dtype, device=self.device)**self


class LazyInitContext:
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
Expand Down Expand Up @@ -524,7 +531,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

@staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
"""Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.

Args:
module (nn.Module): Target ``nn.Module``
Expand All @@ -541,7 +548,7 @@ def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.

Args:
module (nn.Module): Target ``nn.Module``
Expand Down Expand Up @@ -613,3 +620,9 @@ def _is_int_tuple(args) -> bool:
if not isinstance(x, int):
return False
return True


def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
copied = tensor.data.clone()
copied.requires_grad = requires_grad
return copied
5 changes: 4 additions & 1 deletion colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter

from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
Expand Down Expand Up @@ -95,6 +95,7 @@ def from_native_module(module: nn.Embedding,
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
LazyInitContext.materialize(module)
# get the attributes
num_embedding = module.num_embeddings
embedding_dim = module.embedding_dim
Expand Down Expand Up @@ -223,6 +224,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup,
r"""
Convert a native pytorch embedding module to a parallel module.
"""
LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
Expand All @@ -243,6 +245,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup,
process_group=process_group,
*args,
**kwargs)

with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter

from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
Expand Down Expand Up @@ -106,6 +107,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
Expand Down Expand Up @@ -242,6 +244,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
Expand Down
4 changes: 4 additions & 0 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import torch.nn as nn

from colossalai.lazy import LazyInitContext

__all__ = ['FusedLayerNorm', 'FusedRMSNorm']

FAST_LAYERNORM_SUPPORTED_SIZE = [
Expand Down Expand Up @@ -35,6 +37,7 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
raise ImportError(
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')

LazyInitContext.materialize(module)
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
Expand Down Expand Up @@ -84,6 +87,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
)

LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm":
normalized_shape = module.weight.shape[0]
Expand Down
7 changes: 5 additions & 2 deletions colossalai/shardformer/layer/qkv_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter

from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import (
Expand Down Expand Up @@ -231,6 +232,7 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis
process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
Expand Down Expand Up @@ -380,6 +382,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.weight.shape[0]
out_features = module.weight.shape[1]
Expand Down Expand Up @@ -428,9 +431,9 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)

origin_device = self.bias.device
self.bias = self.bias.cuda()
self.bias.data = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)
self.bias.data = self.bias.to(origin_device)

def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
Expand Down
38 changes: 21 additions & 17 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ def preprocess(self):
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
Expand Down Expand Up @@ -166,10 +167,11 @@ def module_policy(self):
return module_policy

def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
if self.shard_config.enable_tensor_parallelism:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model


Expand All @@ -185,10 +187,11 @@ def module_policy(self):
return module_policy

def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
if self.shard_config.enable_tensor_parallelism:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model


Expand All @@ -204,10 +207,11 @@ def module_policy(self):
return module_policy

def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
if self.shard_config.enable_tensor_parallelism:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model


Expand Down
26 changes: 12 additions & 14 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
Expand Down Expand Up @@ -128,16 +129,13 @@ def module_policy(self):
return policy

def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}

for k, v in binding_map.items():
param = getattr_(self.model, k)

if not isinstance(param, nn.Parameter):
param = nn.Parameter(param)
if self.shard_config.enable_tensor_parallelism:
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}

# tie weights
setattr_(self.model, v, param)
for k, v in binding_map.items():
param = getattr_(self.model, k)
# tie weights
setattr_(self.model, v, param)
return self.model


Expand Down
29 changes: 16 additions & 13 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
Expand Down Expand Up @@ -142,10 +143,11 @@ def module_policy(self):
return module_policy

def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
if self.shard_config.enable_tensor_parallelism:
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model


Expand All @@ -172,10 +174,11 @@ def module_policy(self):
return module_policy

def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
if self.shard_config.enable_tensor_parallelism:
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model


Expand Down
13 changes: 7 additions & 6 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ def config_sanity_check(self):
pass

def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

return self.model

Expand Down
Loading