From cd360daef088925fb4f99609ad9c9a0dfb922491 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 14 Sep 2023 11:20:08 +0800 Subject: [PATCH 1/4] [shardformer] fix whisper test failed --- colossalai/shardformer/policies/whisper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 5d496f08e1db..5f04bb80cb69 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -57,6 +57,11 @@ def module_policy(self): warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + #TODO using the jit fused add_and_dropout affect the accuracy + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": From 84dbe9644f9a17a340446b62300c88be58c88d16 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 14 Sep 2023 11:20:08 +0800 Subject: [PATCH 2/4] [shardformer] fix whisper test failed --- colossalai/shardformer/policies/whisper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 5d496f08e1db..5f04bb80cb69 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -57,6 +57,11 @@ def module_policy(self): warnings.warn( "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + #TODO using the jit fused add_and_dropout affect the accuracy + if self.shard_config.enable_jit_fused: + self.shard_config.enable_jit_fused = False + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the sequence parallelism flag.") + if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ "self_attn.embed_dim": From 1446c1df68ac8d3f8aa0f9e897e1cd01010028cc Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 14 Sep 2023 11:29:19 +0800 Subject: [PATCH 3/4] [shardformer] fix whisper test failed --- colossalai/shardformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 559f9a56f61e..b1573ae163a0 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -114,7 +114,7 @@ We will follow this roadmap to develop Shardformer: | bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | | vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | -| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [ ] | [x] | [ ] | [ ] | | sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | From ffe58fcaf8057f5fc51855a8b3872e6b9e2f5466 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 14 Sep 2023 11:35:28 +0800 Subject: [PATCH 4/4] [shardformer] fix whisper test failed --- colossalai/shardformer/policies/whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 5f04bb80cb69..31ba82166b31 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -60,7 +60,7 @@ def module_policy(self): #TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False - warnings.warn("Whisper dosen't support jit fused operator now, will ignore the sequence parallelism flag.") + warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={