Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VoyageAI Reranker module #809

Merged
merged 15 commits into from
Oct 8, 2024
1 change: 1 addition & 0 deletions autorag/nodes/passagereranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .tart.tart import Tart
from .time_reranker import TimeReranker
from .upr import Upr
from .voyageai import VoyageAIReranker
106 changes: 106 additions & 0 deletions autorag/nodes/passagereranker/voyageai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from typing import List, Tuple
import pandas as pd
import voyageai

from autorag.nodes.passagereranker.base import BasePassageReranker
from autorag.utils.util import result_to_dataframe


class VoyageAIReranker(BasePassageReranker):
def __init__(self, project_dir: str, *args, **kwargs):
super().__init__(project_dir)
api_key = kwargs.pop("api_key", None)
api_key = os.getenv("VOYAGE_API_KEY", None) if api_key is None else api_key
if api_key is None:
raise KeyError(
"Please set the API key for VoyageAI rerank in the environment variable VOYAGE_API_KEY "
"or directly set it on the config YAML file."
)

self.voyage_client = voyageai.Client(api_key=api_key)

def __del__(self):
del self.voyage_client
super().__del__()

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
top_k = kwargs.pop("top_k")
model = kwargs.pop("model", "rerank-2")
truncation = kwargs.pop("truncation", True)
return self._pure(queries, contents, ids, top_k, model, truncation)

