Skip to content

Commit

Permalink
chore(wren-ai-service): update eval (#766)
Browse files Browse the repository at this point in the history
* refine readme

* add samples to sql generation

* refine prompt

* fix few shot bugs

* update

* refactor metrics

* fix

* fix f401 issue

* fix

* fix: replacing double quotes to single for the query with multi where clause

* add check_sql_semantics by llm

* allow opt out semantics comparison

* disable semantics comparison by default

* update readme

* update readme

* feat: add a flag to enable sql rewrite

* feat: use another condition to enable exact and exec match metrics

* make semantics option default

* update add_samples_to_toml

* refactor

* refine

* refine prompt

* fix test bugs

* fix bug

* add TODO comment

* fix: leave an empty dict to additional metadata in test case to avoid error

---------

Co-authored-by: Pao-Sheng Wang <paooap.oappao@gmail.com>
Co-authored-by: Aster Sun <imastr114@gmail.com>
  • Loading branch information
3 people authored Oct 18, 2024
1 parent 0b2e9af commit cf31416
Show file tree
Hide file tree
Showing 22 changed files with 543 additions and 333 deletions.
4 changes: 2 additions & 2 deletions wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ prep:
predict dataset pipeline='ask':
@poetry run python -u eval/prediction.py --file {{dataset}} --pipeline {{pipeline}}

eval prediction_result:
@poetry run python -u eval/evaluation.py --file {{prediction_result}}
eval prediction_result semantics='--no-semantics':
@poetry run python -u eval/evaluation.py --file {{prediction_result}} {{semantics}}

demo:
poetry run streamlit run demo/app.py
Expand Down
14 changes: 10 additions & 4 deletions wren-ai-service/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ The dataset curation process is used to prepare the evaluation dataset for the W

## Eval Dataset Preparation(If using Spider 1.0 dataset)

```cli
just prep
```

This command will do two things:
1. download Spider 1.0 dataset in `wren-ai-service/tools/dev/spider1.0`; and there are two folders inside: database and spider_data
- database: it contains test data. It's downloaded from [this repo](https://github.com/taoyds/test-suite-sql-eval).
- spider_data: it contains table schema, ground truths(question sql pairs), etc. For more information, please refer to [this repo](https://github.com/taoyds/spider).
2. prepare evaluation dataset and put them in `wren-ai-service/eval/dataset`. File name of eval dataset for Spider would look like this: `spider_<db_name>_eval_dataset.toml`

```cli
just prep
```

## Evaluation Dataset Schema

- dataset_id(UUID)
Expand Down Expand Up @@ -58,6 +58,12 @@ The evaluation process is used to assess the prediction results of the Wren AI s
just eval <prediction-result>
```

Note: If you would like to enable semantics comparison between SQLs by LLM in order to improve the accuracy metric, please fill in Open AI API key in `.env` file in `wren-ai-service/eval` and add `--semantics` to the end of the command like following:

```cli
just eval <prediction-result> --semantics
```

The evaluation results will be presented on Langfuse as follows:

![shallow_trace_example](../docs/imgs/shallow_trace_example.png)
Expand Down
36 changes: 36 additions & 0 deletions wren-ai-service/eval/add_samples_to_toml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse

import tomlkit

from eval.utils import (
get_next_few_items_circular,
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--toml", type=str, help="The toml file name", required=True)
args = parser.parse_args()

if args.toml:
# read toml
with open(f"eval/dataset/{args.toml}", "r") as f:
doc = tomlkit.parse(f.read())

# get the list of question-sql pairs for generating sample values
ground_truth_list = [
{"question": element["question"], "sql": element["sql"]}
for element in doc["eval_dataset"]
]

# utilize utils.get_next_few_items_circular, put n samples in the eval dataset
new_dataset = []
for i, element in enumerate(doc["eval_dataset"]):
samples = get_next_few_items_circular(ground_truth_list, i)
element["samples"] = samples
new_dataset.append(element)

# write toml
doc["eval_dataset"] = new_dataset

with open(f"eval/dataset/added_samples_{args.toml}", "w") as f:
f.write(tomlkit.dumps(doc, sort_keys=True))
2 changes: 1 addition & 1 deletion wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
WREN_IBIS_ENDPOINT,
get_contexts_from_sqls,
get_data_from_wren_engine_with_sqls,
get_openai_client,
get_question_sql_pairs,
is_sql_valid,
prettify_sql,
Expand All @@ -25,6 +24,7 @@
from eval.utils import (
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
prepare_duckdb_init_sql,
prepare_duckdb_session_sql,
)
Expand Down
9 changes: 0 additions & 9 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@
ddl_converter = DDLConverter()


def get_openai_client(
api_key: str = os.getenv("OPENAI_API_KEY"), timeout: float = TIMEOUT_SECONDS
) -> AsyncClient:
return AsyncClient(
api_key=api_key,
timeout=timeout,
)


async def is_sql_valid(
sql: str,
data_source: str,
Expand Down
24 changes: 19 additions & 5 deletions wren-ai-service/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
def formatter(prediction: dict, meta: dict) -> dict:
retrieval_context = [str(context) for context in prediction["retrieval_context"]]
context = [str(context) for context in prediction["context"]]
enable_spider_metrics = "spider" in meta.get("evaluation_dataset", "").lower()
enable_rewrite = any(
dataset in meta.get("evaluation_dataset", "").lower() for dataset in ["spider"]
)

return {
"input": prediction["input"],
Expand All @@ -32,6 +36,8 @@ def formatter(prediction: dict, meta: dict) -> dict:
"trace_id": prediction["trace_id"],
"trace_url": prediction["trace_url"],
"catalog": meta.get("catalog", None),
"enable_spider_metrics": enable_spider_metrics,
"enable_rewrite": enable_rewrite,
},
}

Expand All @@ -44,8 +50,14 @@ def parse_args() -> Tuple[str]:
type=str,
help="Eval the prediction result in the outputs/predictions directory",
)
args = parser.parse_args()
return f"outputs/predictions/{args.file}"
parser.add_argument(
"--semantics",
"-S",
default=False,
action=argparse.BooleanOptionalAction,
help="Whether use the LLM(OpenAI's gpt-4o-mini) to help check semantics of sqls to improve accuracy metrics",
)
return parser.parse_args()


class Evaluator:
Expand Down Expand Up @@ -122,17 +134,19 @@ def _average_score(self, meta: dict) -> None:


if __name__ == "__main__":
path = parse_args()
args = parse_args()

dotenv.load_dotenv()
utils.load_env_vars()

predicted_file = parse_toml(path)
predicted_file = parse_toml(f"outputs/predictions/{args.file}")
meta = predicted_file["meta"]
predictions = predicted_file["predictions"]

dataset = parse_toml(meta["evaluation_dataset"])
metrics = pipelines.metrics_initiator(meta["pipeline"], dataset["mdl"])
metrics = pipelines.metrics_initiator(
meta["pipeline"], dataset["mdl"], args.semantics
)

evaluator = Evaluator(**metrics)
evaluator.eval(meta, predictions)
Expand Down
20 changes: 20 additions & 0 deletions wren-ai-service/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .accuracy import AccuracyMetric, AccuracyMultiCandidateMetric
from .answer_relevancy import AnswerRelevancyMetric
from .context_precision import ContextualPrecisionMetric
from .context_recall import ContextualRecallMetric
from .context_relevancy import ContextualRelevancyMetric
from .faithfulness import FaithfulnessMetric
from .spider.exact_match import ExactMatchAccuracy
from .spider.exec_match import ExecutionAccuracy

__all__ = [
"AccuracyMetric",
"AccuracyMultiCandidateMetric",
"AnswerRelevancyMetric",
"ContextualPrecisionMetric",
"ContextualRecallMetric",
"ContextualRelevancyMetric",
"FaithfulnessMetric",
"ExactMatchAccuracy",
"ExecutionAccuracy",
]
201 changes: 201 additions & 0 deletions wren-ai-service/eval/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import asyncio
import re
import traceback

import orjson
import pandas as pd
from deepeval.evaluate import TestResult
from deepeval.metrics import BaseMetric
from deepeval.test_case import LLMTestCase

from eval.utils import get_data_from_wren_engine, get_openai_client


class AccuracyMetric(BaseMetric):
def __init__(self, engine_config: dict, enable_semantics_comparison: bool = False):
self.threshold = 0
self.score = 0
self._engine_config = engine_config
self._enable_semantics_comparison = enable_semantics_comparison
if self._enable_semantics_comparison:
self._openai_client = get_openai_client()

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

def _is_subset(self, expected: pd.DataFrame, actual: pd.DataFrame) -> bool:
if not set(expected.columns).issubset(set(actual.columns)):
return False

common_columns = sorted(expected.columns)

expected_sorted = expected[common_columns]
actual_sorted = actual[common_columns]
# Ensure that the data types are the same
actual_sorted = actual_sorted.astype(expected_sorted.dtypes.to_dict())

merged = pd.merge(
actual_sorted,
expected_sorted,
on=common_columns,
how="left",
indicator=True,
)
return all(merged["_merge"] == "both")

def _count_partial_matches(
self, expected: pd.DataFrame, actual: pd.DataFrame
) -> int:
intersection = set(expected.columns).intersection(set(actual.columns))
common_columns = sorted(intersection)
if not common_columns:
return 0

expected_sorted = expected[common_columns]
actual_sorted = actual[common_columns]
# Ensure that the data types are the same
actual_sorted = actual_sorted.astype(expected_sorted.dtypes.to_dict())

merged = pd.merge(
actual_sorted,
expected_sorted,
on=common_columns,
how="left",
indicator=True,
)
if all(merged["_merge"] == "both"):
return len(intersection) / len(expected.columns)
else:
return 0

def _rewrite_sql(self, sql: str) -> str:
# Pattern to match double quotes after WHERE clause, including multiple occurrences
pattern = r'(WHERE\s+.*?)(")(.+?)(")(.*)$'
replacement = r"\1'\3'\5"

# Apply the replacement repeatedly until no more changes
new_sql = re.sub(pattern, replacement, sql, flags=re.IGNORECASE | re.DOTALL)
while new_sql != sql:
sql = new_sql
new_sql = re.sub(pattern, replacement, sql, flags=re.IGNORECASE | re.DOTALL)

return sql

async def _retrieve_data(self, sql: str) -> pd.DataFrame:
response = await get_data_from_wren_engine(sql=sql, **self._engine_config)

df = pd.DataFrame(**response)
sorted_columns = sorted(df.columns)
return df[sorted_columns].sort_values(by=sorted_columns)

async def _check_sql_semantics(self, expected_sql: str, actual_sql: str):
_system_prompt = (
"### TASK ### \n"
+ "You are a great data anlyst, please carefully check the semantics of two given SQLs if they are the same. \n"
+ "The output should be a JSON format with the following schema: \n"
+ "{ \n"
+ ' "reasoning": <REASONING_STRING> \n'
+ ' "same": <BOOL> \n'
+ "}"
)

_user_prompt = (
"### QUESTION ### \n"
+ f"Expected SQL: {expected_sql} \n"
+ f"Actual SQL: {actual_sql} \n"
+ "\n"
+ "Please think step by step"
)

response = await self._openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": _system_prompt},
{"role": "user", "content": _user_prompt},
],
response_format={"type": "json_object"},
)

print(
f"response of _check_sql_semantics: {response.choices[0].message.content}"
)

return 1 if orjson.loads(response.choices[0].message.content)["same"] else 0

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
try:
enable_rewrite = test_case.additional_metadata.get("enable_rewrite", False)
rewritten_expected_output = test_case.expected_output

if enable_rewrite:
rewritten_expected_output = self._rewrite_sql(test_case.expected_output)

expected_dataset = await self._retrieve_data(rewritten_expected_output)
actual_dataset = await self._retrieve_data(test_case.actual_output)

print(f"expected columns: {set(expected_dataset.columns)}")
print(f"actual columns: {set(actual_dataset.columns)}")

if expected_dataset.equals(actual_dataset) or self._is_subset(
expected_dataset, actual_dataset
):
self.success = True
self.score = 1
return self.score

self.score = self._count_partial_matches(expected_dataset, actual_dataset)
# use llm to check sql semantics
if self.score == 0 and self._enable_semantics_comparison:
# TODO: we may need to upload the sql semantics result to langfuse
print(f"before _check_sql_semantics: {self.score}")
print(f"expected sql: {rewritten_expected_output}")
print(f"actual sql: {test_case.actual_output}")
self.score = await self._check_sql_semantics(
rewritten_expected_output, test_case.actual_output
)
print(f"after _check_sql_semantics: {self.score}")
except Exception as e:
self.error = f"Error occurred while evaluating the metric: {e}"
traceback.print_exc()

# if didn't pass any of the above checks
self.success = False
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "Accuracy(column-based)"


class AccuracyMultiCandidateMetric(BaseMetric):
def __init__(self):
self.threshold = 0
self.score = 0
self._questions = {}

def collect(self, test_case: LLMTestCase, result: TestResult):
for metric in result.metrics_data:
if metric.name != "Accuracy(column-based)":
continue

# or 0 to avoid when metric.error is exist
self._questions[test_case.input] = (
self._questions.get(test_case.input, 0) or metric.score or 0
)

def measure(self):
if not self._questions:
return 0
self.score = sum(self._questions.values()) / len(self._questions)
self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "Accuracy(question-based)"
Loading

0 comments on commit cf31416

Please sign in to comment.