diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 713175c6cc13..a9c982231825 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -351,7 +351,7 @@ def module_policy(self): policy = super().module_policy() - if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism: + if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { LlamaForCausalLM: ModulePolicyDescription(