From e6705b4c0f2797e316529af4e1abdfe6e13eddb1 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 10 Aug 2023 15:28:54 +0800 Subject: [PATCH 1/3] [shardformer] gpt2 tests fix [shardformer] test all optimizations (#4399) [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] test all optimizations [shardformer] gpt2 tests fix --- tests/test_shardformer/test_model/_utils.py | 8 +++++--- tests/test_shardformer/test_model/test_shard_gpt2.py | 8 +++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index cce21809d829..e4755256190c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -206,7 +206,8 @@ def check_weight(org_model: Module, if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros([*sharded_weight.shape]).to(sharded_weight.dtype).to('cuda') + for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -215,7 +216,7 @@ def check_weight(org_model: Module, print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ - f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" def check_grad(org_model: Module, @@ -234,7 +235,8 @@ def check_grad(org_model: Module, if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): shard_grad_list = [ - torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros([*shard_grad.shape]).to(shard_grad.dtype).to('cuda') + for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 3ac8fa26d860..274cfaa39ad1 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - org_loss, org_output, sharded_loss, sharded_output = \ run_forward_backward_with_hybrid_plugin( org_model, @@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, if org_model.__class__.__name__ == 'GPT2Model': check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - # check loss check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) def unwrap(module): @@ -92,13 +90,14 @@ def unwrap(module): 'num_microbatches': 4, 'enable_all_optimization': True, 'use_lazy_init': True, - 'precision': 'fp32', + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, 'enable_all_optimization': True, - 'use_lazy_init': False, + 'use_lazy_init': True, 'precision': 'fp16', 'initial_scale': 1, }, { @@ -112,7 +111,6 @@ def unwrap(module): def run_gpt2_test(test_config): # TODO: add test_config for TP+DP after supporting & debugging it - # TODO: check and debug TP+AMP sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') From ae30ae11f658147aae2bae92840665140d636079 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 10 Aug 2023 16:23:37 +0800 Subject: [PATCH 2/3] [shardformer]update t5 to use all optimizations --- tests/kit/model_zoo/transformers/t5.py | 8 ++-- .../test_model/test_shard_t5.py | 40 ++++++++++++++----- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 435cb6f46937..175d48963480 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -16,8 +16,8 @@ def data_gen_for_encoder_only(): # config = T5Config(decoder_start_token_id=0) # tokenizer = T5Tokenizer.from_pretrained("t5-small") # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids - input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12, 1627, 5, 1, 12]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -26,7 +26,7 @@ def data_gen_for_conditional_generation(): # # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() - labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long() data['labels'] = labels return data @@ -35,7 +35,7 @@ def data_gen_for_t5_model(): # decoder_inputs_ids is obtained with the following code # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() - decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long() data['decoder_input_ids'] = decoder_input_ids return data diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index d807ffa06296..7f36ab377615 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -37,11 +37,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if org_model.__class__.__name__ != 'T5ForConditionalGeneration': - check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3) + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) - check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3) + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwrap model t5 = org_model @@ -50,14 +54,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] # check weights and gradients + if test_config['precision'] == 'fp32': + atol, rtol = 1e-5, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0) + check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) # check weights after optimizer.step() org_optimizer.step() sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 1e-4, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False) + check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False) torch.cuda.empty_cache() @@ -66,23 +78,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 2, - 'enable_fused_normalization': True, - 'use_lazy_init': True + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 1, 'pp_size': 2, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp16', + 'initial_scale': 1, }, { 'tp_size': 4, 'pp_size': 1, - 'enable_fused_normalization': True, - 'use_lazy_init': False + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', }, { 'tp_size': 1, 'pp_size': 4, 'num_microbatches': 4, - 'use_lazy_init': False + 'use_lazy_init': False, + 'precision': 'fp32', }]) @clear_cache_before_run() def run_t5_test(test_config): @@ -93,7 +111,7 @@ def run_t5_test(test_config): # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - test_config['precision'] = 'float' # Do not use fp16/bf16 in testing + # test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): From 3e2a34130b2895802642dfd05458cb853d8dc62d Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 14 Aug 2023 15:46:32 +0800 Subject: [PATCH 3/3] [shardformer]update t5 to use all optimizations, fix --- colossalai/shardformer/README.md | 2 +- tests/test_shardformer/test_model/test_shard_t5.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 1c11b4b85444..18e00a6a663d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -31,7 +31,7 @@ ### Quick Start -The sample API usage is given below(If you enable the use of flash attention, please install xformers.): +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): ``` python from colossalai.shardformer import ShardConfig, Shard diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 7f36ab377615..fb065b42250b 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -111,7 +111,6 @@ def run_t5_test(test_config): # TODO: add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - # test_config['precision'] = 'float' # Do not use fp16/bf16 in testing for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():