Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,43 +211,36 @@ def backward(ctx, grad_output):
handle.wait()

else:
# create new stream for calculate the gradient
calculate_stream = torch.cuda.Stream()

# do all gather in default stream
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)

# calculate gradient in calculate_stream
with torch.cuda.stream(calculate_stream):
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None

# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()

torch.cuda.current_stream().wait_stream(calculate_stream)
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()

# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
with torch.cuda.stream(calculate_stream):
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
print(grad_output.shape, input_parallel.shape)
grad_weight = grad_output.t().matmul(input_parallel)

torch.cuda.current_stream().wait_stream(calculate_stream)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()

return output, grad_weight, grad_bias, None, None, None, None
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self,
gather_output: bool = False,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
overlap: bool = False,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
Expand Down
21 changes: 17 additions & 4 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def module_policy(self):

policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
Expand All @@ -71,17 +72,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
Expand All @@ -99,7 +109,10 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="output.dense",
Expand Down
11 changes: 9 additions & 2 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def module_policy(self):
policy = {}

use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
Expand All @@ -55,7 +56,10 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={'seq_parallel': use_sequence_parallel}),
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
Expand All @@ -67,7 +71,10 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs={'seq_parallel': use_sequence_parallel}),
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}

use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
sub_module_replacement=[
Expand Down Expand Up @@ -81,7 +82,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
target_module=col_nn.Linear1D_Col,
kwargs={
'seq_parallel': use_sequence_parallel,
'seq_parallel_dim': 0
'seq_parallel_dim': 0,
'overlap': overlap
}),
SubModuleReplacementDescription(suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
Expand Down
9 changes: 9 additions & 0 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class ShardConfig:
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
Expand All @@ -29,6 +31,7 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False

# pipeline_parallel_size: int
# data_parallel_size: int
Expand All @@ -41,6 +44,11 @@ def tensor_parallel_size(self):
return self._tensor_parallel_size

def __post_init__(self):
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
raise ValueError(
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
Expand All @@ -59,3 +67,4 @@ def _turn_on_all_optimization(self):
self.enable_flash_attention = True
self.enable_jit_fused = True
self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_layer/test_linear_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool

@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [False, True])
@parameterize('overlap', [True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
Expand Down