Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): update eval #766

Merged
merged 27 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ prep:
predict dataset pipeline='ask':
@poetry run python -u eval/prediction.py --file {{dataset}} --pipeline {{pipeline}}

eval prediction_result:
@poetry run python -u eval/evaluation.py --file {{prediction_result}}
eval prediction_result semantics='--no-semantics':
@poetry run python -u eval/evaluation.py --file {{prediction_result}} {{semantics}}

demo:
poetry run streamlit run demo/app.py
Expand Down
14 changes: 10 additions & 4 deletions wren-ai-service/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ The dataset curation process is used to prepare the evaluation dataset for the W

## Eval Dataset Preparation(If using Spider 1.0 dataset)

```cli
just prep
```

This command will do two things:
1. download Spider 1.0 dataset in `wren-ai-service/tools/dev/spider1.0`; and there are two folders inside: database and spider_data
- database: it contains test data. It's downloaded from [this repo](https://github.com/taoyds/test-suite-sql-eval).
- spider_data: it contains table schema, ground truths(question sql pairs), etc. For more information, please refer to [this repo](https://github.com/taoyds/spider).
2. prepare evaluation dataset and put them in `wren-ai-service/eval/dataset`. File name of eval dataset for Spider would look like this: `spider_<db_name>_eval_dataset.toml`

```cli
just prep
```

## Evaluation Dataset Schema

- dataset_id(UUID)
Expand Down Expand Up @@ -58,6 +58,12 @@ The evaluation process is used to assess the prediction results of the Wren AI s
just eval <prediction-result>
```

Note: If you would like to enable semantics comparison between SQLs by LLM in order to improve the accuracy metric, please fill in Open AI API key in `.env` file in `wren-ai-service/eval` and add `--semantics` to the end of the command like following:

```cli
just eval <prediction-result> --semantics
```

The evaluation results will be presented on Langfuse as follows:

![shallow_trace_example](../docs/imgs/shallow_trace_example.png)
Expand Down
36 changes: 36 additions & 0 deletions wren-ai-service/eval/add_samples_to_toml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse

import tomlkit

from eval.utils import (
get_next_few_items_circular,
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--toml", type=str, help="The toml file name", required=True)
args = parser.parse_args()

if args.toml:
# read toml
with open(f"eval/dataset/{args.toml}", "r") as f:
doc = tomlkit.parse(f.read())

# get the list of question-sql pairs for generating sample values
ground_truth_list = [
{"question": element["question"], "sql": element["sql"]}
for element in doc["eval_dataset"]
]

# utilize utils.get_next_few_items_circular, put n samples in the eval dataset
new_dataset = []
for i, element in enumerate(doc["eval_dataset"]):
samples = get_next_few_items_circular(ground_truth_list, i)
element["samples"] = samples
new_dataset.append(element)

# write toml
doc["eval_dataset"] = new_dataset

with open(f"eval/dataset/added_samples_{args.toml}", "w") as f:
f.write(tomlkit.dumps(doc, sort_keys=True))
2 changes: 1 addition & 1 deletion wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
WREN_IBIS_ENDPOINT,
get_contexts_from_sqls,
get_data_from_wren_engine_with_sqls,
get_openai_client,
get_question_sql_pairs,
is_sql_valid,
prettify_sql,
Expand All @@ -25,6 +24,7 @@
from eval.utils import (
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
prepare_duckdb_init_sql,
prepare_duckdb_session_sql,
)
Expand Down
9 changes: 0 additions & 9 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@
ddl_converter = DDLConverter()


def get_openai_client(
api_key: str = os.getenv("OPENAI_API_KEY"), timeout: float = TIMEOUT_SECONDS
) -> AsyncClient:
return AsyncClient(
api_key=api_key,
timeout=timeout,
)


async def is_sql_valid(
sql: str,
data_source: str,
Expand Down
24 changes: 19 additions & 5 deletions wren-ai-service/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
def formatter(prediction: dict, meta: dict) -> dict:
retrieval_context = [str(context) for context in prediction["retrieval_context"]]
context = [str(context) for context in prediction["context"]]
enable_spider_metrics = "spider" in meta.get("evaluation_dataset", "").lower()
enable_rewrite = any(
dataset in meta.get("evaluation_dataset", "").lower() for dataset in ["spider"]
)

return {
"input": prediction["input"],
Expand All @@ -32,6 +36,8 @@ def formatter(prediction: dict, meta: dict) -> dict:
"trace_id": prediction["trace_id"],
"trace_url": prediction["trace_url"],
"catalog": meta.get("catalog", None),
"enable_spider_metrics": enable_spider_metrics,
"enable_rewrite": enable_rewrite,
},
}

Expand All @@ -44,8 +50,14 @@ def parse_args() -> Tuple[str]:
type=str,
help="Eval the prediction result in the outputs/predictions directory",
)
args = parser.parse_args()
return f"outputs/predictions/{args.file}"
parser.add_argument(
"--semantics",
"-S",
default=False,
action=argparse.BooleanOptionalAction,
help="Whether use the LLM(OpenAI's gpt-4o-mini) to help check semantics of sqls to improve accuracy metrics",
)
return parser.parse_args()


class Evaluator:
Expand Down Expand Up @@ -122,17 +134,19 @@ def _average_score(self, meta: dict) -> None:


if __name__ == "__main__":
path = parse_args()
args = parse_args()

dotenv.load_dotenv()
utils.load_env_vars()

predicted_file = parse_toml(path)
predicted_file = parse_toml(f"outputs/predictions/{args.file}")
meta = predicted_file["meta"]
predictions = predicted_file["predictions"]

dataset = parse_toml(meta["evaluation_dataset"])
metrics = pipelines.metrics_initiator(meta["pipeline"], dataset["mdl"])
metrics = pipelines.metrics_initiator(
meta["pipeline"], dataset["mdl"], args.semantics
)

evaluator = Evaluator(**metrics)
evaluator.eval(meta, predictions)
Expand Down
20 changes: 20 additions & 0 deletions wren-ai-service/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .accuracy import AccuracyMetric, AccuracyMultiCandidateMetric
from .answer_relevancy import AnswerRelevancyMetric
from .context_precision import ContextualPrecisionMetric
from .context_recall import ContextualRecallMetric
from .context_relevancy import ContextualRelevancyMetric
from .faithfulness import FaithfulnessMetric
from .spider.exact_match import ExactMatchAccuracy
from .spider.exec_match import ExecutionAccuracy

__all__ = [
"AccuracyMetric",
"AccuracyMultiCandidateMetric",
"AnswerRelevancyMetric",
"ContextualPrecisionMetric",
"ContextualRecallMetric",
"ContextualRelevancyMetric",
"FaithfulnessMetric",
"ExactMatchAccuracy",
"ExecutionAccuracy",
]
201 changes: 201 additions & 0 deletions wren-ai-service/eval/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import asyncio
import re
import traceback

import orjson
import pandas as pd
from deepeval.evaluate import TestResult
from deepeval.metrics import BaseMetric
from deepeval.test_case import LLMTestCase

from eval.utils import get_data_from_wren_engine, get_openai_client


class AccuracyMetric(BaseMetric):
def __init__(self, engine_config: dict, enable_semantics_comparison: bool = False):
self.threshold = 0
self.score = 0
self._engine_config = engine_config
self._enable_semantics_comparison = enable_semantics_comparison
if self._enable_semantics_comparison:
self._openai_client = get_openai_client()

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

def _is_subset(self, expected: pd.DataFrame, actual: pd.DataFrame) -> bool:
if not set(expected.columns).issubset(set(actual.columns)):
return False

common_columns = sorted(expected.columns)

expected_sorted = expected[common_columns]
actual_sorted = actual[common_columns]
# Ensure that the data types are the same
actual_sorted = actual_sorted.astype(expected_sorted.dtypes.to_dict())

merged = pd.merge(
actual_sorted,
expected_sorted,
on=common_columns,
how="left",
indicator=True,
)
return all(merged["_merge"] == "both")

def _count_partial_matches(
self, expected: pd.DataFrame, actual: pd.DataFrame
) -> int:
intersection = set(expected.columns).intersection(set(actual.columns))
common_columns = sorted(intersection)
if not common_columns:
return 0

expected_sorted = expected[common_columns]
actual_sorted = actual[common_columns]
# Ensure that the data types are the same
actual_sorted = actual_sorted.astype(expected_sorted.dtypes.to_dict())

merged = pd.merge(
actual_sorted,
expected_sorted,
on=common_columns,
how="left",
indicator=True,
)
if all(merged["_merge"] == "both"):
return len(intersection) / len(expected.columns)
else:
return 0

def _rewrite_sql(self, sql: str) -> str:
# Pattern to match double quotes after WHERE clause, including multiple occurrences
pattern = r'(WHERE\s+.*?)(")(.+?)(")(.*)$'
replacement = r"\1'\3'\5"

# Apply the replacement repeatedly until no more changes
new_sql = re.sub(pattern, replacement, sql, flags=re.IGNORECASE | re.DOTALL)
while new_sql != sql:
sql = new_sql
new_sql = re.sub(pattern, replacement, sql, flags=re.IGNORECASE | re.DOTALL)

return sql

async def _retrieve_data(self, sql: str) -> pd.DataFrame:
response = await get_data_from_wren_engine(sql=sql, **self._engine_config)

df = pd.DataFrame(**response)
sorted_columns = sorted(df.columns)
return df[sorted_columns].sort_values(by=sorted_columns)

async def _check_sql_semantics(self, expected_sql: str, actual_sql: str):
_system_prompt = (
"### TASK ### \n"
+ "You are a great data anlyst, please carefully check the semantics of two given SQLs if they are the same. \n"
+ "The output should be a JSON format with the following schema: \n"
+ "{ \n"
+ ' "reasoning": <REASONING_STRING> \n'
+ ' "same": <BOOL> \n'
+ "}"
)

_user_prompt = (
"### QUESTION ### \n"
+ f"Expected SQL: {expected_sql} \n"
+ f"Actual SQL: {actual_sql} \n"
+ "\n"
+ "Please think step by step"
)

response = await self._openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": _system_prompt},
{"role": "user", "content": _user_prompt},
],
response_format={"type": "json_object"},
)

print(
f"response of _check_sql_semantics: {response.choices[0].message.content}"
)

return 1 if orjson.loads(response.choices[0].message.content)["same"] else 0

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
try:
enable_rewrite = test_case.additional_metadata.get("enable_rewrite", False)
rewritten_expected_output = test_case.expected_output

if enable_rewrite:
rewritten_expected_output = self._rewrite_sql(test_case.expected_output)

expected_dataset = await self._retrieve_data(rewritten_expected_output)
actual_dataset = await self._retrieve_data(test_case.actual_output)

print(f"expected columns: {set(expected_dataset.columns)}")
print(f"actual columns: {set(actual_dataset.columns)}")

if expected_dataset.equals(actual_dataset) or self._is_subset(
expected_dataset, actual_dataset
):
self.success = True
self.score = 1
return self.score

self.score = self._count_partial_matches(expected_dataset, actual_dataset)
# use llm to check sql semantics
if self.score == 0 and self._enable_semantics_comparison:
# TODO: we may need to upload the sql semantics result to langfuse
print(f"before _check_sql_semantics: {self.score}")
print(f"expected sql: {rewritten_expected_output}")
print(f"actual sql: {test_case.actual_output}")
self.score = await self._check_sql_semantics(
rewritten_expected_output, test_case.actual_output
)
print(f"after _check_sql_semantics: {self.score}")
except Exception as e:
self.error = f"Error occurred while evaluating the metric: {e}"
traceback.print_exc()

# if didn't pass any of the above checks
self.success = False
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "Accuracy(column-based)"


class AccuracyMultiCandidateMetric(BaseMetric):
def __init__(self):
self.threshold = 0
self.score = 0
self._questions = {}

def collect(self, test_case: LLMTestCase, result: TestResult):
for metric in result.metrics_data:
if metric.name != "Accuracy(column-based)":
continue

# or 0 to avoid when metric.error is exist
self._questions[test_case.input] = (
self._questions.get(test_case.input, 0) or metric.score or 0
)

def measure(self):
if not self._questions:
return 0
self.score = sum(self._questions.values()) / len(self._questions)
self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "Accuracy(question-based)"
Loading
Loading