Skip to content

Commit

Permalink
增加data analysis prompt透出 (#216)
Browse files Browse the repository at this point in the history
* add db ref

* add reference

* add nl2sql prompt

* use one button

* Update create table with cache

* add fault tolerance to custom prompt

---------

Co-authored-by: 陆逊 <luxun.fy@alibaba-inc.com>
  • Loading branch information
aero-xi and moria97 authored Sep 13, 2024
1 parent 22e5fc8 commit 0d9eb8d
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 59 deletions.
17 changes: 8 additions & 9 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def _format_rag_response(
filename = doc["metadata"].get("file_name", None)
ref_table = doc["metadata"].get("query_tables", None)
invalid_flag = doc["metadata"].get("invalid_flag", 0)
ref_table = doc["metadata"].get("query_tables", None)
invalid_flag = doc["metadata"].get("invalid_flag", 0)
file_url = doc["metadata"].get("file_url", None)
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and doc["text"] == "":
Expand Down Expand Up @@ -150,12 +152,10 @@ def _format_rag_response(
run_flag = " ✓ "
ref_sql = doc["metadata"].get("query_code_instruction", None)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
# content = f"""{formatted_table_name} \n\n {formatted_sql_query}"""
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> """
"""\n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query} \n sql查询是否有效:</span>"""
f"""<span style="color:green; font-size: 14px;">{run_flag}"""
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">sql查询是否有效:</span> <span style="color:green; font-size: 14px;">{run_flag}</span>"""
)
else:
run_flag = " ✗ "
Expand All @@ -164,10 +164,9 @@ def _format_rag_response(
)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> """
"""\n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query} \n sql查询是否有效:</span>"""
f"""<span style="color:red; font-size: 14px;">{run_flag}"""
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query}</span> \n"""
f"""<span style="color:grey; font-size: 14px;">sql查询是否有效:</span> <span style="color:red; font-size: 14px;">{run_flag}</span>"""
)
else:
content = ""
Expand Down
4 changes: 0 additions & 4 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ def respond(input_elements: List[Any]):
for element, value in input_elements.items():
update_dict[element.elem_id] = value

if update_dict["retrieval_mode"] == "data_analysis":
update_dict["retrieval_mode"] = "hybrid"
update_dict["synthesizer_type"] = "SimpleSummarize"

# empty input.
if not update_dict["question"]:
yield update_dict["chatbot"]
Expand Down
133 changes: 103 additions & 30 deletions src/pai_rag/app/web/tabs/data_analysis_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import gradio as gr
import pandas as pd
from pai_rag.app.web.rag_client import rag_client, RagApiError
from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, DA_SQL_PROMPTS


DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true")
Expand Down Expand Up @@ -43,19 +44,29 @@ def upload_file_fn(input_file):
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")


def connect_database(input_db: List[Any]):
def update_setting(input_db: List[Any]):
try:
update_dict = {"analysis_type": "nl2sql"}
for element, value in input_db.items():
update_dict[element.elem_id] = value
# print("db_config:", update_dict)

rag_client.patch_config(update_dict)
return f"[{datetime.datetime.now()}] Connect database success!"
return f"[{datetime.datetime.now()}] success!"
except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")


# def nl2sql_prompt_update(input_prompt):
# try:
# update_dict = {"db_nl2sql_prompt": input_prompt}
# # print('update_dict:', update_dict)
# rag_client.patch_config(update_dict)
# return f"[{datetime.datetime.now()}] success!"
# except RagApiError as api_error:
# raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")


def analysis_respond(question, chatbot):
response_gen = rag_client.query_data_analysis(question, stream=True)
content = ""
Expand Down Expand Up @@ -84,7 +95,7 @@ def create_data_analysis_tab() -> Dict[str, Any]:
"database",
],
value="datafile",
label="Please choose data analysis type",
label="Please choose the data analysis type",
elem_id="data_analysis_type",
)

Expand Down Expand Up @@ -115,38 +126,90 @@ def create_data_analysis_tab() -> Dict[str, Any]:

# database
with gr.Column(visible=(data_analysis_type.value == "database")) as db_col:
dialect = gr.Textbox(
label="Dialect", elem_id="db_dialect", value="mysql"
)
user = gr.Textbox(label="Username", elem_id="db_username")
password = gr.Textbox(
label="Password", elem_id="db_password", type="password"
)
host = gr.Textbox(label="Host", elem_id="db_host")
port = gr.Textbox(label="Port", elem_id="db_port", value=3306)
dbname = gr.Textbox(label="DBname", elem_id="db_name")
tables = gr.Textbox(
label="Tables",
elem_id="db_tables",
placeholder="List db tables, separated by commas, e.g. table_A, table_B, ... , using all tables if blank",
)
with gr.Row():
dialect = gr.Textbox(
label="Dialect", elem_id="db_dialect", value="mysql"
)
port = gr.Textbox(label="Port", elem_id="db_port", value=3306)
host = gr.Textbox(label="Host", elem_id="db_host")
with gr.Row():
user = gr.Textbox(label="Username", elem_id="db_username")
password = gr.Textbox(
label="Password", elem_id="db_password", type="password"
)
with gr.Row():
dbname = gr.Textbox(label="DBname", elem_id="db_name")
tables = gr.Textbox(
label="Tables",
elem_id="db_tables",
placeholder="List db tables, separated by commas, e.g. table_A, table_B, ... , using all tables if blank",
)
descriptions = gr.Textbox(
label="Descriptions",
lines=5,
lines=3,
elem_id="db_descriptions",
placeholder='A dict of table descriptions, e.g. {"table_A": "text_description_A", "table_B": "text_description_B"}',
)