def _pure(
self,
queries: List[str],
contents_list: List[List[str]],
ids_list: List[List[str]],
top_k: int,
model: str = "rerank-2",
truncation: bool = True,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents with VoyageAI rerank models.
You can get the API key from https://docs.voyageai.com/docs/api-key-and-installation and set it in the environment variable VOYAGE_API_KEY.

:param queries: The list of queries to use for reranking
:param contents_list: The list of lists of contents to rerank
:param ids_list: The list of lists of ids retrieved from the initial ranking
:param top_k: The number of passages to be retrieved
:param model: The model name for VoyageAI rerank.
You can choose between "rerank-2" and "rerank-2-lite".
Default is "rerank-2".
:param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents.
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
content_result, id_result, score_result = zip(
*[
voyageai_rerank_pure(
self.voyage_client, model, query, document, ids, top_k, truncation
)
for query, document, ids in zip(queries, contents_list, ids_list)
]
)

return content_result, id_result, score_result


def voyageai_rerank_pure(
voyage_client: voyageai.Client,
model: str,
query: str,
documents: List[str],
ids: List[str],
top_k: int,
truncation: bool = True,
) -> Tuple[List[str], List[str], List[float]]:
"""
Rerank a list of contents with VoyageAI rerank models.

:param voyage_client: The Voyage Client to use for reranking
:param model: The model name for VoyageAI rerank
:param query: The query to use for reranking
:param documents: The list of contents to rerank
:param ids: The list of ids corresponding to the documents
:param top_k: The number of passages to be retrieved
:param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents.
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
rerank_results = voyage_client.rerank(
bwook00 marked this conversation as resolved.
Show resolved Hide resolved
model=model,
query=query,
documents=documents,
top_k=top_k,
truncation=truncation,
)
reranked_scores: List[float] = list(
map(lambda x: x.relevance_score, rerank_results.results)
)
reranked_contents: List[str] = list(
map(lambda x: x.document, rerank_results.results)
)
indices = list(map(lambda x: x.index, rerank_results.results))
reranked_ids: List[str] = list(map(lambda i: ids[i], indices))
return reranked_contents, reranked_ids, reranked_scores
2 changes: 2 additions & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def get_support_modules(module_name: str) -> Callable:
),
"time_reranker": ("autorag.nodes.passagereranker", "TimeReranker"),
"TimeReranker": ("autorag.nodes.passagereranker", "TimeReranker"),
"voyageai_reranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"),
"VoyageAIReranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"),
# passage_filter
"pass_passage_filter": ("autorag.nodes.passagefilter", "PassPassageFilter"),
"similarity_threshold_cutoff": (
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api_spec/autorag.nodes.passagereranker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ autorag.nodes.passagereranker.upr module
:undoc-members:
:show-inheritance:

autorag.nodes.passagereranker.voyageai module
---------------------------------------------

.. automodule:: autorag.nodes.passagereranker.voyageai
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------

Expand Down
1 change: 1 addition & 0 deletions docs/source/nodes/passage_reranker/passage_reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ sentence_transformer_reranker.md
flag_embedding_reranker.md
flag_embedding_llm_reranker.md
time_reranker.md
voyageai_reranker.md
```
52 changes: 52 additions & 0 deletions docs/source/nodes/passage_reranker/voyageai_reranker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
---
myst:
html_meta:
title: AutoRAG - VoyageAI Reranker
description: Learn about voyage ai reranker module in AutoRAG
keywords: AutoRAG,RAG,Advanced RAG,Reranker,VoyageAI
---
# voyageai_reranker

The `voyageai reranker` module is a reranker from [VoyageAI](https://www.voyageai.com/).
It supports powerful and fast reranker for passage retrieval.

## Before Usage

At first, you need to get the VoyageAI API key from [here](https://docs.voyageai.com/docs/api-key-and-installation).

Next, you can set your VoyageAI API key in the environment variable "VOYAGE_API_KEY".

```bash
export VOYAGE_API_KEY=your_voyageai_api_key
```

Or, you can set your VoyageAI API key in the config.yaml file directly.

```yaml
- module_type: voyageai_reranker
api_key: your_voyageai_api_key
```

## **Module Parameters**

- **model** : The type of model you want to use for reranking. Default is "rerank-2" and you can change
it to "rerank-2-lite"
- **api_key** : The voyageai api key.
- **truncation** : Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. Default is True.

## **Example config.yaml**

```yaml
- module_type: voyageai_reranker
api_key: your_voyageai_api_key
model: rerank-2
```

### Supported Model Names

You can see the supported model names [here](https://docs.voyageai.com/docs/reranker).

| Model Name |
|:-------------:|
| rerank-2 |
| rerank-2-lite |
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ sentence-transformers # for sentence transformer reranker
FlagEmbedding # for flag embedding reranker
llmlingua # for longllmlingua
peft
voyageai # for voyageai reranker

### LlamaIndex ###
llama-index>=0.11.0
Expand Down
1 change: 1 addition & 0 deletions sample_config/rag/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ node_lines:
- module_type: flag_embedding_reranker
- module_type: flag_embedding_llm_reranker
- module_type: time_reranker
- module_type: voyageai_reranker
- node_type: passage_filter
strategy:
metrics: [ retrieval_f1, retrieval_recall, retrieval_precision ]
Expand Down
68 changes: 68 additions & 0 deletions tests/autorag/nodes/passagereranker/test_voyageai_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from unittest.mock import patch

import pytest

import autorag
from autorag.nodes.passagereranker import VoyageAIReranker
from tests.autorag.nodes.passagereranker.test_passage_reranker_base import (
queries_example,
contents_example,
ids_example,
base_reranker_test,
project_dir,
previous_result,
base_reranker_node_test,
)


def mock_voyageai_reranker_pure(
voyage_client, model, query, documents, ids, top_k, truncation
):
if query == queries_example[0]:
return (
[documents[1], documents[2], documents[0]][:top_k],
[ids[1], ids[2], ids[0]][:top_k],
[0.8, 0.2, 0.1][:top_k],
)
elif query == queries_example[1]:
return (
[documents[1], documents[0], documents[2]][:top_k],
[ids[1], ids[0], ids[2]][:top_k],
[0.8, 0.2, 0.1][:top_k],
)
else:
raise ValueError(f"Unexpected query: {query}")


@pytest.fixture
def voyageai_reranker_instance():
return VoyageAIReranker(project_dir, api_key="mock_api_key")


@patch.object(
autorag.nodes.passagereranker.voyageai,
"voyageai_rerank_pure",
mock_voyageai_reranker_pure,
)
def test_voyageai_reranker(voyageai_reranker_instance):
top_k = 3
contents_result, id_result, score_result = voyageai_reranker_instance._pure(
queries_example, contents_example, ids_example, top_k
)
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(
autorag.nodes.passagereranker.voyageai,
"voyageai_rerank_pure",
mock_voyageai_reranker_pure,
)
def test_voyageai_reranker_node():
top_k = 1
result_df = VoyageAIReranker.run_evaluator(
project_dir=project_dir,
previous_result=previous_result,
top_k=top_k,
api_key="mock_api_key",
)
base_reranker_node_test(result_df, top_k)