Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn

from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor import is_distributed_tensor

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -92,7 +92,7 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
if is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)

# If this weight is going to tip up over the maximal size, we split.
Expand Down
2 changes: 1 addition & 1 deletion colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __deepcopy__(self, memo) -> "DeviceMesh":
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k != 'process_groups_dict':
if k != '_process_group_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
# process group cannot be copied
Expand Down
4 changes: 2 additions & 2 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from colossalai._analyzer._subclasses import MetaTensor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor import distribute_tensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec

# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
Expand Down Expand Up @@ -184,7 +184,7 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to
"""
target = self._materialize_data()
self.clean()
local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
return _convert_cls(self, local_tensor)

def clean(self) -> None:
Expand Down
25 changes: 13 additions & 12 deletions colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param

from ._operation import gather_forward_split_backward, reduce_input
from .parallel_module import ParallelModule
Expand Down Expand Up @@ -69,18 +68,17 @@ def __init__(self,
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)

self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output

if device is None:
device = get_current_device()

self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_colwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
Expand Down Expand Up @@ -194,7 +192,7 @@ def __init__(self,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
Expand All @@ -208,8 +206,11 @@ def __init__(self,
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
# parameter
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_rowwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
Expand Down Expand Up @@ -252,7 +253,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup,

def reset_parameters(self, weight_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand Down
30 changes: 14 additions & 16 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device

from ._operation import (
Expand Down Expand Up @@ -76,22 +76,21 @@ def __init__(self,
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)

if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

self.out_features_per_partition = divide(out_features, self.num_partitions)

# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))

weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)

if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
bias = torch.empty(self.out_features, **factory_kwargs)
sharded_bias = shard_colwise(bias, self.process_group)
self.bias = sharded_tensor_to_param(sharded_bias)
else:
self.bias = None

Expand Down Expand Up @@ -128,7 +127,6 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
*args,
**kwargs)

# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
Expand All @@ -137,7 +135,6 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)

return linear_1d

def reset_parameters(self, weight_initializer, bias_initializer) -> None:
Expand Down Expand Up @@ -212,21 +209,20 @@ def __init__(self,
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)

if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)

# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()

factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))

weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)

if self.stream_chunk_num > 1:
# TODO() work for inference only
Expand Down Expand Up @@ -340,3 +336,5 @@ def forward(self, input_: Tensor) -> Tensor:
return output
else:
return output, self.bias
return output, self.bias
return output, self.bias
2 changes: 0 additions & 2 deletions colossalai/shardformer/layer/linear_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

class LinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
Specially created for HuggingFace's GPT2 model.

The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
Expand Down Expand Up @@ -189,7 +188,6 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:

class LinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism
Specially created for HuggingFace's GPT2 model.

Args:
in_features (int): size of each input sample.
Expand Down
142 changes: 142 additions & 0 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import itertools
from abc import ABC, abstractmethod
from typing import List, Union

import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module

from colossalai.tensor.d_tensor import (
distribute_tensor,
get_device_mesh,
get_sharding_spec,
is_distributed_tensor,
sharded_tensor_to_param,
to_global,
)

__all__ = ['ParallelModule']

Expand All @@ -25,3 +37,133 @@ def from_native_module(module: nn.Module,
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass

def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.

In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.

Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()

if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
else:
destination[prefix + name] = param_

for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.

.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.

Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}

for name, param in local_state.items():
key = prefix + name

if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'.format(key, type(input_param)))
continue

if is_distributed_tensor(param):
# shard the input param
device_mesh = get_device_mesh(param)
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)

# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]

if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
continue

try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
elif strict:
missing_keys.append(key)

extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)

if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
24 changes: 22 additions & 2 deletions colossalai/tensor/d_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
from .d_tensor import DTensor
from .api import (
compute_global_numel,
distribute_tensor,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
is_distributed_tensor,
is_sharded,
redistribute,
shard_colwise,
shard_rowwise,
sharded_tensor_to_param,
to_global,
)
from .layout import Layout
from .sharding_spec import ShardingSpec

__all__ = ['DTensor', 'ShardingSpec']
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout'
'Layout', 'ShardingSpec'
]
Loading