Skip to content

Commit

Permalink
Personal/xi/nl2sql op1 (#254)
Browse files Browse the repository at this point in the history
* add syn prompt

* add data_sample

* update parse by lstrip

* update test
  • Loading branch information
aero-xi authored Oct 23, 2024
1 parent b4464f7 commit 68842d4
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 38 deletions.
90 changes: 66 additions & 24 deletions src/pai_rag/app/web/tabs/data_analysis_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import gradio as gr
import pandas as pd
from pai_rag.app.web.rag_client import rag_client, RagApiError
from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, DA_SQL_PROMPTS
from pai_rag.app.web.ui_constants import (
DA_GENERAL_PROMPTS,
DA_SQL_PROMPTS,
SYN_GENERAL_PROMPTS,
)


DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true")
Expand Down Expand Up @@ -157,47 +161,82 @@ def create_data_analysis_tab() -> Dict[str, Any]:
elem_id="db_descriptions",
placeholder='A dict of table descriptions, e.g. {"table_A": "text_description_A", "table_B": "text_description_B"}',
)
with gr.Column(visible=True):
with gr.Tab("Nl2sql Prompt"):
sql_prompt_type = gr.Radio(
[
"general",
"sql",
"custom",
],
value="general",
label="\N{rocket} Please choose the nl2sql prompt template",
elem_id="nl2sql_prompt_type",
)

prompt_type = gr.Radio(
[
"general",
"sql",
"custom",
],
value="general",
label="\N{rocket} Please choose the prompt template type",
elem_id="nl2sql_prompt_type",
)
db_nl2sql_prompt = gr.Textbox(
label="nl2sql template",
elem_id="db_nl2sql_prompt",
value=DA_GENERAL_PROMPTS,
lines=4,
)

db_nl2sql_prompt = gr.Textbox(
label="Prompt template",
elem_id="db_nl2sql_prompt",
value=DA_GENERAL_PROMPTS,
lines=4,
)
with gr.Tab("Synthesize Prompt"):
syn_prompt_type = gr.Radio(
[
"general",
"custom",
],
value="general",
label="\N{rocket} Please choose the synthesizer prompt template",
elem_id="synthesizer_prompt_type",
)

synthesizer_prompt = gr.Textbox(
label="synthesizer template",
elem_id="synthesizer_prompt",
value=SYN_GENERAL_PROMPTS,
lines=4,
)

def change_prompt_template(prompt_type):
def change_sql_prompt_template(prompt_type):
if prompt_type == "general":
return {
db_nl2sql_prompt: gr.update(
value=DA_GENERAL_PROMPTS, interactive=False
value=DA_GENERAL_PROMPTS, interactive=True
)
}
elif prompt_type == "sql":
return {
db_nl2sql_prompt: gr.update(
value=DA_SQL_PROMPTS, interactive=False
value=DA_SQL_PROMPTS, interactive=True
)
}
else:
return {db_nl2sql_prompt: gr.update(value="", interactive=True)}

prompt_type.input(
fn=change_prompt_template,
inputs=prompt_type,
sql_prompt_type.input(
fn=change_sql_prompt_template,
inputs=sql_prompt_type,
outputs=[db_nl2sql_prompt],
)

def change_syn_prompt_template(prompt_type):
if prompt_type == "general":
return {
synthesizer_prompt: gr.update(
value=SYN_GENERAL_PROMPTS, interactive=True
)
}
else:
return {synthesizer_prompt: gr.update(value="", interactive=True)}

syn_prompt_type.input(
fn=change_syn_prompt_template,
inputs=syn_prompt_type,
outputs=[synthesizer_prompt],
)

def data_analysis_type_change(type_value):
if type_value == "datafile":
return {
Expand Down Expand Up @@ -234,6 +273,7 @@ def data_analysis_type_change(type_value):
tables,
descriptions,
db_nl2sql_prompt,
synthesizer_prompt,
question,
chatbot,
}
Expand Down Expand Up @@ -282,6 +322,8 @@ def data_analysis_type_change(type_value):
database.elem_id: database,
tables.elem_id: tables,
descriptions.elem_id: descriptions,
prompt_type.elem_id: prompt_type,
sql_prompt_type.elem_id: sql_prompt_type,
db_nl2sql_prompt.elem_id: db_nl2sql_prompt,
syn_prompt_type.elem_id: syn_prompt_type,
synthesizer_prompt.elem_id: synthesizer_prompt,
}
1 change: 1 addition & 0 deletions src/pai_rag/app/web/ui_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
SYN_GENERAL_PROMPTS = "给定一个输入问题,根据查询代码指令以及查询结果生成最终回复,生成的回复语言需要与输入问题的语言保持一致。\n\n输入问题: {query_str} \n\nSQL 或 Python 查询代码指令(可选): {query_code_instruction}\n\n 查询结果: {query_output}\n\n 最终回复: "


# WELCOME_MESSAGE = """
Expand Down
4 changes: 4 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ViewModel(BaseModel):
db_tables: str = None
db_descriptions: str = None
db_nl2sql_prompt: str = None
synthesizer_prompt: str = None

# postprocessor
reranker_type: str = "no-reranker" # no-reranker / model-based-reranker
Expand Down Expand Up @@ -214,6 +215,7 @@ def from_app_config(config: RagConfig):
config.data_analysis.descriptions, ensure_ascii=False
)
view_model.db_nl2sql_prompt = config.data_analysis.nl2sql_prompt
view_model.synthesizer_prompt = config.data_analysis.synthesizer_prompt

