Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cli command execution for Evaluate initialize and start_trial #50

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,37 @@ evaluate_dataset.to_parquet('your/path/to/evaluate_dataset.parquet')

### Evaluate your data to various RAG modules
```python
from autorag import Evaluator
from autorag.evaluator import Evaluator

evaluator = Evaluator(qa_path='your/path/to/qa.parquet', corpus_path='your/path/to/corpus.parquet')
evaluator = Evaluator(qa_data_path='your/path/to/qa.parquet', corpus_data_path='your/path/to/corpus.parquet')
evaluator.start_trial('your/path/to/config.yaml')
```
or you can use command line interface
```bash
autorag evaluate --config your/path/to/default_config.yaml
autorag evaluate --config your/path/to/default_config.yaml --qa_path your/path/to/qa.parquet --corpus_path your/path/to/corpus.parquet
vkehfdl1 marked this conversation as resolved.
Show resolved Hide resolved
```

### Evaluate your custom RAG pipeline

```python
from autorag.evaluate import evaluate_retrieval, evaluate_generation

@evaluate
def your_custom_rag_pipeline(query: str) -> str:
# your custom rag pipeline
return answer


# also, you can evaluate each RAG module one by one
@evaluate_retrieval
def your_retrieval_module(query: str, top_k: int = 5) -> List[uuid.UUID]:
@evaluate_retrieval(retrieval_gt=retrieval_gt, metrics=['retrieval_f1', 'retrieval_recall', 'retrieval_precision'])
def your_retrieval_module(query: str, top_k: int = 5) -> tuple[list[list[str]], list[list[str]], list[list[float]]]:
# your custom retrieval module
return retrieved_ids
return retrieved_contents, scores, retrieved_ids

@evaluate_generation
def your_llm_module(prompt: str) -> str:
@evaluate_generation(generation_gt=generation_gt, metrics=['bleu', 'rouge'])
def your_llm_module(prompt: str) -> list[str]:
# your custom llm module
return answer
return answers
```

### Config yaml file
Expand Down
11 changes: 7 additions & 4 deletions autorag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
__version__ = '0.0.1'

import os
import logging
import logging.config
import sys

from rich.logging import RichHandler

from .evaluator import Evaluator

from llama_index import OpenAIEmbedding
from llama_index.llms import OpenAI

root_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
version_path = os.path.join(root_path, 'VERSION')

with open(version_path, 'r') as f:
__version__ = f.read().strip()

embedding_models = {
'openai': OpenAIEmbedding(),
}
Expand Down
1 change: 1 addition & 0 deletions autorag/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .retrieval import evaluate_retrieval
from .generation import evaluate_generation
27 changes: 26 additions & 1 deletion autorag/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import List, Dict

import click
import pandas as pd
import yaml

Expand All @@ -14,7 +15,6 @@
from autorag.schema.node import module_type_exists
from autorag.utils import cast_qa_dataset, cast_corpus_dataset


logger = logging.getLogger("AutoRAG")


Expand Down Expand Up @@ -132,3 +132,28 @@ def _load_node_lines(yaml_path: str) -> Dict[str, List[Node]]:
node_line_dict[node_line['node_line_name']] = list(
map(lambda node: Node.from_dict(node), node_line['nodes']))
return node_line_dict


@click.group()
def cli():
pass


@click.command()
@click.option('--config', '-c', help='Path to config yaml file. Must be yaml or yml file.', type=str)
@click.option('--qa_data_path', help='Path to QA dataset. Must be parquet file.', type=str)
@click.option('--corpus_data_path', help='Path to corpus dataset. Must be parquet file.', type=str)
def evaluate(config, qa_data_path, corpus_data_path):
if not config.endswith('.yaml') and not config.endswith('.yml'):
raise ValueError(f"Config file {config} is not a parquet file.")
if not os.path.exists(config):
raise ValueError(f"Config file {config} does not exist.")
evaluator = Evaluator(qa_data_path, corpus_data_path)
evaluator.start_trial(config)
logger.info('Evaluation complete.')


cli.add_command(evaluate)

if __name__ == '__main__':
cli()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ exclude = ["tests"]
pythonpath = ["."]
testpaths = ["tests"]
addopts = ["--import-mode=importlib"] # default is prepend

[project.entry-points.console_scripts]
autorag = "autorag.evaluator:cli"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ sacrebleu # for bleu score
evaluate # for meteor and other scores
rouge_score # for rouge score
rich # for pretty logging
click # for cli
18 changes: 17 additions & 1 deletion tests/autorag/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os.path
import pathlib
import shutil
import subprocess

import pandas as pd
import pytest

from autorag import Evaluator
from autorag.evaluator import Evaluator
from autorag.nodes.retrieval import bm25
from autorag.nodes.retrieval.run import run_retrieval_node
from autorag.schema import Node
Expand Down Expand Up @@ -106,3 +107,18 @@ def test_start_trial(evaluator):
assert trial_summary_df['best_module_name'][0] == 'bm25'
assert trial_summary_df['best_module_params'][0] == {'top_k': 50}
assert trial_summary_df['best_execution_time'][0] > 0


def test_evaluator_cli(evaluator):
result = subprocess.run(['autorag', 'evaluate', '--config', os.path.join(resource_dir, 'simple.yaml'),
'--qa_data_path', os.path.join(resource_dir, 'qa_data_sample.parquet'),
'--corpus_data_path', os.path.join(resource_dir, 'corpus_data_sample.parquet')])
assert result.returncode == 0
# check if the files are created
assert os.path.exists(os.path.join(os.getcwd(), '0'))
assert os.path.exists(os.path.join(os.getcwd(), 'data'))
assert os.path.exists(os.path.join(os.getcwd(), 'resources'))
assert os.path.exists(os.path.join(os.getcwd(), 'trial.json'))
assert os.path.exists(os.path.join(os.getcwd(), '0', 'retrieve_node_line'))
assert os.path.exists(os.path.join(os.getcwd(), '0', 'retrieve_node_line', 'retrieval'))
assert os.path.exists(os.path.join(os.getcwd(), '0', 'retrieve_node_line', 'retrieval', 'bm25=>top_k_50.parquet'))