This repository is derived from the original CoRAG repository, which provides the model, inference data, and evaluation code for the paper Chain-of-Retrieval Augmented Generation.
Based on the original paper's methodology, this fork adds:
- Training Scripts: Full fine-tuning scripts with Chain-of-RAG masking strategy (see
src/train/) - Dataset Construction Scripts: Data preparation pipeline that includes context retrieval via Graph API (see
src/train/prepare_training_data.py)
| Dataset | Short description |
|---|---|
| corag/multihopqa | MultihopQA datasets and model predictions |
| corag/kilt | KILT benchmark with top-k retrieval results |
| corag/kilt-corpus | KILT-version of Wikipedia for retrieval |
| corag/kilt-corpus-embeddings | Corpus embeddings from e5-large-v2 |
| Model | Short description |
|---|---|
| corag/CoRAG-Llama3.1-8B-MultihopQA | CoRAG-8B fine-tuned for multihop QA |
pip install -r requirements.txt
This repository includes comprehensive training scripts for fine-tuning models with the CoRAG methodology. The training process consists of two main steps:
The prepare_training_data.py script converts datasets into training format with real-time context retrieval:
python3 src/train/prepare_training_data.py \
--dataset corag/multihopqa \
--task 2wikimultihopqa \
--output_file data/train_with_graph_retrieval.jsonl \
--retrieve \
--top_k 3 \
--graph_api_url http://localhost:8023/retrieveKey Features:
- Real-time retrieval of relevant documents via Graph API for each sub-query
- Configurable top-k document selection
- Support for multiple datasets (2wikimultihopqa, hotpotqa, musique, bamboogle)
The train.py script implements full fine-tuning with Chain-of-RAG masking strategy:
bash src/train/run_training.shTraining Characteristics:
- Loss Computation: Applied to SubQuery, SubAnswer, and Final Answer (model's reasoning)
- Masked (No Loss): User Query and Retrieved Context (external inputs)
- Supports DeepSpeed ZeRO-3 for distributed training
- Compatible with Qwen, Llama, and other instruction-tuned models
For detailed training parameters and configuration, see src/train/README.md.
Here we provide an example for running inference with CoRAG-8B on the MultihopQA dataset. We tested this on a machine with 8 A100 GPUs (40GB).
- Download embeddings and start the E5 search server.
bash scripts/download_embeddings.sh
# The server logs will be in e5_server.log
bash scripts/start_e5_server.sh- Start the vLLM server and load the CoRAG-8B model.
# The server logs will be in vllm_server.log
bash scripts/start_vllm_server.sh corag/CoRAG-Llama3.1-8B-MultihopQA- Run the inference script. By default, we will use greedy decoding with max path length
L = 6.
# It will evaluate on [2wikimultihopqa, bamboogle, hotpotqa, musique] sequentially.
bash scripts/eval_multihopqa.shAt the end, you will see the evaluation metrics similar to the following (for MuSiQue dataset):
{
"em": 27.679,
"f1": 38.532,
"accuracy": 27.141,
"num_samples": 2417,
"max_path_length": 6,
"decode_strategy": "greedy",
"token_consumed": 23818600,
"average_token_consumed_per_sample": 9854.613156805957
}Due to the randomness of the sampling process, the results may vary slightly each time you run the script especially for small datasets like Bamboogle.
If you find this repository useful, please consider citing our paper:
@article{wang2025chain,
title={Chain-of-Retrieval Augmented Generation},
author={Wang, Liang and Chen, Haonan and Yang, Nan and Huang, Xiaolong and Dou, Zhicheng and Wei, Furu},
journal={arXiv preprint arXiv:2501.14342},
year={2025}
}