Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions colossalai/tensor/d_tensor/sharding_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,21 @@ def build_difference_2d_dict(self):

def dim_diff(self, other):
'''
The difference between two _DimSpec.
The difference between two DimSpec.

Argument:
other(_DimSpec): the dim spec to compare with.
other(DimSpec): the dim spec to compare with.

Return:
difference(int): the difference between two _DimSpec.

Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
```python
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))

Output:
5
# output: 5
```
'''
difference = self.difference_dict[(str(self), str(other))]
return difference
Expand All @@ -142,9 +142,13 @@ class ShardingSpec:
[R, R, S0, S1], which means

Argument:
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
dim_size (int): The number of dimensions of the tensor to be sharded.
dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None.
E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1.
sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
Generally, users should specify either dim_partition_dict or sharding_sequence.
If both are given, users must ensure that they are consistent with each other. Defaults to None.
'''

def __init__(self,
Expand Down Expand Up @@ -208,6 +212,7 @@ def spec_diff(self, other):
pair of sharding sequence.

Example:
```python
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
Expand All @@ -219,10 +224,8 @@ def spec_diff(self, other):
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))

Output:
25

# output: 25
```
Argument:
other(ShardingSpec): The ShardingSpec to compared with.

Expand Down