Skip to content
Merged

Fix rag #38585

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions tests/models/rag/test_modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from unittest.mock import patch

import numpy as np
import requests

from transformers import BartTokenizer, T5Tokenizer
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
Expand Down Expand Up @@ -49,7 +50,7 @@
if is_torch_available() and is_datasets_available() and is_faiss_available():
import faiss
import torch
from datasets import Dataset
from datasets import Dataset, load_dataset

from transformers import (
AutoConfig,
Expand Down Expand Up @@ -679,6 +680,24 @@ def config_and_inputs(self):
@require_tokenizers
@require_torch_non_multi_accelerator
class RagModelIntegrationTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.dataset_path = cls.temp_dir.name
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")

ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
ds.save_to_disk(cls.dataset_path)

url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
response = requests.get(url, stream=True)
with open(cls.index_path, "wb") as fp:
fp.write(response.content)

@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()

def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
Expand Down Expand Up @@ -722,8 +741,9 @@ def get_rag_config(self):
max_combined_length=300,
dataset="wiki_dpr",
dataset_split="train",
index_name="exact",
index_path=None,
index_name="custom",
passages_path=self.dataset_path,
index_path=self.index_path,
use_dummy_dataset=True,
retrieval_vector_size=768,
retrieval_batch_size=8,
Expand Down Expand Up @@ -841,8 +861,8 @@ def test_rag_token_generate_beam(self):
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)

# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
Comment on lines -844 to -845
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we have this short expected outputs.

It's already not matching back to 2024/04 and maybe even earlier.

https://huggingface.slack.com/archives/C06LR9PQA00/p1712388818140089?thread_ts=1712388743.278929&cid=C06LR9PQA00

EXPECTED_OUTPUT_TEXT_1 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Uses"'
EXPECTED_OUTPUT_TEXT_2 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Ways"'

self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
Expand Down Expand Up @@ -903,7 +923,10 @@ def test_data_questions(self):
def test_rag_sequence_generate_batch(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
"facebook/rag-sequence-nq",
index_name="custom",
passages_path=self.dataset_path,
index_path=self.index_path,
)
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device
Expand All @@ -926,12 +949,13 @@ def test_rag_sequence_generate_batch(self):

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

# PR #31938 cause the output being changed from `june 22, 2018` to `june 22 , 2018`.
EXPECTED_OUTPUTS = [
" albert einstein",
" june 22, 2018",
" june 22 , 2018",
" amplitude modulation",
" tim besley ( chairman )",
" june 20, 2018",
" june 20 , 2018",
" 1980",
" 7.0",
" 8",
Expand All @@ -943,9 +967,9 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq",
index_name="exact",
use_dummy_dataset=True,
dataset_revision="b24a417",
index_name="custom",
passages_path=self.dataset_path,
index_path=self.index_path,
)
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device
Expand Down Expand Up @@ -981,10 +1005,10 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):

EXPECTED_OUTPUTS = [
" albert einstein",
" june 22, 2018",
" june 22 , 2018",
" amplitude modulation",
" tim besley ( chairman )",
" june 20, 2018",
" june 20 , 2018",
" 1980",
" 7.0",
" 8",
Expand All @@ -995,7 +1019,7 @@ def test_rag_sequence_generate_batch_from_context_input_ids(self):
def test_rag_token_generate_batch(self):
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained(
"facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
"facebook/rag-token-nq", index_name="custom", passages_path=self.dataset_path, index_path=self.index_path
)
rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
torch_device
Expand Down Expand Up @@ -1023,10 +1047,10 @@ def test_rag_token_generate_batch(self):

EXPECTED_OUTPUTS = [
" albert einstein",
" september 22, 2017",
" september 22 , 2017",
" amplitude modulation",
" stefan persson",
" april 20, 2018",
" april 20 , 2018",
" the 1970s",
" 7.1. 2",
" 13",
Expand All @@ -1037,6 +1061,24 @@ def test_rag_token_generate_batch(self):
@require_torch
@require_retrieval
class RagModelSaveLoadTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.TemporaryDirectory()
cls.dataset_path = cls.temp_dir.name
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")

ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
ds.save_to_disk(cls.dataset_path)

url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
response = requests.get(url, stream=True)
with open(cls.index_path, "wb") as fp:
fp.write(response.content)

@classmethod
def tearDownClass(cls):
cls.temp_dir.cleanup()

def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
Expand All @@ -1060,8 +1102,9 @@ def get_rag_config(self):
max_combined_length=300,
dataset="wiki_dpr",
dataset_split="train",
index_name="exact",
index_path=None,
index_name="custom",
passages_path=self.dataset_path,
index_path=self.index_path,
use_dummy_dataset=True,
retrieval_vector_size=768,
retrieval_batch_size=8,
Expand Down