From 5cc9385f0821e495d6cf66a7a2af26c159de0fa0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 30 May 2024 08:19:56 +0000 Subject: [PATCH] fix --- colossalai/shardformer/policies/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 713175c6cc13..a9c982231825 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -351,7 +351,7 @@ def module_policy(self): policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: + if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription(