Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions colossalai/device/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self,
# 2. directly supply the logical mesh id
assert mesh_shape is None or logical_mesh_id is None, \
"Only one of mesh_shape and logical_mesh_id can be specified." \
"Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
"Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id"

if logical_mesh_id is None:
self._mesh_shape = mesh_shape
Expand All @@ -74,7 +74,7 @@ def __init__(self,
assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
"physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
"Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
"Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again."
assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
"Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."

Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(self,
self._global_rank_of_current_process = None
self._is_initialized = False

# attribute used to inidicate whether this objectd
# attribute used to indicate whether this object
# is created using DeviceMesh.from_process_group
# this attribute can be used to do some check in methods
# such get_process_group as no global rank information
Expand Down Expand Up @@ -395,7 +395,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
Example:

```python
sphysical_mesh_id = torch.arange(0, 16)
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)

# logical mesh will look like
Expand Down Expand Up @@ -438,7 +438,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
# the _local_rank refers to the local rank of the current process
for _local_rank in range(self.logical_mesh_id.shape[dim]):

# if this dimension is not initailized yet,
# if this dimension is not initialized yet,
# initialize it with an empty array
if dim not in processes_in_the_same_process_group:
processes_in_the_same_process_group[dim] = []
Expand All @@ -447,7 +447,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()

# replace the local rank in the given dimension with the
# lcoal rank of the current process iterated
# local rank of the current process iterated
process_coordinates[dim] = _local_rank
processes_in_the_same_process_group[dim].append(process_coordinates)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/tensor/d_tensor/comm_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CommSpec:
to determine the buffer shape, and logical_process_axis

Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
Expand Down
4 changes: 2 additions & 2 deletions colossalai/tensor/shape_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
RS01 -> RR
'''
valid_spec_dict = {}
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
Expand All @@ -362,7 +362,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
b_target_pair = (b_index, [])

gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
comm_spec = CommSpec(comm_pathern,
comm_spec = CommSpec(comm_pattern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
Expand Down
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/transformers/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def data_gen_for_t5_model():
# output transform function
output_transform_fn = lambda x: x

# define loss funciton
# define loss function
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def check_torch_ddp_no_sync():
model = DummyModel()
criterion = lambda x: x.mean()
optimizer = SGD(model.parameters(), lr=1e-3)
# create a custom dasetset with 0 to 10
# create a custom dataset with 0 to 10
dataset = torch.arange(0, 10)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tests.kit.model_zoo import model_zoo


# test baisc fsdp function
# test basic fsdp function
def run_fn(model_fn, data_gen_fn, output_transform_fn):
plugin = TorchFSDPPlugin()
booster = Booster(plugin=plugin)
Expand Down