From 9d796812bdddb7a8a3737a560dc54c7d0d28c73f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Mon, 18 Nov 2024 11:16:36 +0800 Subject: [PATCH] Fix v1 api bug --- src/pai_rag/app/api/v1/chat.py | 4 ++-- src/pai_rag/core/rag_application.py | 14 +++++++++----- src/pai_rag/core/rag_service.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/pai_rag/app/api/v1/chat.py b/src/pai_rag/app/api/v1/chat.py index 9eaad722..dbe7ea2a 100644 --- a/src/pai_rag/app/api/v1/chat.py +++ b/src/pai_rag/app/api/v1/chat.py @@ -67,7 +67,7 @@ async def aquery_retrieval(query: RetrievalQuery): @router_v1.post("/query/agent") async def aquery_agent(query: RagQuery): - response = await rag_service.aquery_agent(query) + response = await rag_service.aquery_agent_v1(query) if not query.stream: return response else: @@ -264,7 +264,7 @@ async def upload_datasheet( @router_v1.post("/query/data_analysis") async def aquery_analysis(query: RagQuery): - response = await rag_service.aquery_analysis(query) + response = await rag_service.aquery_analysis_v1(query) if not query.stream: return response else: diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index ad3ce45b..56257bf0 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -197,7 +197,7 @@ async def aquery( intent = await intent_router.aselect(str_or_query_bundle=new_question) logger.info(f"[IntentDetection] Routing query to {intent}.") if intent == Intents.TOOL: - return await self.aquery_agent(query) + return await self.aquery_agent(query, sse_version=sse_version) elif intent == Intents.WEBSEARCH: chat_type = RagChatType.WEB elif intent == Intents.NL2SQL: @@ -260,7 +260,9 @@ async def aquery( sse_version=sse_version, ) - async def aquery_agent(self, query: RagQuery) -> RagResponse: + async def aquery_agent( + self, query: RagQuery, sse_version: SseVersion = SseVersion.V0 + ) -> RagResponse: """Query answer from RAG App via web search asynchronously. Generate answer from agent's achat interface. @@ -277,7 +279,7 @@ async def aquery_agent(self, query: RagQuery) -> RagResponse: agent = resolve_agent(self.config) if query.stream: response = await agent.astream_chat(query.question) - return event_generator_async(response) + return event_generator_async(response, sse_version=sse_version) else: response = await agent.achat(query.question) return RagResponse(answer=response.response) @@ -306,7 +308,9 @@ async def aload_agent_config(self, agent_cfg_path: str): else: return f"The agent config path {agent_cfg_path} not exists." - async def aquery_analysis(self, query: RagQuery): + async def aquery_analysis( + self, query: RagQuery, sse_version: SseVersion = SseVersion.V0 + ): """Query answer from RAG App asynchronously. Generate answer from Data Analysis interface. @@ -361,4 +365,4 @@ async def aquery_analysis(self, query: RagQuery): if not query.stream: return RagResponse(answer=response.response, **result_info) else: - return event_generator_async(response=response, extra_info=result_info) + return event_generator_async(response=response, sse_version=sse_version) diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 66f9ae4a..69a3d978 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -180,6 +180,13 @@ async def aquery_agent(self, query: RagQuery) -> RagResponse: logger.error(traceback.format_exc()) raise UserInputError(f"Query RAG Agent failed: {ex}") + async def aquery_agent_v1(self, query: RagQuery) -> RagResponse: + try: + return await self.rag.aquery_agent(query, sse_version=SseVersion.V1) + except Exception as ex: + logger.error(traceback.format_exc()) + raise UserInputError(f"Query RAG Agent failed: {ex}") + async def aload_agent_config(self, agent_cfg_path: str): try: return await self.rag.aload_agent_config(agent_cfg_path) @@ -194,5 +201,12 @@ async def aquery_analysis(self, query: RagQuery): logger.error(traceback.format_exc()) raise UserInputError(f"Query Analysis failed: {ex}") + async def aquery_analysis_v1(self, query: RagQuery): + try: + return await self.rag.aquery_analysis(query, sse_version=SseVersion.V1) + except Exception as ex: + logger.error(traceback.format_exc()) + raise UserInputError(f"Query Analysis failed: {ex}") + rag_service = RagService()