Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colossalai/shardformer/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None:
src_rank = 0
else:
src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)

origin_device = self.bias.device
self.bias = self.bias.cuda()
dist.broadcast(self.bias, src=src_rank, group=self.process_group)
self.bias = self.bias.to(origin_device)

def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
Expand Down
40 changes: 36 additions & 4 deletions colossalai/tensor/d_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,21 @@
from .sharding_spec import ShardingSpec


def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
def shard_rowwise(tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
inplace: bool = False) -> DTensor:
"""
Shard the first dim of the given tensor
Shard the first dim of the given tensor.

Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.

Returns:
DTensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
Expand All @@ -24,12 +36,28 @@ def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})

if not inplace:
tensor = tensor.detach().clone()

return DTensor(tensor, device_mesh, sharding_spec)


def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor:
def shard_colwise(tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
inplace: bool = False) -> DTensor:
"""
Shard the first dim of the given tensor
Shard the first dim of the given tensor.

Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.

Returns:
DTensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
Expand All @@ -41,4 +69,8 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})

if not inplace:
tensor = tensor.detach().clone()

return DTensor(tensor, device_mesh, sharding_spec)