Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
return new_module

def set_lm_head(module):
if is_autotp_training_mode():
# we need to handle autoTP training mode separately.
return

embedding_weight = None
for n, p in module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _configure_tensor_parallel_states(self, model):
# sanity check
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
assert self.zero_optimization_stage(
) <= 1, "Currently, the compatibility between 'autotp' and 'zero_stage > 1' has not been validated"
) <= 2, "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated"

self.mpu = groups
self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size())
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,9 +1134,10 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis
if inputs1.keys() != inputs2.keys():
return False
for key in inputs1:
val1 = inputs1[key].to(get_accelerator().current_device())
val2 = inputs2[key].to(get_accelerator().current_device())
val1, val2 = inputs1[key], inputs2[key]
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
val1 = val1.to(get_accelerator().current_device())
val2 = val2.to(get_accelerator().current_device())
if not torch.equal(val1, val2):
return False
elif val1 != val2:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro
return model, base_model


@pytest.mark.parametrize("zero_stage", [0, 1])
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
@pytest.mark.parametrize("tp_size", [2, 4])
class TestSave(DistributedTest):

Expand Down Expand Up @@ -492,7 +492,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int):
compare_lr_scheduler_states(trained_model, loaded_model)


@pytest.mark.parametrize("zero_stage", [0, 1])
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
@pytest.mark.parametrize("tp_size", [2, 4])
class TestTpGradNorm(DistributedTest):

Expand Down