From 579d869db50ab04747a08e8b155f2cbe564fd70f Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Mon, 18 Dec 2023 17:15:11 -0800 Subject: [PATCH] change KV splitting based on Megatron-LM --- .../gpt_bigcode/modeling_gpt_bigcode.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index ff673fb8d9..c2453bff21 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -273,7 +273,35 @@ def forward( # .split((self.head_dim, 2 * self.head_dim), dim=3) # ) - query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + # query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + + # Split the KV tensors based on Megatron-LM's way + c_states = self.c_attn(hidden_states) + new_tensor_shape = hidden_states.size()[:-1] + (self.kv_heads, ((self.num_heads // self.kv_heads + 2)* self.head_dim),) + c_states= c_states.view(*new_tensor_shape) + (query, key, value) = torch.split( + c_states, + [ + ( + self.num_heads + // self.kv_heads + * self.head_dim + ), + self.head_dim, + self.head_dim, + ], + dim=3, + ) + + query = query.reshape(query.size()[:-2] + (-1,)) + key = key.reshape(key.size()[:-2] + (-1,)) + value = value.reshape(value.size()[:-2] + (-1,)) + key_value = torch.cat([key, value], dim=-1) + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.kv_heads * self.head_dim), dim=-1) # key_value: (batch, sequence, 2 * kv_heads * head_dim) if layer_past is not None: