From 7f48c11037d6424883add902ccc8305f9de4f74b Mon Sep 17 00:00:00 2001 From: helunwencser Date: Tue, 30 Jul 2024 11:29:52 -0700 Subject: [PATCH 1/2] Remove size check between attn_weights and kv_seq_len --- src/transformers/models/phi3/modeling_phi3.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 6bbfae984548..8907da7cf3ba 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -401,12 +401,6 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights += causal_mask From 6d655fe2a071010c47bfc68c7b845c49e6950d35 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Tue, 30 Jul 2024 14:21:00 -0700 Subject: [PATCH 2/2] add unit tests --- tests/models/phi3/test_modeling_phi3.py | 101 +++++++++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 1ddc73961bfe..ec3986ff2338 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -19,7 +19,7 @@ from parameterized import parameterized -from transformers import Phi3Config, is_torch_available, set_seed +from transformers import Phi3Config, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( require_torch, slow, @@ -43,6 +43,55 @@ Phi3Model, ) + end_of_text_token = 32000 + + class Phi3MiniWithStaticCache(torch.nn.Module): + def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int): + super().__init__() + self.model = model + self.cache = StaticCache( + config=model.config, + max_batch_size=max_batch_size, + max_cache_len=max_seq_len, + device=self.model.device, + dtype=self.model.dtype, + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + ) -> torch.FloatTensor: + return self.model.forward( + input_ids=input_ids, + use_cache=True, + return_dict=True, + past_key_values=self.cache, + ).logits + + @staticmethod + def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]: + model = Phi3MiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1]) + + response_tokens = [] + + for input_pos in range(prompt_tokens.shape[-1]): + result = model.forward( + input_ids=prompt_tokens[:, input_pos : input_pos + 1], + ) + response_tokens.append(prompt_tokens[0][input_pos].item()) + + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + + while current_token != end_of_text_token and len(response_tokens) < max_seq_len: + result = model.forward( + input_ids=torch.tensor([[current_token]], dtype=torch.long), + ) + current_token = torch.argmax(result[:, -1, :], dim=-1).item() + response_tokens.append(current_token) + + return response_tokens + class Phi3ModelTester: def __init__( @@ -429,7 +478,30 @@ def test_phi3_mini_4k_instruct_generation(self): output_text = tokenizer.batch_decode(outputs) EXPECTED_OUTPUT = [ - "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit" + "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for incorporating these fruits into your" + ] + + self.assertListEqual(output_text, EXPECTED_OUTPUT) + + def test_phi3_mini_4k_instruct_with_static_cache(self): + model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + messages = [ + { + "role": "system", + "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.", + }, + {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, + ] + inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") + + response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64) + + output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device)) + + EXPECTED_OUTPUT = [ + "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some" ] self.assertListEqual(output_text, EXPECTED_OUTPUT) @@ -467,7 +539,30 @@ def test_phi3_mini_128k_instruct_generation(self): output_text = tokenizer.batch_decode(outputs) EXPECTED_OUTPUT = [ - "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and healthy ways. Here are some ideas:\n\n1." + "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways. Here are some creative and healthy" + ] + + self.assertListEqual(output_text, EXPECTED_OUTPUT) + + def test_phi3_mini_128k_instruct_with_static_cache(self): + model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct") + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-128k-instruct") + + messages = [ + { + "role": "system", + "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.", + }, + {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, + ] + inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") + + response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64) + + output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device)) + + EXPECTED_OUTPUT = [ + "<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways" ] self.assertListEqual(output_text, EXPECTED_OUTPUT)