return view_model

Expand Down Expand Up @@ -286,6 +288,7 @@ def to_app_config(self):
config["data_analysis"]["port"] = self.db_port
config["data_analysis"]["database"] = self.database
config["data_analysis"]["nl2sql_prompt"] = self.db_nl2sql_prompt
config["data_analysis"]["synthesizer_prompt"] = self.synthesizer_prompt

# string to list
if self.db_tables:
Expand Down Expand Up @@ -499,6 +502,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["db_tables"] = {"value": self.db_tables}
settings["db_descriptions"] = {"value": self.db_descriptions}
settings["db_nl2sql_prompt"] = {"value": self.db_nl2sql_prompt}
settings["synthesizer_prompt"] = {"value": self.synthesizer_prompt}

print(settings)
return settings
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Literal
from pydantic import BaseModel

from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS
from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, SYN_GENERAL_PROMPTS


class DataAnalysisType(str, Enum):
Expand All @@ -18,6 +18,7 @@ class BaseAnalysisConfig(BaseModel):

type: DataAnalysisType
nl2sql_prompt: str = DA_GENERAL_PROMPTS
synthesizer_prompt: str = SYN_GENERAL_PROMPTS


class PandasAnalysisConfig(BaseAnalysisConfig):
Expand All @@ -41,4 +42,4 @@ class MysqlAnalysisConfig(SqlAnalysisConfig):
user: str
password: str
host: str
port: str = 3306
port: int
3 changes: 2 additions & 1 deletion src/pai_rag/integrations/data_analysis/data_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def __init__(
)
self._synthesizer = DataAnalysisSynthesizer(
llm=self._llm,
response_synthesis_prompt=DEFAULT_RESPONSE_SYNTHESIS_PROMPT,
response_synthesis_prompt=PromptTemplate(analysis_config.synthesizer_prompt)
or DEFAULT_RESPONSE_SYNTHESIS_PROMPT,
)
super().__init__(callback_manager=callback_manager)

Expand Down
39 changes: 28 additions & 11 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip().strip(";").strip()
return response.strip().strip("```").strip().strip(";").strip().lstrip("sql")


def get_sql_info(sql_config: SqlAnalysisConfig):
Expand Down Expand Up @@ -285,8 +285,8 @@ def inspect_db_connection(
engine = create_engine(
database_url,
echo=False,
pool_size=5,
max_overflow=10,
pool_size=10,
max_overflow=20,
pool_timeout=30,
pool_recycle=360,
poolclass=QueuePool,
Expand Down Expand Up @@ -339,6 +339,7 @@ class MyNLSQLRetriever(BaseRetriever, PromptMixin):
def __init__(
self,
sql_database: SQLDatabase,
dialect: str,
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
context_query_kwargs: Optional[dict] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
Expand All @@ -362,7 +363,7 @@ def __init__(
sql_database, tables, context_query_kwargs, table_retriever
)
self._tables = tables
self._tables = tables
self._dialect = dialect
self._context_str_prefix = context_str_prefix
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_TMPL
Expand Down Expand Up @@ -393,8 +394,10 @@ def from_config(
nl2sql_prompt_tmpl = DEFAULT_TEXT_TO_SQL_TMPL

sql_database, tables, table_descriptions = get_sql_info(sql_config)
# print("tmp_test:", sql_config.type)
return cls(
sql_database=sql_database,
dialect=sql_config.type,
llm=llm,
text_to_sql_prompt=nl2sql_prompt_tmpl,
tables=tables,
Expand Down Expand Up @@ -681,25 +684,39 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:

def _get_table_context(self, query_bundle: QueryBundle) -> str:
"""Get table context.
Get tables schema + optional context as a single string.
Get tables schema + optional context + data sample as a single string.
"""
table_schema_objs = self._get_tables(query_bundle.query_str)
table_schema_objs = self._get_tables(
query_bundle.query_str
) # get a list of SQLTableSchema, e.g. [SQLTableSchema(table_name='has_pet', context_str=None),]
context_strs = []
if self._context_str_prefix is not None:
context_strs = [self._context_str_prefix]

for table_schema_obj in table_schema_objs:
table_info = self._sql_database.get_single_table_info(
table_schema_obj.table_name
)
) # get ddl info
data_sample = self._get_data_sample(
table_schema_obj.table_name
) # get data sample
table_info_with_sample = table_info + "\ndata_sample: " + data_sample

if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context
table_info_with_sample += table_opt_context

context_strs.append(table_info)
context_strs.append(table_info_with_sample)

return "\n\n".join(context_strs)

def _get_data_sample(self, table: str, sample_n: int = 3) -> str:
# 对每个table随机采样
if self._dialect == "mysql":
sql_str = f"SELECT * FROM {table} ORDER BY RAND() LIMIT {sample_n};"
if self._dialect in ("sqlite", "postgresql"):
sql_str = f"Select * FROM {table} ORDER BY RANDOM() LIMIT {sample_n};"
table_sample, _ = self._sql_database.run_sql(sql_str)

return table_sample
1 change: 1 addition & 0 deletions tests/integrations/test_nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_nl2sql_retriever(db_connection):
sql_database, db_tables, table_descriptions = db_connection
nl2sql_retriever = MyNLSQLRetriever(
sql_database=sql_database,
dialect="sqlite",
text_to_sql_prompt=DEFAULT_TEXT_TO_SQL_TMPL,
tables=db_tables,
)
Expand Down

0 comments on commit 68842d4

Please sign in to comment.