diff --git a/README.md b/README.md index 0bbef2eea..22765f0c8 100644 --- a/README.md +++ b/README.md @@ -45,19 +45,21 @@ evaluate_dataset.to_parquet('your/path/to/evaluate_dataset.parquet') ### Evaluate your data to various RAG modules ```python -from autorag import Evaluator +from autorag.evaluator import Evaluator -evaluator = Evaluator(qa_path='your/path/to/qa.parquet', corpus_path='your/path/to/corpus.parquet') +evaluator = Evaluator(qa_data_path='your/path/to/qa.parquet', corpus_data_path='your/path/to/corpus.parquet') evaluator.start_trial('your/path/to/config.yaml') ``` or you can use command line interface ```bash -autorag evaluate --config your/path/to/default_config.yaml +autorag evaluate --config your/path/to/default_config.yaml --qa_data_path your/path/to/qa.parquet --corpus_data_path your/path/to/corpus.parquet ``` ### Evaluate your custom RAG pipeline ```python +from autorag.evaluate import evaluate_retrieval, evaluate_generation + @evaluate def your_custom_rag_pipeline(query: str) -> str: # your custom rag pipeline @@ -65,15 +67,15 @@ def your_custom_rag_pipeline(query: str) -> str: # also, you can evaluate each RAG module one by one -@evaluate_retrieval -def your_retrieval_module(query: str, top_k: int = 5) -> List[uuid.UUID]: +@evaluate_retrieval(retrieval_gt=retrieval_gt, metrics=['retrieval_f1', 'retrieval_recall', 'retrieval_precision']) +def your_retrieval_module(query: str, top_k: int = 5) -> tuple[list[list[str]], list[list[str]], list[list[float]]]: # your custom retrieval module - return retrieved_ids + return retrieved_contents, scores, retrieved_ids -@evaluate_generation -def your_llm_module(prompt: str) -> str: +@evaluate_generation(generation_gt=generation_gt, metrics=['bleu', 'rouge']) +def your_llm_module(prompt: str) -> list[str]: # your custom llm module - return answer + return answers ``` ### Config yaml file diff --git a/autorag/__init__.py b/autorag/__init__.py index 779538d4d..09f871a78 100644 --- a/autorag/__init__.py +++ b/autorag/__init__.py @@ -1,16 +1,19 @@ -__version__ = '0.0.1' - +import os import logging import logging.config import sys from rich.logging import RichHandler -from .evaluator import Evaluator - from llama_index import OpenAIEmbedding from llama_index.llms import OpenAI +root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +version_path = os.path.join(root_path, 'VERSION') + +with open(version_path, 'r') as f: + __version__ = f.read().strip() + embedding_models = { 'openai': OpenAIEmbedding(), } diff --git a/autorag/evaluate/__init__.py b/autorag/evaluate/__init__.py index 48ceb6831..6d01de8e7 100644 --- a/autorag/evaluate/__init__.py +++ b/autorag/evaluate/__init__.py @@ -1 +1,2 @@ from .retrieval import evaluate_retrieval +from .generation import evaluate_generation diff --git a/autorag/evaluator.py b/autorag/evaluator.py index 7eb76f04d..42358c179 100644 --- a/autorag/evaluator.py +++ b/autorag/evaluator.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import List, Dict +import click import pandas as pd import yaml @@ -14,7 +15,6 @@ from autorag.schema.node import module_type_exists from autorag.utils import cast_qa_dataset, cast_corpus_dataset - logger = logging.getLogger("AutoRAG") @@ -132,3 +132,28 @@ def _load_node_lines(yaml_path: str) -> Dict[str, List[Node]]: node_line_dict[node_line['node_line_name']] = list( map(lambda node: Node.from_dict(node), node_line['nodes'])) return node_line_dict + + +@click.group() +def cli(): + pass + + +@click.command() +@click.option('--config', '-c', help='Path to config yaml file. Must be yaml or yml file.', type=str) +@click.option('--qa_data_path', help='Path to QA dataset. Must be parquet file.', type=str) +@click.option('--corpus_data_path', help='Path to corpus dataset. Must be parquet file.', type=str) +def evaluate(config, qa_data_path, corpus_data_path): + if not config.endswith('.yaml') and not config.endswith('.yml'): + raise ValueError(f"Config file {config} is not a parquet file.") + if not os.path.exists(config): + raise ValueError(f"Config file {config} does not exist.") + evaluator = Evaluator(qa_data_path, corpus_data_path) + evaluator.start_trial(config) + logger.info('Evaluation complete.') + + +cli.add_command(evaluate) + +if __name__ == '__main__': + cli() diff --git a/pyproject.toml b/pyproject.toml index 57d8a9651..18435347e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,3 +40,6 @@ exclude = ["tests"] pythonpath = ["."] testpaths = ["tests"] addopts = ["--import-mode=importlib"] # default is prepend + +[project.entry-points.console_scripts] +autorag = "autorag.evaluator:cli" diff --git a/requirements.txt b/requirements.txt index a30cbda51..5ab75f2e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ sacrebleu # for bleu score evaluate # for meteor and other scores rouge_score # for rouge score rich # for pretty logging +click # for cli diff --git a/tests/autorag/test_evaluator.py b/tests/autorag/test_evaluator.py index 65b637d65..ba1869fbf 100644 --- a/tests/autorag/test_evaluator.py +++ b/tests/autorag/test_evaluator.py @@ -1,11 +1,12 @@ import os.path import pathlib import shutil +import subprocess import pandas as pd import pytest -from autorag import Evaluator +from autorag.evaluator import Evaluator from autorag.nodes.retrieval import bm25 from autorag.nodes.retrieval.run import run_retrieval_node from autorag.schema import Node @@ -106,3 +107,18 @@ def test_start_trial(evaluator): assert trial_summary_df['best_module_name'][0] == 'bm25' assert trial_summary_df['best_module_params'][0] == {'top_k': 50} assert trial_summary_df['best_execution_time'][0] > 0 + + +def test_evaluator_cli(evaluator): + result = subprocess.run(['autorag', 'evaluate', '--config', os.path.join(resource_dir, 'simple.yaml'), + '--qa_data_path', os.path.join(resource_dir, 'qa_data_sample.parquet'), + '--corpus_data_path', os.path.join(resource_dir, 'corpus_data_sample.parquet')]) + assert result.returncode == 0 + # check if the files are created + assert os.path.exists(os.path.join(os.getcwd(), '0')) + assert os.path.exists(os.path.join(os.getcwd(), 'data')) + assert os.path.exists(os.path.join(os.getcwd(), 'resources')) + assert os.path.exists(os.path.join(os.getcwd(), 'trial.json')) + assert os.path.exists(os.path.join(os.getcwd(), '0', 'retrieve_node_line')) + assert os.path.exists(os.path.join(os.getcwd(), '0', 'retrieve_node_line', 'retrieval')) + assert os.path.exists(os.path.join(os.getcwd(), '0', 'retrieve_node_line', 'retrieval', 'bm25=>top_k_50.parquet'))