From b786cfd20760da209a1d9261dc450e0d7e8f23d8 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 13 Sep 2023 13:13:53 +0800 Subject: [PATCH] [shardformer] fix GPT2DoubleHeadsModel --- colossalai/shardformer/modeling/gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index bc99be4cc391..84deafefeadd 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -94,9 +94,9 @@ def gpt2_model_forward( if hidden_states is None: raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") input_shape = hidden_states.size()[:-1] - batch_size = input_shape[0] device = hidden_states.device hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) + batch_size = hidden_states.shape[0] # GPT2Attention mask. if attention_mask is not None: