diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 85a107f0..44a93fe9 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -151,10 +151,10 @@ def _format_rag_response( if invalid_flag == 0: run_flag = " ✓ " ref_sql = doc["metadata"].get("query_code_instruction", None) - formatted_sql_query = f"生成的sql语句为:{ref_sql}" + formatted_sql_query = f"{ref_sql}" content = ( f"""{formatted_table_name} \n""" - f"""{formatted_sql_query} \n""" + f"""生成的sql语句为:
{formatted_sql_query}""" f"""sql查询是否有效: {run_flag}""" ) else: @@ -162,10 +162,10 @@ def _format_rag_response( ref_sql = doc["metadata"].get( "generated_query_code_instruction", None ) - formatted_sql_query = f"生成的sql语句为:{ref_sql}" + formatted_sql_query = f"{ref_sql}" content = ( f"""{formatted_table_name} \n""" - f"""{formatted_sql_query} \n""" + f"""生成的sql语句为:
{formatted_sql_query}""" f"""sql查询是否有效: {run_flag}""" ) else: diff --git a/src/pai_rag/app/web/tabs/data_analysis_tab.py b/src/pai_rag/app/web/tabs/data_analysis_tab.py index 825f1ded..5b28d0c7 100644 --- a/src/pai_rag/app/web/tabs/data_analysis_tab.py +++ b/src/pai_rag/app/web/tabs/data_analysis_tab.py @@ -1,5 +1,4 @@ import os -import datetime from typing import Dict, Any, List import gradio as gr import pandas as pd @@ -44,35 +43,43 @@ def upload_file_fn(input_file): raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") -def update_setting(input_db: List[Any]): - try: - update_dict = {"analysis_type": "nl2sql"} - for element, value in input_db.items(): - update_dict[element.elem_id] = value - # print("db_config:", update_dict) +def respond(input_elements: List[Any]): + update_dict = {} + for element, value in input_elements.items(): + update_dict[element.elem_id] = value + + if update_dict["analysis_type"] == "datafile": + update_dict["analysis_type"] = "nl2pandas" + else: + update_dict["analysis_type"] = "nl2sql" + + # empty input. + if not update_dict["question"]: + yield update_dict["chatbot"] + return + # update snapshot + try: rag_client.patch_config(update_dict) - return f"[{datetime.datetime.now()}] success!" except RagApiError as api_error: raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") + question = update_dict["question"] + chatbot = update_dict["chatbot"] -# def nl2sql_prompt_update(input_prompt): -# try: -# update_dict = {"db_nl2sql_prompt": input_prompt} -# # print('update_dict:', update_dict) -# rag_client.patch_config(update_dict) -# return f"[{datetime.datetime.now()}] success!" -# except RagApiError as api_error: -# raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") - + if chatbot is not None: + chatbot.append((question, "")) -def analysis_respond(question, chatbot): - response_gen = rag_client.query_data_analysis(question, stream=True) - content = "" - chatbot.append((question, content)) - for resp in response_gen: - chatbot[-1] = (question, resp.result) + try: + response_gen = rag_client.query_data_analysis(question, stream=True) + for resp in response_gen: + chatbot[-1] = (question, resp.result) + yield chatbot + except RagApiError as api_error: + raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") + except Exception as e: + raise gr.Error(f"Error: {e}") + finally: yield chatbot @@ -96,7 +103,7 @@ def create_data_analysis_tab() -> Dict[str, Any]: ], value="datafile", label="Please choose the data analysis type", - elem_id="data_analysis_type", + elem_id="analysis_type", ) # datafile @@ -169,25 +176,6 @@ def create_data_analysis_tab() -> Dict[str, Any]: lines=4, ) - update_button = gr.Button( - "Update Settings", - elem_id="update_settings", - variant="primary", - ) # 点击功能中更新 analysis_type, nl2sql参数以及prompt - - update_info = gr.Textbox(label="Update Info", elem_id="update_info") - - # with gr.Row(): - # prompt_update_button = gr.Button( - # "prompt update", - # elem_id="prompt_update", - # variant="primary", - # ) # 点击功能中更新 nl2sql prompt - - # update_info = gr.Textbox( - # label="Update Info", elem_id="prompt_update_info" - # ) - def change_prompt_template(prompt_type): if prompt_type == "general": return { @@ -210,32 +198,6 @@ def change_prompt_template(prompt_type): outputs=[prompt_template], ) - inputs_db = { - dialect, - user, - password, - host, - port, - dbname, - tables, - descriptions, - prompt_template, - } - - update_button.click( - fn=update_setting, - inputs=inputs_db, - outputs=update_info, - api_name="update_info_clk", - ) - - # prompt_update_button.click( - # fn=nl2sql_prompt_update, - # inputs=prompt_template, - # outputs=update_info, - # api_name="update_nl2sql_prompt", - # ) - def data_analysis_type_change(type_value): if type_value == "datafile": return { @@ -255,23 +217,38 @@ def data_analysis_type_change(type_value): ) with gr.Column(scale=6): - chatbot = gr.Chatbot(height=500, elem_id="data_analysis_chatbot") + chatbot = gr.Chatbot(height=500, elem_id="chatbot") question = gr.Textbox(label="Enter your question.", elem_id="question") with gr.Row(): submitBtn = gr.Button("Submit", variant="primary") clearBtn = gr.Button("Clear History", variant="secondary") + chat_args = { + data_analysis_type, + dialect, + user, + password, + host, + port, + dbname, + tables, + descriptions, + prompt_template, + question, + chatbot, + } + submitBtn.click( - fn=analysis_respond, - inputs=[question, chatbot], + fn=respond, + inputs=chat_args, outputs=[chatbot], api_name="analysis_respond_clk", ) # 绑定Textbox提交事件,当按下Enter,调用respond函数 question.submit( - analysis_respond, - inputs=[question, chatbot], + respond, + inputs=chat_args, outputs=[chatbot], api_name="analysis_respond_q", ) diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index 883058bb..b984d7de 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -360,6 +360,7 @@ def retrieve_with_metadata( retrieved_nodes = [NodeWithScore(node=sql_only_node, score=1.0)] metadata = {"result": sql_query_str} else: + query_tables = self._get_table_from_sql(self._tables, sql_query_str) try: ( retrieved_nodes, @@ -384,7 +385,9 @@ def retrieve_with_metadata( if self._handle_sql_errors: logger.info(f"async error info: {e}\n") - new_sql_query_str = self._sql_query_modification(sql_query_str) + new_sql_query_str = self._sql_query_modification( + query_tables, sql_query_str + ) # 如果找到table,生成新的sql_query if new_sql_query_str != sql_query_str: @@ -419,7 +422,6 @@ def retrieve_with_metadata( metadata = {} # add query_tables into metadata - query_tables = self._get_table_from_sql(self._tables, sql_query_str) retrieved_nodes[0].metadata["query_tables"] = query_tables return retrieved_nodes, {"sql_query": sql_query_str, **metadata} @@ -453,6 +455,7 @@ async def aretrieve_with_metadata( retrieved_nodes = [NodeWithScore(node=sql_only_node, score=1.0)] metadata: Dict[str, Any] = {} else: + query_tables = self._get_table_from_sql(self._tables, sql_query_str) try: ( retrieved_nodes, @@ -479,7 +482,9 @@ async def aretrieve_with_metadata( if self._handle_sql_errors: logger.info(f"async error info: {e}\n") - new_sql_query_str = self._sql_query_modification(sql_query_str) + new_sql_query_str = self._sql_query_modification( + query_tables, sql_query_str + ) # 如果找到table,生成新的sql_query if new_sql_query_str != sql_query_str: @@ -516,7 +521,6 @@ async def aretrieve_with_metadata( metadata = {} # add query_tables into metadata - query_tables = self._get_table_from_sql(self._tables, sql_query_str) retrieved_nodes[0].metadata["query_tables"] = query_tables return retrieved_nodes, {"sql_query": sql_query_str, **metadata} @@ -528,13 +532,15 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: table_collection.append(table) return table_collection - def _sql_query_modification(self, sql_query_str): - table_pattern = r"FROM\s+(\w+)" - match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL) - if match: - first_table = match.group(1) + def _sql_query_modification(self, query_tables: list, sql_query_str: str): + # table_pattern = r"FROM\s+(\w+)" + # match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL) + # if match: + # 改用已知table匹配,否则match中FROM逻辑也可能匹配到无效的table + if len(query_tables) != 0: + first_table = query_tables[0] new_sql_query_str = f"SELECT * FROM {first_table}" - logger.info(f"use the whole table {first_table} instead if possible") + logger.info(f"use the whole table named {first_table} instead if possible") else: # raise ValueError("No table is matched") new_sql_query_str = sql_query_str