Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VoyageAI Reranker module #809

Merged
merged 15 commits into from
Oct 8, 2024
1 change: 1 addition & 0 deletions autorag/nodes/passagereranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .time_reranker import TimeReranker
from .upr import Upr
from .openvino import OpenVINOReranker
from .voyageai import VoyageAIReranker
109 changes: 109 additions & 0 deletions autorag/nodes/passagereranker/voyageai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
from typing import List, Tuple
import pandas as pd
import voyageai

from autorag.nodes.passagereranker.base import BasePassageReranker
from autorag.utils.util import result_to_dataframe, get_event_loop, process_batch


class VoyageAIReranker(BasePassageReranker):
def __init__(self, project_dir: str, *args, **kwargs):
super().__init__(project_dir)
api_key = kwargs.pop("api_key", None)
api_key = os.getenv("VOYAGE_API_KEY", None) if api_key is None else api_key
if api_key is None:
raise KeyError(
"Please set the API key for VoyageAI rerank in the environment variable VOYAGE_API_KEY "
"or directly set it on the config YAML file."
)

self.voyage_client = voyageai.AsyncClient(api_key=api_key)

def __del__(self):
del self.voyage_client
super().__del__()

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
top_k = kwargs.pop("top_k")
batch = kwargs.pop("batch", 8)
model = kwargs.pop("model", "rerank-2")
truncation = kwargs.pop("truncation", True)
return self._pure(queries, contents, ids, top_k, model, batch, truncation)

