From 521cf6407461d00ade4ab5cbe1b014a9b2b7ad0b Mon Sep 17 00:00:00 2001 From: wwxxzz Date: Fri, 13 Sep 2024 14:04:21 +0800 Subject: [PATCH] Add eval tab (#214) --- src/pai_rag/app/web/rag_client.py | 2 +- src/pai_rag/app/web/tabs/eval_tab.py | 1 - src/pai_rag/app/web/webui.py | 7 ++++--- src/pai_rag/config/settings.toml | 4 ++++ .../modules/evaluation/batch_eval_runner.py | 15 ++++++--------- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 989a4189..8aa3c2ed 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -468,7 +468,7 @@ def evaluate_for_response_stage(self): response = dotdict(json.loads(r.text)) if r.status_code != HTTPStatus.OK: raise RagApiError(code=r.status_code, msg=response.message) - print("evaluate_for_response_stage response", response) + return response rag_client = RagWebClient() diff --git a/src/pai_rag/app/web/tabs/eval_tab.py b/src/pai_rag/app/web/tabs/eval_tab.py index 6b39e4d7..8b9d0c87 100644 --- a/src/pai_rag/app/web/tabs/eval_tab.py +++ b/src/pai_rag/app/web/tabs/eval_tab.py @@ -44,7 +44,6 @@ def eval_response_stage(): response_res = rag_client.evaluate_for_response_stage() except RagApiError as api_error: raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") - formatted_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") pd_results = { "Metrics": ["Faithfulness", "Correctness", "Similarity", "LastModified"], diff --git a/src/pai_rag/app/web/webui.py b/src/pai_rag/app/web/webui.py index 748001bb..cf08699a 100644 --- a/src/pai_rag/app/web/webui.py +++ b/src/pai_rag/app/web/webui.py @@ -8,6 +8,7 @@ from pai_rag.app.web.tabs.chat_tab import create_chat_tab from pai_rag.app.web.tabs.agent_tab import create_agent_tab from pai_rag.app.web.tabs.data_analysis_tab import create_data_analysis_tab +from pai_rag.app.web.tabs.eval_tab import create_evaluation_tab from pai_rag.app.web.element_manager import elem_manager from pai_rag.app.web.ui_constants import ( DEFAULT_CSS_STYPE, @@ -56,15 +57,15 @@ def make_homepage(): with gr.Tab("\N{fire} Chat"): chat_elements = create_chat_tab() elem_manager.add_elems(chat_elements) - # with gr.Tab("\N{rocket} Evaluation"): - # eval_elements = create_evaluation_tab() - # elem_manager.add_elems(eval_elements) with gr.Tab("\N{rocket} Agent"): agent_elements = create_agent_tab() elem_manager.add_elems(agent_elements) with gr.Tab("\N{bar chart} Data Analysis"): analysis_elements = create_data_analysis_tab() elem_manager.add_elems(analysis_elements) + with gr.Tab("\N{rocket} Evaluation"): + eval_elements = create_evaluation_tab() + elem_manager.add_elems(eval_elements) homepage.load( resume_ui, outputs=elem_manager.get_elem_list(), concurrency_limit=None ) diff --git a/src/pai_rag/config/settings.toml b/src/pai_rag/config/settings.toml index 18ae99aa..cc7314b1 100644 --- a/src/pai_rag/config/settings.toml +++ b/src/pai_rag/config/settings.toml @@ -69,6 +69,10 @@ name = "qwen-turbo" [rag.llm.function_calling_llm] source = "" +[rag.llm.multi_modal] +enable = false +source = "" + [rag.llm_chat_engine] type = "SimpleChatEngine" diff --git a/src/pai_rag/modules/evaluation/batch_eval_runner.py b/src/pai_rag/modules/evaluation/batch_eval_runner.py index 4e9aeeb6..0aa26d79 100644 --- a/src/pai_rag/modules/evaluation/batch_eval_runner.py +++ b/src/pai_rag/modules/evaluation/batch_eval_runner.py @@ -7,7 +7,6 @@ from llama_index.core.base.response.schema import RESPONSE_TYPE, Response from llama_index.core.evaluation.base import BaseEvaluator, EvaluationResult from pai_rag.integrations.evaluation.retrieval.evaluator import MyRetrievalEvalResult -from fastapi.concurrency import run_in_threadpool async def eval_response_worker( @@ -277,14 +276,12 @@ async def aevaluate_queries( response_jobs.append(response_worker(self.semaphore, query_engine, query)) responses = await self.asyncio_mod.gather(*response_jobs) - return await run_in_threadpool( - lambda: self.aevaluate_responses( - queries=queries, - node_ids=node_ids, - responses=responses, - references=reference_answers, - **eval_kwargs_lists, - ) + return await self.aevaluate_responses( + queries=queries, + node_ids=node_ids, + responses=responses, + references=reference_answers, + **eval_kwargs_lists, ) async def aevaluate_queries_for_retrieval(