Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlashRank Reranker module #818

Merged
merged 12 commits into from
Oct 11, 2024
1 change: 1 addition & 0 deletions autorag/nodes/passagereranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .openvino import OpenVINOReranker
from .voyageai import VoyageAIReranker
from .mixedbreadai import MixedbreadAIReranker
from .flashrank import FlashRankReranker
224 changes: 224 additions & 0 deletions autorag/nodes/passagereranker/flashrank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import json
from pathlib import Path

import pandas as pd
import torch
from tokenizers import AddedToken, Tokenizer
import onnxruntime as ort
bwook00 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import os
import zipfile
import requests
from tqdm import tqdm
import collections
from typing import List, Dict, Tuple

from autorag.nodes.passagereranker.base import BasePassageReranker
from autorag.utils import result_to_dataframe
from autorag.utils.util import flatten_apply, sort_by_scores, select_top_k, make_batch

model_url = "https://huggingface.co/prithivida/flashrank/resolve/main/{}.zip"

model_file_map = {
"ms-marco-TinyBERT-L-2-v2": "flashrank-TinyBERT-L-2-v2.onnx",
"ms-marco-MiniLM-L-12-v2": "flashrank-MiniLM-L-12-v2_Q.onnx",
"ms-marco-MultiBERT-L-12": "flashrank-MultiBERT-L12_Q.onnx",
"rank-T5-flan": "flashrank-rankt5_Q.onnx",
"ce-esci-MiniLM-L12-v2": "flashrank-ce-esci-MiniLM-L12-v2_Q.onnx",
"miniReranker_arabic_v1": "miniReranker_arabic_v1.onnx",
}


class FlashRankReranker(BasePassageReranker):
def __init__(
self, project_dir: str, model: str = "ms-marco-TinyBERT-L-2-v2", *args, **kwargs
):
"""
Initialize FlashRank rerank node.

:param project_dir: The project directory path.
:param model: The model name for FlashRank rerank.
You can get the list of available models from https://github.com/PrithivirajDamodaran/FlashRank.
Default is "ms-marco-TinyBERT-L-2-v2".
Not support “rank_zephyr_7b_v1_full” due to parallel inference issue.
:param kwargs: Extra arguments that are not affected
"""
super().__init__(project_dir)

cache_dir = kwargs.pop("cache_dir", "/tmp")
max_length = kwargs.pop("max_length", 512)

self.cache_dir: Path = Path(cache_dir)
self.model_dir: Path = self.cache_dir / model
self._prepare_model_dir(model)
model_file = model_file_map[model]

self.session = ort.InferenceSession(str(self.model_dir / model_file))
self.tokenizer: Tokenizer = self._get_tokenizer(max_length)

def __del__(self):
del self.session
del self.tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
super().__del__()

def _prepare_model_dir(self, model_name: str):
if not self.cache_dir.exists():
self.cache_dir.mkdir(parents=True, exist_ok=True)

if not self.model_dir.exists():
self._download_model_files(model_name)

def _download_model_files(self, model_name: str):
local_zip_file = self.cache_dir / f"{model_name}.zip"
formatted_model_url = model_url.format(model_name)