connect_db_button = gr.Button(
"Connect Database",
elem_id="connect_db_button",
variant="primary",
) # 点击功能中更新analysis_type
prompt_type = gr.Radio(
[
"general",
"sql",
"custom",
],
value="general",
label="\N{rocket} Please choose the prompt template type",
elem_id="nl2sql_prompt_type",
)

connection_info = gr.Textbox(
label="Connection Info", elem_id="db_connection_info"
prompt_template = gr.Textbox(
label="prompt template",
elem_id="db_nl2sql_prompt",
value=DA_GENERAL_PROMPTS,
lines=4,
)

update_button = gr.Button(
"Update Settings",
elem_id="update_settings",
variant="primary",
) # 点击功能中更新 analysis_type, nl2sql参数以及prompt

update_info = gr.Textbox(label="Update Info", elem_id="update_info")

# with gr.Row():
# prompt_update_button = gr.Button(
# "prompt update",
# elem_id="prompt_update",
# variant="primary",
# ) # 点击功能中更新 nl2sql prompt

# update_info = gr.Textbox(
# label="Update Info", elem_id="prompt_update_info"
# )

def change_prompt_template(prompt_type):
if prompt_type == "general":
return {
prompt_template: gr.update(
value=DA_GENERAL_PROMPTS, interactive=False
)
}
elif prompt_type == "sql":
return {
prompt_template: gr.update(
value=DA_SQL_PROMPTS, interactive=False
)
}
else:
return {prompt_template: gr.update(value="", interactive=True)}

prompt_type.input(
fn=change_prompt_template,
inputs=prompt_type,
outputs=[prompt_template],
)

inputs_db = {
dialect,
user,
Expand All @@ -156,15 +219,23 @@ def create_data_analysis_tab() -> Dict[str, Any]:
dbname,
tables,
descriptions,
prompt_template,
}

connect_db_button.click(
fn=connect_database,
update_button.click(
fn=update_setting,
inputs=inputs_db,
outputs=connection_info,
api_name="connect_db",
outputs=update_info,
api_name="update_info_clk",
)

# prompt_update_button.click(
# fn=nl2sql_prompt_update,
# inputs=prompt_template,
# outputs=update_info,
# api_name="update_nl2sql_prompt",
# )

def data_analysis_type_change(type_value):
if type_value == "datafile":
return {
Expand Down Expand Up @@ -234,4 +305,6 @@ def data_analysis_type_change(type_value):
dbname.elem_id: dbname,
tables.elem_id: tables,
descriptions.elem_id: descriptions,
prompt_type.elem_id: prompt_type,
prompt_template.elem_id: prompt_template,
}
3 changes: 3 additions & 0 deletions src/pai_rag/app/web/ui_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"

DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "

PROMPT_MAP = {
SIMPLE_PROMPTS: "Simple",
GENERAL_PROMPTS: "General",
Expand Down
5 changes: 5 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class ViewModel(BaseModel):
db_name: str = None
db_tables: str = None
db_descriptions: str = None
db_nl2sql_prompt: str = None

# postprocessor
reranker_type: str = (
Expand Down Expand Up @@ -376,6 +377,8 @@ def from_app_config(config):
else:
view_model.db_descriptions = None

view_model.db_nl2sql_prompt = config["data_analysis"].get("nl2sql_prompt", None)

reranker_type = config["postprocessor"].get(
"reranker_type", "simple-weighted-reranker"
)
Expand Down Expand Up @@ -549,6 +552,7 @@ def to_app_config(self):
config["data_analysis"]["host"] = self.db_host
config["data_analysis"]["port"] = self.db_port
config["data_analysis"]["dbname"] = self.db_name
config["data_analysis"]["nl2sql_prompt"] = self.db_nl2sql_prompt
# string to list
if self.db_tables:
# 去掉首位空格和末尾逗号
Expand Down Expand Up @@ -815,5 +819,6 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["db_name"] = {"value": self.db_name}
settings["db_tables"] = {"value": self.db_tables}
settings["db_descriptions"] = {"value": self.db_descriptions}
settings["db_nl2sql_prompt"] = {"value": self.db_nl2sql_prompt}

return settings
1 change: 1 addition & 0 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ persist_path = "localdata/storage"

[rag.data_analysis]
analysis_type = "nl2pandas"
nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "

[rag.data_loader]
type = "local"
Expand Down
1 change: 1 addition & 0 deletions src/pai_rag/config/settings_multi_modal.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ persist_path = "localdata/storage"

[rag.data_analysis]
analysis_type = "nl2pandas"
nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "

[rag.data_loader]
type = "Local" # [Local, Oss]
Expand Down
Loading

0 comments on commit 0d9eb8d

Please sign in to comment.