diff --git a/autorag/nodes/passagereranker/__init__.py b/autorag/nodes/passagereranker/__init__.py index c0b0cea56..6643bd1ef 100644 --- a/autorag/nodes/passagereranker/__init__.py +++ b/autorag/nodes/passagereranker/__init__.py @@ -12,3 +12,4 @@ from .time_reranker import TimeReranker from .upr import Upr from .openvino import OpenVINOReranker +from .voyageai import VoyageAIReranker diff --git a/autorag/nodes/passagereranker/voyageai.py b/autorag/nodes/passagereranker/voyageai.py new file mode 100644 index 000000000..2868189d2 --- /dev/null +++ b/autorag/nodes/passagereranker/voyageai.py @@ -0,0 +1,109 @@ +import os +from typing import List, Tuple +import pandas as pd +import voyageai + +from autorag.nodes.passagereranker.base import BasePassageReranker +from autorag.utils.util import result_to_dataframe, get_event_loop, process_batch + + +class VoyageAIReranker(BasePassageReranker): + def __init__(self, project_dir: str, *args, **kwargs): + super().__init__(project_dir) + api_key = kwargs.pop("api_key", None) + api_key = os.getenv("VOYAGE_API_KEY", None) if api_key is None else api_key + if api_key is None: + raise KeyError( + "Please set the API key for VoyageAI rerank in the environment variable VOYAGE_API_KEY " + "or directly set it on the config YAML file." + ) + + self.voyage_client = voyageai.AsyncClient(api_key=api_key) + + def __del__(self): + del self.voyage_client + super().__del__() + + @result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"]) + def pure(self, previous_result: pd.DataFrame, *args, **kwargs): + queries, contents, scores, ids = self.cast_to_run(previous_result) + top_k = kwargs.pop("top_k") + batch = kwargs.pop("batch", 8) + model = kwargs.pop("model", "rerank-2") + truncation = kwargs.pop("truncation", True) + return self._pure(queries, contents, ids, top_k, model, batch, truncation) + + def _pure( + self, + queries: List[str], + contents_list: List[List[str]], + ids_list: List[List[str]], + top_k: int, + model: str = "rerank-2", + batch: int = 8, + truncation: bool = True, + ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: + """ + Rerank a list of contents with VoyageAI rerank models. + You can get the API key from https://docs.voyageai.com/docs/api-key-and-installation and set it in the environment variable VOYAGE_API_KEY. + + :param queries: The list of queries to use for reranking + :param contents_list: The list of lists of contents to rerank + :param ids_list: The list of lists of ids retrieved from the initial ranking + :param top_k: The number of passages to be retrieved + :param model: The model name for VoyageAI rerank. + You can choose between "rerank-2" and "rerank-2-lite". + Default is "rerank-2". + :param batch: The number of queries to be processed in a batch + :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. + :return: Tuple of lists containing the reranked contents, ids, and scores + """ + tasks = [ + voyageai_rerank_pure( + self.voyage_client, model, query, contents, ids, top_k, truncation + ) + for query, contents, ids in zip(queries, contents_list, ids_list) + ] + loop = get_event_loop() + results = loop.run_until_complete(process_batch(tasks, batch)) + + content_result, id_result, score_result = zip(*results) + + return list(content_result), list(id_result), list(score_result) + + +async def voyageai_rerank_pure( + voyage_client: voyageai.AsyncClient, + model: str, + query: str, + documents: List[str], + ids: List[str], + top_k: int, + truncation: bool = True, +) -> Tuple[List[str], List[str], List[float]]: + """ + Rerank a list of contents with VoyageAI rerank models. + + :param voyage_client: The Voyage Client to use for reranking + :param model: The model name for VoyageAI rerank + :param query: The query to use for reranking + :param documents: The list of contents to rerank + :param ids: The list of ids corresponding to the documents + :param top_k: The number of passages to be retrieved + :param truncation: Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. + :return: Tuple of lists containing the reranked contents, ids, and scores + """ + rerank_results = await voyage_client.rerank( + model=model, + query=query, + documents=documents, + top_k=top_k, + truncation=truncation, + ) + reranked_scores: List[float] = list( + map(lambda x: x.relevance_score, rerank_results.results) + ) + indices = list(map(lambda x: x.index, rerank_results.results)) + reranked_contents: List[str] = list(map(lambda i: documents[i], indices)) + reranked_ids: List[str] = list(map(lambda i: ids[i], indices)) + return reranked_contents, reranked_ids, reranked_scores diff --git a/autorag/support.py b/autorag/support.py index db4d94c06..0a4cb3a0e 100644 --- a/autorag/support.py +++ b/autorag/support.py @@ -119,6 +119,8 @@ def get_support_modules(module_name: str) -> Callable: "TimeReranker": ("autorag.nodes.passagereranker", "TimeReranker"), "openvino_reranker": ("autorag.nodes.passagereranker", "OpenVINOReranker"), "OpenVINOReranker": ("autorag.nodes.passagereranker", "OpenVINOReranker"), + "voyageai_reranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"), + "VoyageAIReranker": ("autorag.nodes.passagereranker", "VoyageAIReranker"), # passage_filter "pass_passage_filter": ("autorag.nodes.passagefilter", "PassPassageFilter"), "similarity_threshold_cutoff": ( diff --git a/docs/source/api_spec/autorag.nodes.passagereranker.rst b/docs/source/api_spec/autorag.nodes.passagereranker.rst index 9ed368bfc..82a85721f 100644 --- a/docs/source/api_spec/autorag.nodes.passagereranker.rst +++ b/docs/source/api_spec/autorag.nodes.passagereranker.rst @@ -132,6 +132,14 @@ autorag.nodes.passagereranker.upr module :undoc-members: :show-inheritance: +autorag.nodes.passagereranker.voyageai module +--------------------------------------------- + +.. automodule:: autorag.nodes.passagereranker.voyageai + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/docs/source/nodes/passage_reranker/passage_reranker.md b/docs/source/nodes/passage_reranker/passage_reranker.md index f39efcd46..c6e6eeab9 100644 --- a/docs/source/nodes/passage_reranker/passage_reranker.md +++ b/docs/source/nodes/passage_reranker/passage_reranker.md @@ -71,4 +71,5 @@ flag_embedding_reranker.md flag_embedding_llm_reranker.md time_reranker.md openvino_reranker.md +voyageai_reranker.md ``` diff --git a/docs/source/nodes/passage_reranker/voyageai_reranker.md b/docs/source/nodes/passage_reranker/voyageai_reranker.md new file mode 100644 index 000000000..acb3a3f34 --- /dev/null +++ b/docs/source/nodes/passage_reranker/voyageai_reranker.md @@ -0,0 +1,52 @@ +--- +myst: + html_meta: + title: AutoRAG - VoyageAI Reranker + description: Learn about voyage ai reranker module in AutoRAG + keywords: AutoRAG,RAG,Advanced RAG,Reranker,VoyageAI +--- +# voyageai_reranker + +The `voyageai reranker` module is a reranker from [VoyageAI](https://www.voyageai.com/). +It supports powerful and fast reranker for passage retrieval. + +## Before Usage + +At first, you need to get the VoyageAI API key from [here](https://docs.voyageai.com/docs/api-key-and-installation). + +Next, you can set your VoyageAI API key in the environment variable "VOYAGE_API_KEY". + +```bash +export VOYAGE_API_KEY=your_voyageai_api_key +``` + +Or, you can set your VoyageAI API key in the config.yaml file directly. + +```yaml +- module_type: voyageai_reranker + api_key: your_voyageai_api_key +``` + +## **Module Parameters** + +- **model** : The type of model you want to use for reranking. Default is "rerank-2" and you can change + it to "rerank-2-lite" +- **api_key** : The voyageai api key. +- **truncation** : Whether to truncate the input to satisfy the 'context length limit' on the query and the documents. Default is True. + +## **Example config.yaml** + +```yaml +- module_type: voyageai_reranker + api_key: your_voyageai_api_key + model: rerank-2 +``` + +### Supported Model Names + +You can see the supported model names [here](https://docs.voyageai.com/docs/reranker). + +| Model Name | +|:-------------:| +| rerank-2 | +| rerank-2-lite | diff --git a/requirements.txt b/requirements.txt index a32faea29..cb3818bbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ FlagEmbedding # for flag embedding reranker llmlingua # for longllmlingua peft optimum[openvino,nncf] # for openvino reranker +voyageai # for voyageai reranker ### API server ### quart diff --git a/sample_config/rag/full.yaml b/sample_config/rag/full.yaml index d516b5d85..65cbe5672 100644 --- a/sample_config/rag/full.yaml +++ b/sample_config/rag/full.yaml @@ -74,6 +74,7 @@ node_lines: - module_type: flag_embedding_llm_reranker - module_type: time_reranker - module_type: openvino_reranker + - module_type: voyageai_reranker - node_type: passage_filter strategy: metrics: [ retrieval_f1, retrieval_recall, retrieval_precision ] diff --git a/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py new file mode 100644 index 000000000..cdcf90aef --- /dev/null +++ b/tests/autorag/nodes/passagereranker/test_voyageai_reranker.py @@ -0,0 +1,95 @@ +from unittest.mock import patch + +import pytest + +from collections import namedtuple +import voyageai +from voyageai.object.reranking import RerankingObject, RerankingResult +from voyageai.api_resources import VoyageResponse + +import autorag +from autorag.nodes.passagereranker import VoyageAIReranker +from tests.autorag.nodes.passagereranker.test_passage_reranker_base import ( + queries_example, + contents_example, + ids_example, + base_reranker_test, + project_dir, + previous_result, + base_reranker_node_test, +) + + +async def mock_voyageai_reranker_pure( + self, + query, + documents, + model, + top_k, + truncation, +): + mock_documents = ["Document 1 content", "Document 2 content", "Document 3 content"] + + # Mock response data + mock_response_data = [ + {"index": 1, "relevance_score": 0.8}, + {"index": 2, "relevance_score": 0.2}, + {"index": 0, "relevance_score": 0.1}, + ] + + # Mock usage data + mock_usage = {"total_tokens": 100} + + # Create a mock VoyageResponse object + mock_response = VoyageResponse() + mock_response.data = [ + namedtuple("MockData", d.keys())(*d.values()) for d in mock_response_data + ] + mock_response.usage = namedtuple("MockUsage", mock_usage.keys())( + *mock_usage.values() + ) + + # Create an instance of RerankingObject using the mock data + object = RerankingObject(documents=mock_documents, response=mock_response) + + if top_k == 1: + object.results = [ + RerankingResult(index=1, document="nodonggunn", relevance_score=0.8) + ] + return object + + +@pytest.fixture +def voyageai_reranker_instance(): + return VoyageAIReranker(project_dir, api_key="mock_api_key") + + +@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure) +def test_voyageai_reranker(voyageai_reranker_instance): + top_k = 3 + contents_result, id_result, score_result = voyageai_reranker_instance._pure( + queries_example, contents_example, ids_example, top_k + ) + base_reranker_test(contents_result, id_result, score_result, top_k) + + +@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure) +def test_voyageai_reranker_batch_one(voyageai_reranker_instance): + top_k = 1 + batch = 1 + contents_result, id_result, score_result = voyageai_reranker_instance._pure( + queries_example, contents_example, ids_example, top_k, batch=batch + ) + base_reranker_test(contents_result, id_result, score_result, top_k) + + +@patch.object(voyageai.AsyncClient, "rerank", mock_voyageai_reranker_pure) +def test_voyageai_reranker_node(): + top_k = 1 + result_df = VoyageAIReranker.run_evaluator( + project_dir=project_dir, + previous_result=previous_result, + top_k=top_k, + api_key="mock_api_key", + ) + base_reranker_node_test(result_df, top_k)