Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.distributed as dist

from colossalai.core import global_context as gpc

try:
Expand Down Expand Up @@ -72,6 +73,7 @@ def backward(ctx, grad_output):
total_input = input
grad_input = grad_output.matmul(weight)

grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
Expand Down
9 changes: 6 additions & 3 deletions colossalai/nn/layer/parallel_1d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ def __init__(self,
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features

# Parameters.
# Initialize weight.
Expand Down Expand Up @@ -612,7 +613,8 @@ def __init__(self,
raise ValueError('cannot skip bias addition if bias is None')

# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features

# Parameters.
# Initialize weight.
Expand Down Expand Up @@ -884,7 +886,8 @@ def __init__(self,

tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

Expand Down
177 changes: 177 additions & 0 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
## ShardFormer

### Intro
Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy.

### Quick start
1. Usage
- Use
``` python
from colossalai.shardformer.shard.shardmodel import ShardModel
from transformers import BertForMaskedLM

# create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

# make the huggingface model paralleled to ShardModel
# auto policy:
shardmodel = ShardModel(model).model

# custom policy:
from xxx import <POLICYCLASS>
shardmodel = ShardModel(model, <POLICYCLASS>).model


# do angthing as normal
...
```
- Policy

If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model.

You should do:

1. Inherit Policy class
2. Overwrite argument_policy method
- In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers.
3. Overwrite inject_policy method [Optional]
- If you need to modify the forward or backward progress.
4. Overwrite or add the param recording functions
- These function use suffix to record the path of weight or bias for the layer.
5. Overwrite binding

More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument

CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions

Args:
model_config: The config of transformer model
shard_setting: The config of distributed model

Return:
Dict for the modify policy,
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}

"""
raise NotImplementedError

@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
Return the dict for the inject model

Return:
The injected model, key is the original model and value is the new shardmodel
"""
return ()

@staticmethod
def binding_policy() -> Dict:
"""
Return the dict for the binding model
"""
return NotImplementedError

@staticmethod
def attn_in() -> List:
"""
Attention qkv layer

Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError

@staticmethod
def attn_out() -> List:
"""
Attention output projection layer

Returns:
List[Layer]: List of layer object
"""
return NotImplementedError

@staticmethod
def mlp_in() -> List:
"""
h -> 4h mlp layer

Returns:
List[Layer]: List of layer object
"""
return NotImplementedError

@staticmethod
def mlp_out() -> List:
"""
4h -> h mlp layer

Returns:
List[Layer]: List of layer object
"""
return NotImplementedError

@staticmethod
def embedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums

Return:
List[Layer]: List of layer object
"""
return NotImplementedError

@staticmethod
def unembedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums

Return:
List[Layer]: List of layer object
"""
return NotImplementedError

```

2. Simple example
``` shell
# inference
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference
# train
colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train
```
16 changes: 9 additions & 7 deletions colossalai/shardformer/model/modeling_bert.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Dict, List, Type

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import Any, Dict, List, Type


from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput


class BertForMaskedLM_(BertForMaskedLM):

def forward(
self,
input_ids=None,
Expand All @@ -23,7 +25,7 @@ def forward(
return_dict=None,
**kwargs,
):
print("[Inject OK] Injected forward method")
# print("[Inject OK] Injected forward method")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
Expand All @@ -46,9 +48,9 @@ def forward(
masked_lm_loss = None

# if input_ids is not None:
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
Expand All @@ -60,4 +62,4 @@ def forward(
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
)
25 changes: 16 additions & 9 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,47 @@
import torch.nn as nn


def build_policies():
"""
r"""
Build the policies for the model

Return:
The dict for the policies
"""
auto_policy_dict = {}

from transformers.models.bert.modeling_bert import BertForMaskedLM

from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy

from transformers.models.bert.modeling_bert import BertForSequenceClassification

from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy

return auto_policy_dict

def get_autopolicy(model:nn.Module):
"""

def get_autopolicy(model: nn.Module):
r"""
Return the auto policy for the model

Args:
model: The model to be used
model (:class:`nn.Module`): The model to get the auto policy

Return:
The auto policy for the model
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}")
if policy is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
)
return policy


# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
Expand Down
Loading