From b2766ca2fccd909322708df6516982f54c2d946c Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 10 Aug 2023 15:28:54 +0800 Subject: [PATCH 1/2] [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix --- tests/test_shardformer/test_model/_utils.py | 8 +++++--- tests/test_shardformer/test_model/test_shard_gpt2.py | 8 +++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index cce21809d829..e4755256190c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -206,7 +206,8 @@ def check_weight(org_model: Module, if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros([*sharded_weight.shape]).to(sharded_weight.dtype).to('cuda') + for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -215,7 +216,7 @@ def check_weight(org_model: Module, print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ - f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" def check_grad(org_model: Module, @@ -234,7 +235,8 @@ def check_grad(org_model: Module, if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [ - torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros([*shard_grad.shape]).to(shard_grad.dtype).to('cuda') + for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3ac8fa26d860..274cfaa39ad1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( org_model, @@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'GPT2Model': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - # check loss check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) def unwrap(module): @@ -92,13 +90,14 @@ def unwrap(module): 'num_microbatches': 4, 'enable_all_optimization': True, 'use_lazy_init': True, - 'precision': 'fp32', + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'enable_all_optimization': True, - 'use_lazy_init': False, + 'use_lazy_init': True, 'precision': 'fp16', 'initial_scale': 1, }, { @@ -112,7 +111,6 @@ def unwrap(module): def run_gpt2_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # TODO: check and debug TP+AMP sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') From 00a3634fc2ff97e6521dbe5a804d5a132bc72ced Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 10 Aug 2023 16:35:35 +0800 Subject: [PATCH 2/2] [shardformer] gpt2 tests fix --- tests/test_shardformer/test_model/_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e4755256190c..bea0587e646b 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -206,8 +206,7 @@ def check_weight(org_model: Module, if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros([*sharded_weight.shape]).to(sharded_weight.dtype).to('cuda') - for _ in range(dist.get_world_size(tp_group)) + torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -234,10 +233,7 @@ def check_grad(org_model: Module, shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [ - torch.zeros([*shard_grad.shape]).to(shard_grad.dtype).to('cuda') - for _ in range(dist.get_world_size(tp_group)) - ] + shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim)