From 947503ab59da0b0ddb96b0d0ccaf551adf95b3f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Fri, 6 Sep 2024 18:53:15 +0800 Subject: [PATCH] Fix session_id bug --- src/pai_rag/app/web/rag_client.py | 16 ++++------------ src/pai_rag/app/web/tabs/agent_tab.py | 3 +-- src/pai_rag/app/web/tabs/data_analysis_tab.py | 3 +-- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 17ac7052..6bd43753 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -198,18 +198,14 @@ def query_search( raise RagApiError(code=r.status_code, msg=r.text) if not stream: response = dotdict(json.loads(r.text)) - yield self._format_rag_response( - text, response, session_id=session_id, stream=stream - ) + yield self._format_rag_response(text, response, stream=stream) else: full_content = "" for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True): chunk_response = dotdict(json.loads(chunk)) full_content += chunk_response.delta chunk_response.delta = full_content - yield self._format_rag_response( - text, chunk_response, session_id=session_id, stream=stream - ) + yield self._format_rag_response(text, chunk_response, stream=stream) def query_data_analysis( self, @@ -228,18 +224,14 @@ def query_data_analysis( raise RagApiError(code=r.status_code, msg=r.text) if not stream: response = dotdict(json.loads(r.text)) - yield self._format_rag_response( - text, response, session_id=session_id, stream=stream - ) + yield self._format_rag_response(text, response, stream=stream) else: full_content = "" for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True): chunk_response = dotdict(json.loads(chunk)) full_content += chunk_response.delta chunk_response.delta = full_content - yield self._format_rag_response( - text, chunk_response, session_id=session_id, stream=stream - ) + yield self._format_rag_response(text, chunk_response, stream=stream) def query_llm( self, diff --git a/src/pai_rag/app/web/tabs/agent_tab.py b/src/pai_rag/app/web/tabs/agent_tab.py index 99707d2f..75fe2fb9 100644 --- a/src/pai_rag/app/web/tabs/agent_tab.py +++ b/src/pai_rag/app/web/tabs/agent_tab.py @@ -27,9 +27,8 @@ def respond(agent_question, agent_chatbot): def clear_history(chatbot): + rag_client.clear_history() chatbot = [] - global current_session_id - current_session_id = None return chatbot diff --git a/src/pai_rag/app/web/tabs/data_analysis_tab.py b/src/pai_rag/app/web/tabs/data_analysis_tab.py index 447462b0..b972ab22 100644 --- a/src/pai_rag/app/web/tabs/data_analysis_tab.py +++ b/src/pai_rag/app/web/tabs/data_analysis_tab.py @@ -66,9 +66,8 @@ def analysis_respond(question, chatbot): def clear_history(chatbot): + rag_client.clear_history() chatbot = [] - global current_session_id - current_session_id = None return chatbot