Skip to content

Latest commit

 

History

History
189 lines (139 loc) · 9.83 KB

README.md

File metadata and controls

189 lines (139 loc) · 9.83 KB

👉 OneGen 👈

OneGen: Efficient One-Pass Unified Generation and Retrieval for LLMs

Awesome License: MIT

☁️ Google Drive (Data)
📄arXiv𝕏 Blog • 🌐 Web

🤗 HF (Model)👇 • 🔭 Model Scope (Model)👇 • 🧊 Wise Model (Model)👇

🎯 Task Name 🤗 HuggingFace 🔭 ModelScope 🧊 WiseModel
Entity Linking Llama2-7B Llama2-7B Llama2-7B
Single-hop QA Llama2-7B Llama2-7B Llama2-7B
Multi-hop QA Llama2-7B Llama2-7B Llama2-7B

Table of Contents

📋TODO

  • Support LoRA train
  • Code documentation
  • Support vLLM inference
  • Support distributed embedding
  • Gradio

👀Overview

We introduce a One-pass Generation and retrieval framework (OneGen) for fine-tuning LLMs on generation, retrieval, or hybrid tasks. Our core idea is to integrate generation and retrieval to the same context by allocating the retrieval task to retirval tokens generated in an autoregressive manner, thus enabling LLM to perform both tasks in a single forward pass.

The following figure illustrates the training process. We first introduce the concept named roles of tokens in LLMs. A token $x_i$ is the basic unit processed by an LLM. Token in the input of an LLM serves three different roles:

  • Generating next token, noted as $role(x_i)=\texttt{GEN}$.
  • Providing context information, noted as $role(x_i)=\texttt{CTX}$.
  • Representing a sentence, noted as $role(x_i)=\texttt{RET}$.

Hence, we apply the cross-entropy loss for the token $x_i$ where $role(x_i)=\texttt{GEN}$ and apply the contrastive loss for the token $x_i$ where $role(x_i)=\texttt{RET}$. This is the training overview.

The following figure illustrates the inference process of different methods for RAG task. First, we can see both GritLM and OneGen only need to deploy a single model, which can lower the deployment cost. However, GritLM achieves generation and retrieval within a single model by switching back and forth between causal attention and bidirectional attention. Additionally, both GritLM and the Pipeline method require explicit queries, which leads to the need for two forward passes for the queries. In contrast, OneGen can perform retrieval during the generation process, thus avoiding the two forward pass calculations for the queries and allowing for the direct use of kv-cache, significantly reducing inference costs.

🔧Installation

git clone https://github.com/zjunlp/OneGen
cd OneGen
conda create -n onegen python=3.9 -y
conda activate onegen
pip install -r requirements.txt

🏃Quick Start

The inference section focuses on running model predictions to get output results (Single-hop QA is an exception). The evaluation of these results is discussed in the Evaluation section.

Download the data

Download train_data.tar.gz and eval_data.tar.gz from Google Drive. After extracting, you will get two folders: train_data and eval_data. Move these two folders into the data directory. Use the following commands to extract the files:

tar -xzvf train_data.tar.gz
tar -xzvf eval_data.tar.gz

Please note that the training data we are using is available on Hugging Face, so you do not need to download train_data.tar.gz. Just run the training scripts!

Download the trained model (Optional)

Download the trained model (Optional)

The model weights trained on three tasks have been made public and are available for download on three platforms: 🤗Huggingface, 🔭ModelScope, and 🧊WiseModel. For detailed information, please refer to the table below:

🎯 Task Name 🤗 HuggingFace 🔭 ModelScope 🧊 WiseModel
Entity Linking Llama2-7B Llama2-7B Llama2-7B
Single-hop QA Llama2-7B Llama2-7B Llama2-7B
Multi-hop QA Llama2-7B Llama2-7B Llama2-7B

Note

It is worth noting that for the Entity Linking task, we have pre-stored the entity embeddings. Click here to download them.

Training model from scratch (Optional)

Training model from scratch (Optional)

We provide the training scripts for three tasks. If you are using a locally downloaded model, you can modify the info-model field in the workflow/{task}/{model}.json file. Update the model_path and tokenizer_path with the local paths. Note that the hyperparameters in the configuration files are set for 8xA800 GPUs. If you encounter OOM (Out of Memory) issues, please reduce the per_device_train_batch_size, n_pos_per_sent, n_neg_per_pos, and max_length.

# Entity Linking
deepspeed train.py --workflow workflow/entity_linking/llama2.json
# Single-Hop QA
deepspeed train.py --workflow workflow/self_rag/llama2.json
# Multi-hop QA
deepspeed train.py --workflow workflow/multi_hop_qa/llama2.json

Inference

Here are the inference scripts for the Entity Linking and Multi-hop QA tasks. The inference script for Single-Hop QA is introduced in the next section. You can modify the values of fields such as model_path, tokenizer_path, file, and output_file_path in {config}/{eval_config}/{task}/{config}.json as needed.

During the model inference process, we now support using Faiss as the vector retrieval engine. You just need to set the use_faiss field in the inference section of the config.json file to true.

# Entity Linking (Need GPU)
python eval.py --config config/eval_config/entity_linking/llama2_wo_pkl.json
# Multi-hop QA (Need GPU)
python eval.py --config config/eval_config/multi_hop_qa/llama2.json

Evaluation

Below are the evaluation scripts for the Entity Linking and Multi-hop QA tasks. /your/path/to/result.jsonl is the file saved during the inference stage.

# Entity Linking (CPU)
bash scripts/eval_el.sh el /your/path/to/result.jsonl

# Multi-hop QA for HotpotQA dataset (CPU)
bash scripts/eval_multi_hop_qa.sh /your/path/to/result.jsonl hotpotqa

# Multi-hop QA for 2WIKI dataset (CPU)
bash scripts/eval_multi_hop_qa.sh /your/path/to/result.jsonl 2wiki

Here is the evaluation for the Single-Hop QA task, mainly based on Self-RAG:

# Single-hop QA using Self-RAG (Need GPU)
# [CUDA_VISIBLE_DEVICES] [MODE] [MODEL_PATH] [SAVE_TAG] [SAVED_DATASET_PATH] [N_DOC] [ENV] [SCORE]
bash scripts/eval_self_rag.sh 0 always_retrieve /your/path/to/model model_tag saved_rank_path 5 true true

🚩Citation

If this work is helpful, please kindly cite as:

@inproceedings{EMNLP24_OneGen,
    title = "{O}ne{G}en: Efficient One-Pass Unified Generation and Retrieval for {LLM}s",
    author = "Zhang, Jintian  and
      Peng, Cheng  and
      Sun, Mengshu  and
      Chen, Xiang  and
      Liang, Lei  and
      Zhang, Zhiqiang  and
      Zhou, Jun  and
      Chen, Huajun  and
      Zhang, Ningyu",
    editor = "Al-Onaizan, Yaser  and
      Bansal, Mohit  and
      Chen, Yun-Nung",
    booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2024",
    month = nov,
    year = "2024",
    address = "Miami, Florida, USA",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2024.findings-emnlp.237",
    pages = "4088--4119",
}