Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def module_policy(self):
target_module=NopadBaichuanMLP,
),
SubModuleReplacementDescription(
suffix="self_attn.W_pack", target_module=FusedLinear1D_Col, kwargs={"n_fused": 3}
suffix="self_attn.W_pack",
target_module=FusedLinear1D_Col,
kwargs={"split_sizes": [self.model.config.hidden_size] * 3},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
Expand Down
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

__all__ = [
"Embedding1D",
Expand Down Expand Up @@ -34,4 +34,5 @@
"RingAttention",
"get_pad_info",
"all_to_all_comm",
"FusedLinear1D_Row",
]
4 changes: 2 additions & 2 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communicati
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
bsz = input_.shape[0]

# using all_to_all_single when batch size is 1
if bsz == 1:
Expand Down Expand Up @@ -871,7 +871,7 @@ def backward(ctx, grad_output):
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape
bsz = grad_output.shape[0]

if bsz == 1:
return_grad = _all_to_all_single(
Expand Down
11 changes: 4 additions & 7 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,8 @@ def forward(self, input_: Tensor) -> Tensor:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
if self.seq_parallel_mode == "split_gather":
output_parallel = F.linear(input_, self.weight)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
)
Expand All @@ -445,8 +442,8 @@ def forward(self, input_: Tensor) -> Tensor:
ring=True,
)
else:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group)
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)

if not self.skip_bias_add:
if self.bias is not None:
Expand Down
Loading