Skip to content

Commit

Permalink
Update ui&reference (#221)
Browse files Browse the repository at this point in the history
* update ui button and reference

* update refernece display & table match logic

* delete button
  • Loading branch information
aero-xi authored Sep 14, 2024
1 parent 99827df commit b41b43e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 88 deletions.
8 changes: 4 additions & 4 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,21 @@ def _format_rag_response(
if invalid_flag == 0:
run_flag = " ✓ "
ref_sql = doc["metadata"].get("query_code_instruction", None)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
formatted_sql_query = f"<b>{ref_sql}</b>"
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">生成的sql语句为:</span> <pre style="color:grey; font-size: 12px;">{formatted_sql_query}</pre> """
f"""<span style="color:grey; font-size: 14px;">sql查询是否有效:</span> <span style="color:green; font-size: 14px;">{run_flag}</span>"""
)
else:
run_flag = " ✗ "
ref_sql = doc["metadata"].get(
"generated_query_code_instruction", None
)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
formatted_sql_query = f"<b>{ref_sql}</b>"
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">生成的sql语句为:</span> <pre style="color:grey; font-size: 12px;">{formatted_sql_query}</pre> """
f"""<span style="color:grey; font-size: 14px;">sql查询是否有效:</span> <span style="color:red; font-size: 14px;">{run_flag}</span>"""
)
else:
Expand Down
125 changes: 51 additions & 74 deletions src/pai_rag/app/web/tabs/data_analysis_tab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import datetime
from typing import Dict, Any, List
import gradio as gr
import pandas as pd
Expand Down Expand Up @@ -44,35 +43,43 @@ def upload_file_fn(input_file):
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")


def update_setting(input_db: List[Any]):
try:
update_dict = {"analysis_type": "nl2sql"}
for element, value in input_db.items():
update_dict[element.elem_id] = value
# print("db_config:", update_dict)
def respond(input_elements: List[Any]):
update_dict = {}
for element, value in input_elements.items():
update_dict[element.elem_id] = value

if update_dict["analysis_type"] == "datafile":
update_dict["analysis_type"] = "nl2pandas"
else:
update_dict["analysis_type"] = "nl2sql"

# empty input.
if not update_dict["question"]:
yield update_dict["chatbot"]
return

# update snapshot
try:
rag_client.patch_config(update_dict)
return f"[{datetime.datetime.now()}] success!"
except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")

question = update_dict["question"]
chatbot = update_dict["chatbot"]

# def nl2sql_prompt_update(input_prompt):
# try:
# update_dict = {"db_nl2sql_prompt": input_prompt}
# # print('update_dict:', update_dict)
# rag_client.patch_config(update_dict)
# return f"[{datetime.datetime.now()}] success!"
# except RagApiError as api_error:
# raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")

if chatbot is not None:
chatbot.append((question, ""))

def analysis_respond(question, chatbot):
response_gen = rag_client.query_data_analysis(question, stream=True)
content = ""
chatbot.append((question, content))
for resp in response_gen:
chatbot[-1] = (question, resp.result)
try:
response_gen = rag_client.query_data_analysis(question, stream=True)
for resp in response_gen:
chatbot[-1] = (question, resp.result)
yield chatbot
except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")
except Exception as e:
raise gr.Error(f"Error: {e}")
finally:
yield chatbot


Expand All @@ -96,7 +103,7 @@ def create_data_analysis_tab() -> Dict[str, Any]:
],
value="datafile",
label="Please choose the data analysis type",
elem_id="data_analysis_type",
elem_id="analysis_type",
)

# datafile
Expand Down Expand Up @@ -169,25 +176,6 @@ def create_data_analysis_tab() -> Dict[str, Any]:
lines=4,
)

update_button = gr.Button(
"Update Settings",
elem_id="update_settings",
variant="primary",
) # 点击功能中更新 analysis_type, nl2sql参数以及prompt

update_info = gr.Textbox(label="Update Info", elem_id="update_info")

# with gr.Row():
# prompt_update_button = gr.Button(
# "prompt update",
# elem_id="prompt_update",
# variant="primary",
# ) # 点击功能中更新 nl2sql prompt

# update_info = gr.Textbox(
# label="Update Info", elem_id="prompt_update_info"
# )