def _pure(
self,
queries: List[str],
contents_list: List[List[str]],
ids_list: List[List[str]],
top_k: int,
model: str = "rerank-2",
batch: int = 8,
truncation: bool = True,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents with VoyageAI rerank models.
You can get the API key from https://docs.voyageai.com/docs/api-key-and-installation and set it in the environment variable VOYAGE_API_KEY.

:param queries: The list of queries to use for reranking
:param contents_list: The list of lists of contents to rerank
:param ids_list: The list of lists of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param model: The model name for VoyageAI rerank.
You can choose between "rerank-2" and "rerank-2-lite".
Default is "rerank-2".
:param batch: The number of queries to be processed in a batch
:param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents.
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
tasks = [
voyageai_rerank_pure(
self.voyage_client, model, query, contents, ids, top_k, truncation
)
for query, contents, ids in zip(queries, contents_list, ids_list)
]
loop = get_event_loop()
results = loop.run_until_complete(process_batch(tasks, batch))

content_result, id_result, score_result = zip(*results)

return list(content_result), list(id_result), list(score_result)


async def voyageai_rerank_pure(
voyage_client: voyageai.AsyncClient,
model: str,
query: str,
documents: List[str],
ids: List[str],
top_k: int,
truncation: bool = True,
) -> Tuple[List[str], List[str], List[float]]:
"""
Rerank a list of contents with VoyageAI rerank models.

:param voyage_client: The Voyage Client to use for reranking
:param model: The model name for VoyageAI rerank
:param query: The query to use for reranking
:param documents: The list of contents to rerank
:param ids: The list of ids corresponding to the documents
:param top_k: The number of passages to be retrieved
:param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents.
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
rerank_results = await voyage_client.rerank(
model=model,
query=query,
documents=documents,
top_k=top_k,
truncation=truncation,
)
reranked_scores: List[float] = list(
map(lambda x: x.relevance_score, rerank_results.results)
)
indices = list(map(lambda x: x.index, rerank_results.results))
reranked_contents: List[str] = list(map(lambda i: documents[i], indices))
reranked_ids: List[str] = list(map(lambda i: ids[i], indices))
return reranked_contents, reranked_ids, reranked_scores
2 changes: 2 additions & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def get_support_modules(module_name: str) -> Callable:
"TimeReranker": ("autorag.nodes.passagereranker", "TimeReranker"),
"openvino_reranker": ("autorag.nodes.passagereranker", "OpenVINOReranker"),
"OpenVINOReranker": ("autorag.nodes.passagereranker", "OpenVINOReranker"),
"voyageai_reranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"),
"VoyageAIReranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"),
# passage_filter
"pass_passage_filter": ("autorag.nodes.passagefilter", "PassPassageFilter"),
"similarity_threshold_cutoff": (
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.nodes.passagereranker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ autorag.nodes.passagereranker.upr module
:undoc-members:
:show-inheritance:

autorag.nodes.passagereranker.voyageai module
---------------------------------------------

.. automodule:: autorag.nodes.passagereranker.voyageai
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

Expand Down
1 change: 1 addition & 0 deletions docs/source/nodes/passage_reranker/passage_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ flag_embedding_reranker.md
flag_embedding_llm_reranker.md
time_reranker.md
openvino_reranker.md
voyageai_reranker.md
```
52 changes: 52 additions & 0 deletions docs/source/nodes/passage_reranker/voyageai_reranker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
---
myst:
html_meta:
title: AutoRAG - VoyageAI Reranker
description: Learn about voyage ai reranker module in AutoRAG
keywords: AutoRAG,RAG,Advanced RAG,Reranker,VoyageAI
---
# voyageai_reranker

The `voyageai reranker` module is a reranker from [VoyageAI](https://www.voyageai.com/).
It supports powerful and fast reranker for passage retrieval.

## Before Usage

At first, you need to get the VoyageAI API key from [here](https://docs.voyageai.com/docs/api-key-and-installation).

Next, you can set your VoyageAI API key in the environment variable "VOYAGE_API_KEY".

```bash
export VOYAGE_API_KEY=your_voyageai_api_key
```

Or, you can set your VoyageAI API key in the config.yaml file directly.

```yaml
- module_type: voyageai_reranker
api_key: your_voyageai_api_key
```

## **Module Parameters**

- **model** : The type of model you want to use for reranking. Default is "rerank-2" and you can change
it to "rerank-2-lite"
- **api_key** : The voyageai api key.
- **truncation** : Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. Default is True.

## **Example config.yaml**

```yaml
- module_type: voyageai_reranker
api_key: your_voyageai_api_key
model: rerank-2
```

### Supported Model Names

You can see the supported model names [here](https://docs.voyageai.com/docs/reranker).

| Model Name |
|:-------------:|
| rerank-2 |
| rerank-2-lite |
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ FlagEmbedding # for flag embedding reranker
llmlingua # for longllmlingua
peft
optimum[openvino,nncf] # for openvino reranker
voyageai # for voyageai reranker

### API server ###
quart
Expand Down
1 change: 1 addition & 0 deletions sample_config/rag/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ node_lines:
- module_type: flag_embedding_llm_reranker
- module_type: time_reranker
- module_type: openvino_reranker
- module_type: voyageai_reranker
- node_type: passage_filter
strategy:
metrics: [ retrieval_f1, retrieval_recall, retrieval_precision ]
Expand Down
95 changes: 95 additions & 0 deletions tests/autorag/nodes/passagereranker/test_voyageai_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from unittest.mock import patch

import pytest

from collections import namedtuple
import voyageai
from voyageai.object.reranking import RerankingObject, RerankingResult
from voyageai.api_resources import VoyageResponse

import autorag
from autorag.nodes.passagereranker import VoyageAIReranker
from tests.autorag.nodes.passagereranker.test_passage_reranker_base import (
queries_example,
contents_example,
ids_example,
base_reranker_test,
project_dir,
previous_result,
base_reranker_node_test,
)


async def mock_voyageai_reranker_pure(
self,
query,
documents,
model,
top_k,
truncation,
):
mock_documents = ["Document 1 content", "Document 2 content", "Document 3 content"]

# Mock response data
mock_response_data = [
{"index": 1, "relevance_score": 0.8},
{"index": 2, "relevance_score": 0.2},
{"index": 0, "relevance_score": 0.1},
]

# Mock usage data
mock_usage = {"total_tokens": 100}

# Create a mock VoyageResponse object
mock_response = VoyageResponse()
mock_response.data = [
namedtuple("MockData", d.keys())(*d.values()) for d in mock_response_data
]
mock_response.usage = namedtuple("MockUsage", mock_usage.keys())(
*mock_usage.values()
)

# Create an instance of RerankingObject using the mock data
object = RerankingObject(documents=mock_documents, response=mock_response)

if top_k == 1:
object.results = [
RerankingResult(index=1, document="nodonggunn", relevance_score=0.8)
]
return object


@pytest.fixture
def voyageai_reranker_instance():
return VoyageAIReranker(project_dir, api_key="mock_api_key")


@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure)
def test_voyageai_reranker(voyageai_reranker_instance):
top_k = 3
contents_result, id_result, score_result = voyageai_reranker_instance._pure(
queries_example, contents_example, ids_example, top_k
)
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure)
def test_voyageai_reranker_batch_one(voyageai_reranker_instance):
top_k = 1
batch = 1
contents_result, id_result, score_result = voyageai_reranker_instance._pure(
queries_example, contents_example, ids_example, top_k, batch=batch
)
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure)
def test_voyageai_reranker_node():
top_k = 1
result_df = VoyageAIReranker.run_evaluator(
project_dir=project_dir,
previous_result=previous_result,
top_k=top_k,
api_key="mock_api_key",
)
base_reranker_node_test(result_df, top_k)