Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class ModelSharder:

def shard(self) -> None:
"""
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...

Expand Down
12 changes: 6 additions & 6 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index:

# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta
up_shreshold = down_shreshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold)
masked_target = target.clone() - down_shreshold
down_threshold = rank * delta
up_threshold = down_threshold + delta
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0

# reshape the logist and target
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
Expand All @@ -79,7 +79,7 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index:
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))

# caculate the softmax
# calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Policy(ABC):
like BertPolicy for Bert Model or OPTPolicy for OPT model.

Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _recursive_replace_layer(
layer (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic
param_replacement (List[Callable]): The function list to get parameter shard information in policy
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
"""
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
Expand Down