Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
The sample API usage is given below:

``` python
from colossalai.shardformer import shard_model
from colossalai.shardformer import ShardConfig, shard_model
from transformers import BertForMaskedLM

# create huggingface model as normal
model = BertForMaskedLM.from_pretrained("bert-base-uncased")

# make the huggingface model paralleled to ShardModel
# auto policy:
sharded_model = shard_model(model)
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(model, config=shardconfig)

# custom policy:
from xxx import <POLICYCLASS>
Expand Down Expand Up @@ -72,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument

CustomPolicy(Policy):
class CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
Expand Down Expand Up @@ -235,7 +240,7 @@ CustomPolicy(Policy):
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.

CLASS `Col_Layer(Layer)`:
- gather_output (bool): Whether to gather the output of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.

This class inherited from `Layer`, representing the layer will be sliced along column.

Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .shard import ShardConfig, shard_model
10 changes: 5 additions & 5 deletions colossalai/shardformer/layer/dist_crossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DistCrossEntropy(Function):
"""

@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
Expand Down Expand Up @@ -75,8 +75,8 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):

# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))

# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
Expand All @@ -101,5 +101,5 @@ def backward(ctx, grad_output):
return grad_logits, None, None


def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)
5 changes: 3 additions & 2 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def unembedding() -> List:
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
# gather_output=True,
gather_output=True,
Comment thread
FrankLeeeee marked this conversation as resolved.
)
]

Expand All @@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy):

@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
# return (BertForMaskedLM, BertForMaskedLM_)
return None
Comment thread
FrankLeeeee marked this conversation as resolved.


class BertForSequenceClassificationPolicy(BertPolicy):
Expand Down
18 changes: 8 additions & 10 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@

@dataclass
class ShardConfig:
"""
The config for sharding the huggingface model for test
r"""
The config for sharding the huggingface model

Args:
rank (int): The rank of local process
world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend = "nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False
master_addr: str = "127.0.0.1"
master_port: int = 29500
gather_output: bool = True
4 changes: 3 additions & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def inject_model(
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return

if inject_policy is None:
return
Expand Down Expand Up @@ -148,7 +150,7 @@ def shard_one_layer(
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
gather_output = policy_layer.gather_output and self.shard_config.gather_output

if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/test/config.py

This file was deleted.

50 changes: 0 additions & 50 deletions colossalai/shardformer/test/module_test.py

This file was deleted.

124 changes: 0 additions & 124 deletions colossalai/shardformer/test/test.py

This file was deleted.

Loading