From 0d9eb8d0dac253c8649e161904abfd5248971111 Mon Sep 17 00:00:00 2001 From: aero-xi Date: Fri, 13 Sep 2024 17:08:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0data=20analysis=20prompt?= =?UTF-8?q?=E9=80=8F=E5=87=BA=20(#216)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add db ref * add reference * add nl2sql prompt * use one button * Update create table with cache * add fault tolerance to custom prompt --------- Co-authored-by: 陆逊 --- src/pai_rag/app/web/rag_client.py | 17 ++- src/pai_rag/app/web/tabs/chat_tab.py | 4 - src/pai_rag/app/web/tabs/data_analysis_tab.py | 133 ++++++++++++++---- src/pai_rag/app/web/ui_constants.py | 3 + src/pai_rag/app/web/view_model.py | 5 + src/pai_rag/config/settings.toml | 1 + src/pai_rag/config/settings_multi_modal.toml | 1 + .../data_analysis/nl2sql_retriever.py | 90 ++++++++++-- .../modules/dataanalysis/data_analysis.py | 42 +++++- 9 files changed, 237 insertions(+), 59 deletions(-) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 8aa3c2ed..85a107f0 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -114,6 +114,8 @@ def _format_rag_response( filename = doc["metadata"].get("file_name", None) ref_table = doc["metadata"].get("query_tables", None) invalid_flag = doc["metadata"].get("invalid_flag", 0) + ref_table = doc["metadata"].get("query_tables", None) + invalid_flag = doc["metadata"].get("invalid_flag", 0) file_url = doc["metadata"].get("file_url", None) media_url = doc.get("metadata", {}).get("image_url", None) if media_url and doc["text"] == "": @@ -150,12 +152,10 @@ def _format_rag_response( run_flag = " ✓ " ref_sql = doc["metadata"].get("query_code_instruction", None) formatted_sql_query = f"生成的sql语句为:{ref_sql}" - # content = f"""{formatted_table_name} \n\n {formatted_sql_query}""" content = ( - f"""{formatted_table_name} """ - """\n""" - f"""{formatted_sql_query} \n sql查询是否有效:""" - f"""{run_flag}""" + f"""{formatted_table_name} \n""" + f"""{formatted_sql_query} \n""" + f"""sql查询是否有效: {run_flag}""" ) else: run_flag = " ✗ " @@ -164,10 +164,9 @@ def _format_rag_response( ) formatted_sql_query = f"生成的sql语句为:{ref_sql}" content = ( - f"""{formatted_table_name} """ - """\n""" - f"""{formatted_sql_query} \n sql查询是否有效:""" - f"""{run_flag}""" + f"""{formatted_table_name} \n""" + f"""{formatted_sql_query} \n""" + f"""sql查询是否有效: {run_flag}""" ) else: content = "" diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py index 6c077876..c671b315 100644 --- a/src/pai_rag/app/web/tabs/chat_tab.py +++ b/src/pai_rag/app/web/tabs/chat_tab.py @@ -25,10 +25,6 @@ def respond(input_elements: List[Any]): for element, value in input_elements.items(): update_dict[element.elem_id] = value - if update_dict["retrieval_mode"] == "data_analysis": - update_dict["retrieval_mode"] = "hybrid" - update_dict["synthesizer_type"] = "SimpleSummarize" - # empty input. if not update_dict["question"]: yield update_dict["chatbot"] diff --git a/src/pai_rag/app/web/tabs/data_analysis_tab.py b/src/pai_rag/app/web/tabs/data_analysis_tab.py index b972ab22..825f1ded 100644 --- a/src/pai_rag/app/web/tabs/data_analysis_tab.py +++ b/src/pai_rag/app/web/tabs/data_analysis_tab.py @@ -4,6 +4,7 @@ import gradio as gr import pandas as pd from pai_rag.app.web.rag_client import rag_client, RagApiError +from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, DA_SQL_PROMPTS DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true") @@ -43,7 +44,7 @@ def upload_file_fn(input_file): raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") -def connect_database(input_db: List[Any]): +def update_setting(input_db: List[Any]): try: update_dict = {"analysis_type": "nl2sql"} for element, value in input_db.items(): @@ -51,11 +52,21 @@ def connect_database(input_db: List[Any]): # print("db_config:", update_dict) rag_client.patch_config(update_dict) - return f"[{datetime.datetime.now()}] Connect database success!" + return f"[{datetime.datetime.now()}] success!" except RagApiError as api_error: raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") +# def nl2sql_prompt_update(input_prompt): +# try: +# update_dict = {"db_nl2sql_prompt": input_prompt} +# # print('update_dict:', update_dict) +# rag_client.patch_config(update_dict) +# return f"[{datetime.datetime.now()}] success!" +# except RagApiError as api_error: +# raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") + + def analysis_respond(question, chatbot): response_gen = rag_client.query_data_analysis(question, stream=True) content = "" @@ -84,7 +95,7 @@ def create_data_analysis_tab() -> Dict[str, Any]: "database", ], value="datafile", - label="Please choose data analysis type", + label="Please choose the data analysis type", elem_id="data_analysis_type", ) @@ -115,38 +126,90 @@ def create_data_analysis_tab() -> Dict[str, Any]: # database with gr.Column(visible=(data_analysis_type.value == "database")) as db_col: - dialect = gr.Textbox( - label="Dialect", elem_id="db_dialect", value="mysql" - ) - user = gr.Textbox(label="Username", elem_id="db_username") - password = gr.Textbox( - label="Password", elem_id="db_password", type="password" - ) - host = gr.Textbox(label="Host", elem_id="db_host") - port = gr.Textbox(label="Port", elem_id="db_port", value=3306) - dbname = gr.Textbox(label="DBname", elem_id="db_name") - tables = gr.Textbox( - label="Tables", - elem_id="db_tables", - placeholder="List db tables, separated by commas, e.g. table_A, table_B, ... , using all tables if blank", - ) + with gr.Row(): + dialect = gr.Textbox( + label="Dialect", elem_id="db_dialect", value="mysql" + ) + port = gr.Textbox(label="Port", elem_id="db_port", value=3306) + host = gr.Textbox(label="Host", elem_id="db_host") + with gr.Row(): + user = gr.Textbox(label="Username", elem_id="db_username") + password = gr.Textbox( + label="Password", elem_id="db_password", type="password" + ) + with gr.Row(): + dbname = gr.Textbox(label="DBname", elem_id="db_name") + tables = gr.Textbox( + label="Tables", + elem_id="db_tables", + placeholder="List db tables, separated by commas, e.g. table_A, table_B, ... , using all tables if blank", + ) descriptions = gr.Textbox( label="Descriptions", - lines=5, + lines=3, elem_id="db_descriptions", placeholder='A dict of table descriptions, e.g. {"table_A": "text_description_A", "table_B": "text_description_B"}', ) - connect_db_button = gr.Button( - "Connect Database", - elem_id="connect_db_button", - variant="primary", - ) # 点击功能中更新analysis_type + prompt_type = gr.Radio( + [ + "general", + "sql", + "custom", + ], + value="general", + label="\N{rocket} Please choose the prompt template type", + elem_id="nl2sql_prompt_type", + ) - connection_info = gr.Textbox( - label="Connection Info", elem_id="db_connection_info" + prompt_template = gr.Textbox( + label="prompt template", + elem_id="db_nl2sql_prompt", + value=DA_GENERAL_PROMPTS, + lines=4, ) + update_button = gr.Button( + "Update Settings", + elem_id="update_settings", + variant="primary", + ) # 点击功能中更新 analysis_type, nl2sql参数以及prompt + + update_info = gr.Textbox(label="Update Info", elem_id="update_info") + + # with gr.Row(): + # prompt_update_button = gr.Button( + # "prompt update", + # elem_id="prompt_update", + # variant="primary", + # ) # 点击功能中更新 nl2sql prompt + + # update_info = gr.Textbox( + # label="Update Info", elem_id="prompt_update_info" + # ) + + def change_prompt_template(prompt_type): + if prompt_type == "general": + return { + prompt_template: gr.update( + value=DA_GENERAL_PROMPTS, interactive=False + ) + } + elif prompt_type == "sql": + return { + prompt_template: gr.update( + value=DA_SQL_PROMPTS, interactive=False + ) + } + else: + return {prompt_template: gr.update(value="", interactive=True)} + + prompt_type.input( + fn=change_prompt_template, + inputs=prompt_type, + outputs=[prompt_template], + ) + inputs_db = { dialect, user, @@ -156,15 +219,23 @@ def create_data_analysis_tab() -> Dict[str, Any]: dbname, tables, descriptions, + prompt_template, } - connect_db_button.click( - fn=connect_database, + update_button.click( + fn=update_setting, inputs=inputs_db, - outputs=connection_info, - api_name="connect_db", + outputs=update_info, + api_name="update_info_clk", ) + # prompt_update_button.click( + # fn=nl2sql_prompt_update, + # inputs=prompt_template, + # outputs=update_info, + # api_name="update_nl2sql_prompt", + # ) + def data_analysis_type_change(type_value): if type_value == "datafile": return { @@ -234,4 +305,6 @@ def data_analysis_type_change(type_value): dbname.elem_id: dbname, tables.elem_id: tables, descriptions.elem_id: descriptions, + prompt_type.elem_id: prompt_type, + prompt_template.elem_id: prompt_template, } diff --git a/src/pai_rag/app/web/ui_constants.py b/src/pai_rag/app/web/ui_constants.py index f582ba6b..65ac7689 100644 --- a/src/pai_rag/app/web/ui_constants.py +++ b/src/pai_rag/app/web/ui_constants.py @@ -3,6 +3,9 @@ EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}" ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}" +DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " +DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " + PROMPT_MAP = { SIMPLE_PROMPTS: "Simple", GENERAL_PROMPTS: "General", diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index 67f8d868..085ce61a 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -153,6 +153,7 @@ class ViewModel(BaseModel): db_name: str = None db_tables: str = None db_descriptions: str = None + db_nl2sql_prompt: str = None # postprocessor reranker_type: str = ( @@ -376,6 +377,8 @@ def from_app_config(config): else: view_model.db_descriptions = None + view_model.db_nl2sql_prompt = config["data_analysis"].get("nl2sql_prompt", None) + reranker_type = config["postprocessor"].get( "reranker_type", "simple-weighted-reranker" ) @@ -549,6 +552,7 @@ def to_app_config(self): config["data_analysis"]["host"] = self.db_host config["data_analysis"]["port"] = self.db_port config["data_analysis"]["dbname"] = self.db_name + config["data_analysis"]["nl2sql_prompt"] = self.db_nl2sql_prompt # string to list if self.db_tables: # 去掉首位空格和末尾逗号 @@ -815,5 +819,6 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]: settings["db_name"] = {"value": self.db_name} settings["db_tables"] = {"value": self.db_tables} settings["db_descriptions"] = {"value": self.db_descriptions} + settings["db_nl2sql_prompt"] = {"value": self.db_nl2sql_prompt} return settings diff --git a/src/pai_rag/config/settings.toml b/src/pai_rag/config/settings.toml index 3e94e6c4..2088cb79 100644 --- a/src/pai_rag/config/settings.toml +++ b/src/pai_rag/config/settings.toml @@ -27,6 +27,7 @@ persist_path = "localdata/storage" [rag.data_analysis] analysis_type = "nl2pandas" +nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " [rag.data_loader] type = "local" diff --git a/src/pai_rag/config/settings_multi_modal.toml b/src/pai_rag/config/settings_multi_modal.toml index 9a44e809..3ef69135 100644 --- a/src/pai_rag/config/settings_multi_modal.toml +++ b/src/pai_rag/config/settings_multi_modal.toml @@ -27,6 +27,7 @@ persist_path = "localdata/storage" [rag.data_analysis] analysis_type = "nl2pandas" +nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: " [rag.data_loader] type = "Local" # [Local, Oss] diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index 3ae73bbb..52a55db4 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -126,6 +126,7 @@ def retrieve_with_metadata( query_bundle.query_str = self._limit_check(query_bundle.query_str) logger.info(f"Limited SQL query: {query_bundle.query_str}") + # set timeout to 10s # set timeout to 10s signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(10) # start @@ -133,9 +134,11 @@ def retrieve_with_metadata( raw_response_str, metadata = self._sql_database.run_sql( query_bundle.query_str ) - except TimeoutError: - logger.info("SQL Query Timed Out (>10s)") - raw_response_str = "SQL Query Timed Out (>10s)" + except (TimeoutError, NotImplementedError) as error: + logger.info("Invalid SQL or SQL Query Timed Out (>10s)") + raise error + # raw_response_str = "Invalid SQL or SQL Query Timed Out (>10s)" + # metadata = {"result": {e}, "col_keys": []} finally: signal.alarm(0) # cancel @@ -262,6 +265,7 @@ def __init__( sql_database, tables, context_query_kwargs, table_retriever ) self._tables = tables + self._tables = tables self._context_str_prefix = context_str_prefix self._llm = llm or llm_from_settings_or_context(Settings, service_context) self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT @@ -363,12 +367,24 @@ def retrieve_with_metadata( metadata, ) = self._sql_retriever.retrieve_with_metadata(sql_query_str) retrieved_nodes[0].metadata["invalid_flag"] = 0 + retrieved_nodes[0].metadata["invalid_flag"] = 0 logger.info( f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" ) # 如果生成的sql语句执行后无结果,待bad case补充 # if retrieved_nodes[0].metadata["query_output"] == "": + # new_sql_query_str = self._sql_query_modification(sql_query_str) + # ( + # retrieved_nodes, + # metadata, + # ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str) + # logger.info( + # f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" + # ) + # 如果生成的sql语句执行后无结果,待bad case补充 + # if retrieved_nodes[0].metadata["query_output"] == "": + # new_sql_query_str = self._sql_query_modification(sql_query_str) # ( # retrieved_nodes, @@ -391,6 +407,10 @@ def retrieve_with_metadata( retrieved_nodes[0].metadata[ "generated_query_code_instruction" ] = sql_query_str + retrieved_nodes[0].metadata["invalid_flag"] = 1 + retrieved_nodes[0].metadata[ + "generated_query_code_instruction" + ] = sql_query_str logger.info( f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" ) @@ -405,6 +425,10 @@ def retrieve_with_metadata( query_tables = self._get_table_from_sql(self._tables, sql_query_str) retrieved_nodes[0].metadata["query_tables"] = query_tables + # add query_tables into metadata + query_tables = self._get_table_from_sql(self._tables, sql_query_str) + retrieved_nodes[0].metadata["query_tables"] = query_tables + return retrieved_nodes, {"sql_query": sql_query_str, **metadata} async def aretrieve_with_metadata( @@ -442,6 +466,7 @@ async def aretrieve_with_metadata( metadata, ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str) retrieved_nodes[0].metadata["invalid_flag"] = 0 + retrieved_nodes[0].metadata["invalid_flag"] = 0 logger.info( f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" ) @@ -463,17 +488,40 @@ async def aretrieve_with_metadata( logger.info(f"async error info: {e}\n") new_sql_query_str = self._sql_query_modification(sql_query_str) - ( - retrieved_nodes, - metadata, - ) = await self._sql_retriever.aretrieve_with_metadata(new_sql_query_str) - retrieved_nodes[0].metadata["invalid_flag"] = 1 - retrieved_nodes[0].metadata[ - "generated_query_code_instruction" - ] = sql_query_str - logger.info( - f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" - ) + + # 如果找到table,生成新的sql_query + if new_sql_query_str != sql_query_str: + ( + retrieved_nodes, + metadata, + ) = await self._sql_retriever.aretrieve_with_metadata( + new_sql_query_str + ) + retrieved_nodes[0].metadata["invalid_flag"] = 1 + retrieved_nodes[0].metadata[ + "generated_query_code_instruction" + ] = sql_query_str + logger.info( + f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" + ) + # 没有找到table,新旧sql_query一样,不再通过_sql_retriever执行,直接retrieved_nodes + else: + logger.info(f"[{new_sql_query_str}] is not even a SQL") + retrieved_nodes = [ + NodeWithScore( + node=TextNode( + text=new_sql_query_str, + metadata={ + "query_code_instruction": new_sql_query_str, + "generated_query_code_instruction": sql_query_str, + "query_output": "", + "invalid_flag": 1, + }, + ), + score=1.0, + ), + ] + metadata = {} # err_node = TextNode(text=f"Error: {e!s}") # logger.info(f"async error_node info: {err_node}\n") # retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)] @@ -484,6 +532,10 @@ async def aretrieve_with_metadata( query_tables = self._get_table_from_sql(self._tables, sql_query_str) retrieved_nodes[0].metadata["query_tables"] = query_tables + # add query_tables into metadata + query_tables = self._get_table_from_sql(self._tables, sql_query_str) + retrieved_nodes[0].metadata["query_tables"] = query_tables + return retrieved_nodes, {"sql_query": sql_query_str, **metadata} def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: @@ -493,6 +545,13 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: table_collection.append(table) return table_collection + def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: + table_collection = list() + for table in table_list: + if table.lower() in sql_query.lower(): + table_collection.append(table) + return table_collection + def _sql_query_modification(self, sql_query_str): table_pattern = r"FROM\s+(\w+)" match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL) @@ -504,6 +563,9 @@ def _sql_query_modification(self, sql_query_str): # raise ValueError("No table is matched") new_sql_query_str = sql_query_str logger.info("No table is matched") + # raise ValueError("No table is matched") + new_sql_query_str = sql_query_str + logger.info("No table is matched") return new_sql_query_str diff --git a/src/pai_rag/modules/dataanalysis/data_analysis.py b/src/pai_rag/modules/dataanalysis/data_analysis.py index 1af3b669..c7e3142f 100644 --- a/src/pai_rag/modules/dataanalysis/data_analysis.py +++ b/src/pai_rag/modules/dataanalysis/data_analysis.py @@ -1,3 +1,4 @@ +import functools import logging import os import glob @@ -7,6 +8,7 @@ from sqlalchemy.engine import URL from sqlalchemy.pool import QueuePool from llama_index.core import SQLDatabase +from llama_index.core.prompts import PromptTemplate from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG @@ -36,6 +38,11 @@ def _create_new_instance(self, new_params: Dict[str, Any]): llm = new_params["LlmModule"] embed_model = new_params["EmbeddingModule"] data_analysis_type = config.get("analysis_type", "nl2pandas") + nl2sql_prompt = config.get("nl2sql_prompt", None) + if nl2sql_prompt: + nl2sql_prompt = PromptTemplate(nl2sql_prompt) + else: + nl2sql_prompt = DEFAULT_TEXT_TO_SQL_TMPL if data_analysis_type == "nl2pandas": df = self.get_dataframe(config) @@ -51,7 +58,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): sql_database, tables, table_descriptions = self.db_connection(config) analysis_retriever = MyNLSQLRetriever( sql_database=sql_database, - text_to_sql_prompt=DEFAULT_TEXT_TO_SQL_TMPL, + text_to_sql_prompt=nl2sql_prompt, tables=tables, context_query_kwargs=table_descriptions, sql_only=False, @@ -59,6 +66,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): llm=llm, ) logger.info("DataAnalysis NL2SQLRetriever used") + # logger.info(f"nl2sql prompt: {nl2sql_prompt}") else: raise ValueError( @@ -76,7 +84,6 @@ def _create_new_instance(self, new_params: Dict[str, Any]): ) def db_connection(self, config): - # get rds_db config dialect = config.get("dialect", "sqlite") user = config.get("user", "") password = config.get("password", "") @@ -87,6 +94,37 @@ def db_connection(self, config): desired_tables = config.get("tables", []) table_descriptions = config.get("descriptions", {}) + return self.inspect_db_connection( + dialect=dialect, + user=user, + password=password, + host=host, + port=port, + path=path, + dbname=dbname, + desired_tables=tuple(desired_tables) if desired_tables else None, + table_descriptions=tuple(table_descriptions.items()) + if table_descriptions + else None, + ) + + @functools.cache + def inspect_db_connection( + self, + dialect, + user, + password, + host, + port, + path, + dbname, + desired_tables, + table_descriptions, + ): + desired_tables = list(desired_tables) if desired_tables else None + table_descriptions = dict(table_descriptions) if table_descriptions else None + + # get rds_db config logger.info(f"desired_tables from ui input: {desired_tables}") logger.info(f"table_descriptions from ui input: {table_descriptions}")