Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,6 @@ def forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights += causal_mask
Expand Down
101 changes: 98 additions & 3 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from parameterized import parameterized

from transformers import Phi3Config, is_torch_available, set_seed
from transformers import Phi3Config, StaticCache, is_torch_available, set_seed
from transformers.testing_utils import (
require_torch,
slow,
Expand All @@ -43,6 +43,55 @@
Phi3Model,
)

end_of_text_token = 32000

class Phi3MiniWithStaticCache(torch.nn.Module):
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
super().__init__()
self.model = model
self.cache = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=self.model.device,
dtype=self.model.dtype,
)

def forward(
self,
input_ids: torch.LongTensor = None,
) -> torch.FloatTensor:
return self.model.forward(
input_ids=input_ids,
use_cache=True,
return_dict=True,
past_key_values=self.cache,
).logits

@staticmethod
def generate(model: Phi3ForCausalLM, prompt_tokens: torch.LongTensor, max_seq_len: int) -> list[int]:
model = Phi3MiniWithStaticCache(model, 1, max_seq_len + prompt_tokens.shape[-1])

response_tokens = []

for input_pos in range(prompt_tokens.shape[-1]):
result = model.forward(
input_ids=prompt_tokens[:, input_pos : input_pos + 1],
)
response_tokens.append(prompt_tokens[0][input_pos].item())

current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)

while current_token != end_of_text_token and len(response_tokens) < max_seq_len:
result = model.forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)

return response_tokens


class Phi3ModelTester:
def __init__(
Expand Down Expand Up @@ -429,7 +478,30 @@ def test_phi3_mini_4k_instruct_generation(self):
output_text = tokenizer.batch_decode(outputs)

EXPECTED_OUTPUT = [
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit"
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for incorporating these fruits into your"
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)

def test_phi3_mini_4k_instruct_with_static_cache(self):
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")

messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")

response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)

output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))

EXPECTED_OUTPUT = [
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some"
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)
Expand Down Expand Up @@ -467,7 +539,30 @@ def test_phi3_mini_128k_instruct_generation(self):
output_text = tokenizer.batch_decode(outputs)

EXPECTED_OUTPUT = [
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and healthy ways. Here are some ideas:\n\n1."
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways. Here are some creative and healthy"
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)

def test_phi3_mini_128k_instruct_with_static_cache(self):
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-128k-instruct")

messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")

response_tokens = Phi3MiniWithStaticCache.generate(model, inputs, 64)

output_text = tokenizer.batch_decode(torch.tensor([response_tokens], dtype=torch.long, device=torch_device))

EXPECTED_OUTPUT = [
"<|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and nutritious ways"
]

self.assertListEqual(output_text, EXPECTED_OUTPUT)