diff --git a/src/pai_rag/app/web/tabs/data_analysis_tab.py b/src/pai_rag/app/web/tabs/data_analysis_tab.py index d8429a43..4bbf962e 100644 --- a/src/pai_rag/app/web/tabs/data_analysis_tab.py +++ b/src/pai_rag/app/web/tabs/data_analysis_tab.py @@ -3,7 +3,11 @@ import gradio as gr import pandas as pd from pai_rag.app.web.rag_client import rag_client, RagApiError -from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, DA_SQL_PROMPTS +from pai_rag.app.web.ui_constants import ( + DA_GENERAL_PROMPTS, + DA_SQL_PROMPTS, + SYN_GENERAL_PROMPTS, +) DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true") @@ -157,47 +161,82 @@ def create_data_analysis_tab() -> Dict[str, Any]: elem_id="db_descriptions", placeholder='A dict of table descriptions, e.g. {"table_A": "text_description_A", "table_B": "text_description_B"}', ) + with gr.Column(visible=True): + with gr.Tab("Nl2sql Prompt"): + sql_prompt_type = gr.Radio( + [ + "general", + "sql", + "custom", + ], + value="general", + label="\N{rocket} Please choose the nl2sql prompt template", + elem_id="nl2sql_prompt_type", + ) - prompt_type = gr.Radio( - [ - "general", - "sql", - "custom", - ], - value="general", - label="\N{rocket} Please choose the prompt template type", - elem_id="nl2sql_prompt_type", - ) + db_nl2sql_prompt = gr.Textbox( + label="nl2sql template", + elem_id="db_nl2sql_prompt", + value=DA_GENERAL_PROMPTS, + lines=4, + ) - db_nl2sql_prompt = gr.Textbox( - label="Prompt template", - elem_id="db_nl2sql_prompt", - value=DA_GENERAL_PROMPTS, - lines=4, - ) + with gr.Tab("Synthesize Prompt"): + syn_prompt_type = gr.Radio( + [ + "general", + "custom", + ], + value="general", + label="\N{rocket} Please choose the synthesizer prompt template", + elem_id="synthesizer_prompt_type", + ) + + synthesizer_prompt = gr.Textbox( + label="synthesizer template", + elem_id="synthesizer_prompt", + value=SYN_GENERAL_PROMPTS, + lines=4, + ) - def change_prompt_template(prompt_type): + def change_sql_prompt_template(prompt_type): if prompt_type == "general": return { db_nl2sql_prompt: gr.update( - value=DA_GENERAL_PROMPTS, interactive=False + value=DA_GENERAL_PROMPTS, interactive=True ) } elif prompt_type == "sql": return { db_nl2sql_prompt: gr.update( - value=DA_SQL_PROMPTS, interactive=False + value=DA_SQL_PROMPTS, interactive=True ) } else: return {db_nl2sql_prompt: gr.update(value="", interactive=True)} - prompt_type.input( - fn=change_prompt_template, - inputs=prompt_type, + sql_prompt_type.input( + fn=change_sql_prompt_template, + inputs=sql_prompt_type, outputs=[db_nl2sql_prompt], ) + def change_syn_prompt_template(prompt_type): + if prompt_type == "general": + return { + synthesizer_prompt: gr.update( + value=SYN_GENERAL_PROMPTS, interactive=True + ) + } + else: + return {synthesizer_prompt: gr.update(value="", interactive=True)} + + syn_prompt_type.input( + fn=change_syn_prompt_template, + inputs=syn_prompt_type, + outputs=[synthesizer_prompt], + ) + def data_analysis_type_change(type_value): if type_value == "datafile": return { @@ -234,6 +273,7 @@ def data_analysis_type_change(type_value): tables, descriptions, db_nl2sql_prompt, + synthesizer_prompt, question, chatbot, } @@ -282,6 +322,8 @@ def data_analysis_type_change(type_value): database.elem_id: database, tables.elem_id: tables, descriptions.elem_id: descriptions, - prompt_type.elem_id: prompt_type, + sql_prompt_type.elem_id: sql_prompt_type, db_nl2sql_prompt.elem_id: db_nl2sql_prompt, + syn_prompt_type.elem_id: syn_prompt_type, + synthesizer_prompt.elem_id: synthesizer_prompt, } diff --git a/src/pai_rag/app/web/ui_constants.py b/src/pai_rag/app/web/ui_constants.py index 0dd1ee22..1d67db52 100644 --- a/src/pai_rag/app/web/ui_constants.py +++ b/src/pai_rag/app/web/ui_constants.py @@ -8,6 +8,7 @@ DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " +SYN_GENERAL_PROMPTS = "给定一个输入问题,根据查询代码指令以及查询结果生成最终回复,生成的回复语言需要与输入问题的语言保持一致。\n\n输入问题: {query_str} \n\nSQL 或 Python 查询代码指令(可选): {query_code_instruction}\n\n 查询结果: {query_output}\n\n 最终回复: " # WELCOME_MESSAGE = """ diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index 5cd8b213..d1c1c4f2 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -103,6 +103,7 @@ class ViewModel(BaseModel): db_tables: str = None db_descriptions: str = None db_nl2sql_prompt: str = None + synthesizer_prompt: str = None # postprocessor reranker_type: str = "no-reranker" # no-reranker / model-based-reranker @@ -214,6 +215,7 @@ def from_app_config(config: RagConfig): config.data_analysis.descriptions, ensure_ascii=False ) view_model.db_nl2sql_prompt = config.data_analysis.nl2sql_prompt + view_model.synthesizer_prompt = config.data_analysis.synthesizer_prompt return view_model @@ -286,6 +288,7 @@ def to_app_config(self): config["data_analysis"]["port"] = self.db_port config["data_analysis"]["database"] = self.database config["data_analysis"]["nl2sql_prompt"] = self.db_nl2sql_prompt + config["data_analysis"]["synthesizer_prompt"] = self.synthesizer_prompt # string to list if self.db_tables: @@ -499,6 +502,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]: settings["db_tables"] = {"value": self.db_tables} settings["db_descriptions"] = {"value": self.db_descriptions} settings["db_nl2sql_prompt"] = {"value": self.db_nl2sql_prompt} + settings["synthesizer_prompt"] = {"value": self.synthesizer_prompt} print(settings) return settings diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_config.py b/src/pai_rag/integrations/data_analysis/data_analysis_config.py index 287ba770..c587f7a1 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_config.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_config.py @@ -2,7 +2,7 @@ from typing import Dict, List, Literal from pydantic import BaseModel -from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS +from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, SYN_GENERAL_PROMPTS class DataAnalysisType(str, Enum): @@ -18,6 +18,7 @@ class BaseAnalysisConfig(BaseModel): type: DataAnalysisType nl2sql_prompt: str = DA_GENERAL_PROMPTS + synthesizer_prompt: str = SYN_GENERAL_PROMPTS class PandasAnalysisConfig(BaseAnalysisConfig): @@ -41,4 +42,4 @@ class MysqlAnalysisConfig(SqlAnalysisConfig): user: str password: str host: str - port: str = 3306 + port: int diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py index dbc029c9..581988dc 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py @@ -76,7 +76,8 @@ def __init__( ) self._synthesizer = DataAnalysisSynthesizer( llm=self._llm, - response_synthesis_prompt=DEFAULT_RESPONSE_SYNTHESIS_PROMPT, + response_synthesis_prompt=PromptTemplate(analysis_config.synthesizer_prompt) + or DEFAULT_RESPONSE_SYNTHESIS_PROMPT, ) super().__init__(callback_manager=callback_manager) diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index 017f6968..8ac1ecdd 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -241,7 +241,7 @@ def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str sql_result_start = response.find("SQLResult:") if sql_result_start != -1: response = response[:sql_result_start] - return response.strip().strip("```").strip().strip(";").strip() + return response.strip().strip("```").strip().strip(";").strip().lstrip("sql") def get_sql_info(sql_config: SqlAnalysisConfig): @@ -285,8 +285,8 @@ def inspect_db_connection( engine = create_engine( database_url, echo=False, - pool_size=5, - max_overflow=10, + pool_size=10, + max_overflow=20, pool_timeout=30, pool_recycle=360, poolclass=QueuePool, @@ -339,6 +339,7 @@ class MyNLSQLRetriever(BaseRetriever, PromptMixin): def __init__( self, sql_database: SQLDatabase, + dialect: str, text_to_sql_prompt: Optional[BasePromptTemplate] = None, context_query_kwargs: Optional[dict] = None, tables: Optional[Union[List[str], List[Table]]] = None, @@ -362,7 +363,7 @@ def __init__( sql_database, tables, context_query_kwargs, table_retriever ) self._tables = tables - self._tables = tables + self._dialect = dialect self._context_str_prefix = context_str_prefix self._llm = llm or llm_from_settings_or_context(Settings, service_context) self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_TMPL @@ -393,8 +394,10 @@ def from_config( nl2sql_prompt_tmpl = DEFAULT_TEXT_TO_SQL_TMPL sql_database, tables, table_descriptions = get_sql_info(sql_config) + # print("tmp_test:", sql_config.type) return cls( sql_database=sql_database, + dialect=sql_config.type, llm=llm, text_to_sql_prompt=nl2sql_prompt_tmpl, tables=tables, @@ -681,11 +684,11 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: def _get_table_context(self, query_bundle: QueryBundle) -> str: """Get table context. - - Get tables schema + optional context as a single string. - + Get tables schema + optional context + data sample as a single string. """ - table_schema_objs = self._get_tables(query_bundle.query_str) + table_schema_objs = self._get_tables( + query_bundle.query_str + ) # get a list of SQLTableSchema, e.g. [SQLTableSchema(table_name='has_pet', context_str=None),] context_strs = [] if self._context_str_prefix is not None: context_strs = [self._context_str_prefix] @@ -693,13 +696,27 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str: for table_schema_obj in table_schema_objs: table_info = self._sql_database.get_single_table_info( table_schema_obj.table_name - ) + ) # get ddl info + data_sample = self._get_data_sample( + table_schema_obj.table_name + ) # get data sample + table_info_with_sample = table_info + "\ndata_sample: " + data_sample if table_schema_obj.context_str: table_opt_context = " The table description is: " table_opt_context += table_schema_obj.context_str - table_info += table_opt_context + table_info_with_sample += table_opt_context - context_strs.append(table_info) + context_strs.append(table_info_with_sample) return "\n\n".join(context_strs) + + def _get_data_sample(self, table: str, sample_n: int = 3) -> str: + # 对每个table随机采样 + if self._dialect == "mysql": + sql_str = f"SELECT * FROM {table} ORDER BY RAND() LIMIT {sample_n};" + if self._dialect in ("sqlite", "postgresql"): + sql_str = f"Select * FROM {table} ORDER BY RANDOM() LIMIT {sample_n};" + table_sample, _ = self._sql_database.run_sql(sql_str) + + return table_sample diff --git a/tests/integrations/test_nl2sql_retriever.py b/tests/integrations/test_nl2sql_retriever.py index 012c59bf..8f9d3c81 100644 --- a/tests/integrations/test_nl2sql_retriever.py +++ b/tests/integrations/test_nl2sql_retriever.py @@ -101,6 +101,7 @@ def test_nl2sql_retriever(db_connection): sql_database, db_tables, table_descriptions = db_connection nl2sql_retriever = MyNLSQLRetriever( sql_database=sql_database, + dialect="sqlite", text_to_sql_prompt=DEFAULT_TEXT_TO_SQL_TMPL, tables=db_tables, )