def change_prompt_template(prompt_type):
if prompt_type == "general":
return {
Expand All @@ -210,32 +198,6 @@ def change_prompt_template(prompt_type):
outputs=[prompt_template],
)

inputs_db = {
dialect,
user,
password,
host,
port,
dbname,
tables,
descriptions,
prompt_template,
}

update_button.click(
fn=update_setting,
inputs=inputs_db,
outputs=update_info,
api_name="update_info_clk",
)

# prompt_update_button.click(
# fn=nl2sql_prompt_update,
# inputs=prompt_template,
# outputs=update_info,
# api_name="update_nl2sql_prompt",
# )

def data_analysis_type_change(type_value):
if type_value == "datafile":
return {
Expand All @@ -255,23 +217,38 @@ def data_analysis_type_change(type_value):
)

with gr.Column(scale=6):
chatbot = gr.Chatbot(height=500, elem_id="data_analysis_chatbot")
chatbot = gr.Chatbot(height=500, elem_id="chatbot")
question = gr.Textbox(label="Enter your question.", elem_id="question")
with gr.Row():
submitBtn = gr.Button("Submit", variant="primary")
clearBtn = gr.Button("Clear History", variant="secondary")

chat_args = {
data_analysis_type,
dialect,
user,
password,
host,
port,
dbname,
tables,
descriptions,
prompt_template,
question,
chatbot,
}

submitBtn.click(
fn=analysis_respond,
inputs=[question, chatbot],
fn=respond,
inputs=chat_args,
outputs=[chatbot],
api_name="analysis_respond_clk",
)

# 绑定Textbox提交事件,当按下Enter,调用respond函数
question.submit(
analysis_respond,
inputs=[question, chatbot],
respond,
inputs=chat_args,
outputs=[chatbot],
api_name="analysis_respond_q",
)
Expand Down
26 changes: 16 additions & 10 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def retrieve_with_metadata(
retrieved_nodes = [NodeWithScore(node=sql_only_node, score=1.0)]
metadata = {"result": sql_query_str}
else:
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
try:
(
retrieved_nodes,
Expand All @@ -384,7 +385,9 @@ def retrieve_with_metadata(
if self._handle_sql_errors:
logger.info(f"async error info: {e}\n")

new_sql_query_str = self._sql_query_modification(sql_query_str)
new_sql_query_str = self._sql_query_modification(
query_tables, sql_query_str
)

# 如果找到table,生成新的sql_query
if new_sql_query_str != sql_query_str:
Expand Down Expand Up @@ -419,7 +422,6 @@ def retrieve_with_metadata(
metadata = {}

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
Expand Down Expand Up @@ -453,6 +455,7 @@ async def aretrieve_with_metadata(
retrieved_nodes = [NodeWithScore(node=sql_only_node, score=1.0)]
metadata: Dict[str, Any] = {}
else:
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
try:
(
retrieved_nodes,
Expand All @@ -479,7 +482,9 @@ async def aretrieve_with_metadata(
if self._handle_sql_errors:
logger.info(f"async error info: {e}\n")

new_sql_query_str = self._sql_query_modification(sql_query_str)
new_sql_query_str = self._sql_query_modification(
query_tables, sql_query_str
)

# 如果找到table,生成新的sql_query
if new_sql_query_str != sql_query_str:
Expand Down Expand Up @@ -516,7 +521,6 @@ async def aretrieve_with_metadata(
metadata = {}

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
Expand All @@ -528,13 +532,15 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection.append(table)
return table_collection

def _sql_query_modification(self, sql_query_str):
table_pattern = r"FROM\s+(\w+)"
match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL)
if match:
first_table = match.group(1)
def _sql_query_modification(self, query_tables: list, sql_query_str: str):
# table_pattern = r"FROM\s+(\w+)"
# match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL)
# if match:
# 改用已知table匹配,否则match中FROM逻辑也可能匹配到无效的table
if len(query_tables) != 0:
first_table = query_tables[0]
new_sql_query_str = f"SELECT * FROM {first_table}"
logger.info(f"use the whole table {first_table} instead if possible")
logger.info(f"use the whole table named {first_table} instead if possible")
else:
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
Expand Down

0 comments on commit b41b43e

Please sign in to comment.