with requests.get(formatted_model_url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with (
open(local_zip_file, "wb") as f,
tqdm(
desc=local_zip_file.name,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar,
):
for chunk in r.iter_content(chunk_size=8192):
size = f.write(chunk)
bar.update(size)

with zipfile.ZipFile(local_zip_file, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
os.remove(local_zip_file)

def _get_tokenizer(self, max_length: int = 512) -> Tokenizer:
config = json.load(open(str(self.model_dir / "config.json")))
tokenizer_config = json.load(
open(str(self.model_dir / "tokenizer_config.json"))
)
tokens_map = json.load(open(str(self.model_dir / "special_tokens_map.json")))
tokenizer = Tokenizer.from_file(str(self.model_dir / "tokenizer.json"))

tokenizer.enable_truncation(
max_length=min(tokenizer_config["model_max_length"], max_length)
)
tokenizer.enable_padding(
pad_id=config["pad_token_id"], pad_token=tokenizer_config["pad_token"]
)

for token in tokens_map.values():
if isinstance(token, str):
tokenizer.add_special_tokens([token])
elif isinstance(token, dict):
tokenizer.add_special_tokens([AddedToken(**token)])

vocab_file = self.model_dir / "vocab.txt"
if vocab_file.exists():
tokenizer.vocab = self._load_vocab(vocab_file)
tokenizer.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in tokenizer.vocab.items()]
)
return tokenizer

def _load_vocab(self, vocab_file: Path) -> Dict[str, int]:
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, contents, _, ids = self.cast_to_run(previous_result)
top_k = kwargs.pop("top_k")
batch = kwargs.pop("batch", 64)
return self._pure(queries, contents, ids, top_k, batch)

def _pure(
self,
queries: List[str],
contents_list: List[List[str]],
ids_list: List[List[str]],
top_k: int,
batch: int = 64,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents with FlashRank rerank models.

:param queries: The list of queries to use for reranking
:param contents_list: The list of lists of contents to rerank
:param ids_list: The list of lists of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param batch: The number of queries to be processed in a batch
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
nested_list = [
list(map(lambda x: [query, x], content_list))
for query, content_list in zip(queries, contents_list)
]

rerank_scores = flatten_apply(
flashrank_run_model,
nested_list,
session=self.session,
batch_size=batch,
tokenizer=self.tokenizer,
)

df = pd.DataFrame(
{
"contents": contents_list,
"ids": ids_list,
"scores": rerank_scores,
}
)
df[["contents", "ids", "scores"]] = df.apply(
sort_by_scores, axis=1, result_type="expand"
)
results = select_top_k(df, ["contents", "ids", "scores"], top_k)

return (
results["contents"].tolist(),
results["ids"].tolist(),
results["scores"].tolist(),
)


def flashrank_run_model(input_texts, tokenizer, session, batch_size: int):
batch_input_texts = make_batch(input_texts, batch_size)
results = []

for batch_texts in tqdm(batch_input_texts):
input_text = tokenizer.encode_batch(batch_texts)
input_ids = np.array([e.ids for e in input_text])
token_type_ids = np.array([e.type_ids for e in input_text])
attention_mask = np.array([e.attention_mask for e in input_text])

use_token_type_ids = token_type_ids is not None and not np.all(
token_type_ids == 0
)

onnx_input = {
"input_ids": input_ids.astype(np.int64),
"attention_mask": attention_mask.astype(np.int64),
}
if use_token_type_ids:
onnx_input["token_type_ids"] = token_type_ids.astype(np.int64)

outputs = session.run(None, onnx_input)

logits = outputs[0]

if logits.shape[1] == 1:
scores = 1 / (1 + np.exp(-logits.flatten()))
else:
exp_logits = np.exp(logits)
scores = exp_logits[:, 1] / np.sum(exp_logits, axis=1)
results.extend(scores)
return results
2 changes: 2 additions & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def get_support_modules(module_name: str) -> Callable:
"autorag.nodes.passagereranker",
"MixedbreadAIReranker",
),
"flashrank_reranker": ("autorag.nodes.passagereranker", "FlashRankReranker"),
"FlashRankReranker": ("autorag.nodes.passagereranker", "FlashRankReranker"),
# passage_filter
"pass_passage_filter": ("autorag.nodes.passagefilter", "PassPassageFilter"),
"similarity_threshold_cutoff": (
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.nodes.passagereranker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ autorag.nodes.passagereranker.flag\_embedding\_llm module
:undoc-members:
:show-inheritance:

autorag.nodes.passagereranker.flashrank module
----------------------------------------------

.. automodule:: autorag.nodes.passagereranker.flashrank
:members:
:undoc-members:
:show-inheritance:

autorag.nodes.passagereranker.jina module
-----------------------------------------

Expand Down
29 changes: 29 additions & 0 deletions docs/source/nodes/passage_reranker/flashrank_reranker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
myst:
html_meta:
title: AutoRAG - FlashRank Reranker
description: Learn about flashrank reranker module in AutoRAG
keywords: AutoRAG,RAG,Advanced RAG,Reranker,FlashRank Reranker
---
# FlashRank Reranker
[FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) is the Ultra-lite & Super-fast Python library to add re-ranking to your existing search & retrieval pipelines.

It is based on SoTA cross-encoders, with gratitude to all the model owners.

## **Module Parameters**

- **batch** : The size of a batch. If you have limited CUDA memory, decrease the size of the batch. (default: 64)
- **model** : The type of model id or path you want to use for reranking. Default is id ""ms-marco-TinyBERT-L-2-v2"".
- You can get the list of available models from [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank.)
```{admonition} Note
“rank_zephyr_7b_v1_full” is an llm based reranker that uses llama-cpp.
Due to issues with parallel inference, “rank_zephyr_7b_v1_full” is not currently supported by AutoRAG.
```

## **Example config.yaml**

```yaml
- module_type: flashrank_reranker
batch: 32
model: "ms-marco-MiniLM-L-12-v2"
```
2 changes: 1 addition & 1 deletion docs/source/nodes/passage_reranker/openvino_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ myst:
keywords: AutoRAG,RAG,Advanced RAG,Reranker,OpenVINO Reranker
---
# OpenVINO Reranker
[OpenVINO™]() is an open-source toolkit for optimizing and deploying AI inference. The OpenVINO™ Runtime supports various hardware devices including x86 and ARM CPUs, and Intel GPUs. It can help to boost deep learning performance in Computer Vision, Automatic Speech Recognition, Natural Language Processing and other common tasks.
[OpenVINO™](https://github.com/openvinotoolkit/openvino) is an open-source toolkit for optimizing and deploying AI inference. The OpenVINO™ Runtime supports various hardware devices including x86 and ARM CPUs, and Intel GPUs. It can help to boost deep learning performance in Computer Vision, Automatic Speech Recognition, Natural Language Processing and other common tasks.

Hugging Face rerank model can be supported by OpenVINO through `OpenVINOReranker` class.

Expand Down
1 change: 1 addition & 0 deletions docs/source/nodes/passage_reranker/passage_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@ time_reranker.md
openvino_reranker.md
voyageai_reranker.md
mixedbreadai_reranker.md
flashrank_reranker.md
```
1 change: 1 addition & 0 deletions sample_config/rag/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ node_lines:
- module_type: openvino_reranker
- module_type: voyageai_reranker
- module_type: mixedbreadai_reranker
- module_type: flashrank_reranker
- node_type: passage_filter
strategy:
metrics: [ retrieval_f1, retrieval_recall, retrieval_precision ]
Expand Down
50 changes: 50 additions & 0 deletions tests/autorag/nodes/passagereranker/test_flashrank_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

from autorag.nodes.passagereranker import FlashRankReranker
from tests.autorag.nodes.passagereranker.test_passage_reranker_base import (
base_reranker_test,
base_reranker_node_test,
queries_example,
contents_example,
ids_example,
project_dir,
previous_result,
)
from tests.delete_tests import is_github_action


@pytest.fixture
def flashrank_instance():
return FlashRankReranker(project_dir, "ms-marco-TinyBERT-L-2-v2")


@pytest.mark.skipif(is_github_action(), reason="Skipping this test on GitHub Actions")
def test_flashrank_reranker(flashrank_instance):
top_k = 1
contents_result, id_result, score_result = flashrank_instance._pure(
queries_example, contents_example, ids_example, top_k
)
base_reranker_test(contents_result, id_result, score_result, top_k)


@pytest.mark.skipif(is_github_action(), reason="Skipping this test on GitHub Actions")
def test_flashrank_reranker_batch_one(flashrank_instance):
top_k = 1
batch = 1
contents_result, id_result, score_result = flashrank_instance._pure(
queries_example,
contents_example,
ids_example,
top_k,
batch=batch,
)
base_reranker_test(contents_result, id_result, score_result, top_k)


@pytest.mark.skipif(is_github_action(), reason="Skipping this test on GitHub Actions")
def test_flashrank_reranker_node():
top_k = 1
result_df = FlashRankReranker.run_evaluator(
project_dir=project_dir, previous_result=previous_result, top_k=top_k
)
base_reranker_node_test(result_df, top_k)