diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py
index 8aa3c2ed..85a107f0 100644
--- a/src/pai_rag/app/web/rag_client.py
+++ b/src/pai_rag/app/web/rag_client.py
@@ -114,6 +114,8 @@ def _format_rag_response(
filename = doc["metadata"].get("file_name", None)
ref_table = doc["metadata"].get("query_tables", None)
invalid_flag = doc["metadata"].get("invalid_flag", 0)
+ ref_table = doc["metadata"].get("query_tables", None)
+ invalid_flag = doc["metadata"].get("invalid_flag", 0)
file_url = doc["metadata"].get("file_url", None)
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and doc["text"] == "":
@@ -150,12 +152,10 @@ def _format_rag_response(
run_flag = " ✓ "
ref_sql = doc["metadata"].get("query_code_instruction", None)
formatted_sql_query = f"生成的sql语句为:{ref_sql}"
- # content = f"""{formatted_table_name} \n\n {formatted_sql_query}"""
content = (
- f"""{formatted_table_name} """
- """\n"""
- f"""{formatted_sql_query} \n sql查询是否有效:"""
- f"""{run_flag}"""
+ f"""{formatted_table_name} \n"""
+ f"""{formatted_sql_query} \n"""
+ f"""sql查询是否有效: {run_flag}"""
)
else:
run_flag = " ✗ "
@@ -164,10 +164,9 @@ def _format_rag_response(
)
formatted_sql_query = f"生成的sql语句为:{ref_sql}"
content = (
- f"""{formatted_table_name} """
- """\n"""
- f"""{formatted_sql_query} \n sql查询是否有效:"""
- f"""{run_flag}"""
+ f"""{formatted_table_name} \n"""
+ f"""{formatted_sql_query} \n"""
+ f"""sql查询是否有效: {run_flag}"""
)
else:
content = ""
diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py
index 6c077876..c671b315 100644
--- a/src/pai_rag/app/web/tabs/chat_tab.py
+++ b/src/pai_rag/app/web/tabs/chat_tab.py
@@ -25,10 +25,6 @@ def respond(input_elements: List[Any]):
for element, value in input_elements.items():
update_dict[element.elem_id] = value
- if update_dict["retrieval_mode"] == "data_analysis":
- update_dict["retrieval_mode"] = "hybrid"
- update_dict["synthesizer_type"] = "SimpleSummarize"
-
# empty input.
if not update_dict["question"]:
yield update_dict["chatbot"]
diff --git a/src/pai_rag/app/web/tabs/data_analysis_tab.py b/src/pai_rag/app/web/tabs/data_analysis_tab.py
index b972ab22..825f1ded 100644
--- a/src/pai_rag/app/web/tabs/data_analysis_tab.py
+++ b/src/pai_rag/app/web/tabs/data_analysis_tab.py
@@ -4,6 +4,7 @@
import gradio as gr
import pandas as pd
from pai_rag.app.web.rag_client import rag_client, RagApiError
+from pai_rag.app.web.ui_constants import DA_GENERAL_PROMPTS, DA_SQL_PROMPTS
DEFAULT_IS_INTERACTIVE = os.environ.get("PAIRAG_RAG__SETTING__interactive", "true")
@@ -43,7 +44,7 @@ def upload_file_fn(input_file):
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")
-def connect_database(input_db: List[Any]):
+def update_setting(input_db: List[Any]):
try:
update_dict = {"analysis_type": "nl2sql"}
for element, value in input_db.items():
@@ -51,11 +52,21 @@ def connect_database(input_db: List[Any]):
# print("db_config:", update_dict)
rag_client.patch_config(update_dict)
- return f"[{datetime.datetime.now()}] Connect database success!"
+ return f"[{datetime.datetime.now()}] success!"
except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")
+# def nl2sql_prompt_update(input_prompt):
+# try:
+# update_dict = {"db_nl2sql_prompt": input_prompt}
+# # print('update_dict:', update_dict)
+# rag_client.patch_config(update_dict)
+# return f"[{datetime.datetime.now()}] success!"
+# except RagApiError as api_error:
+# raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")
+
+
def analysis_respond(question, chatbot):
response_gen = rag_client.query_data_analysis(question, stream=True)
content = ""
@@ -84,7 +95,7 @@ def create_data_analysis_tab() -> Dict[str, Any]:
"database",
],
value="datafile",
- label="Please choose data analysis type",
+ label="Please choose the data analysis type",
elem_id="data_analysis_type",
)
@@ -115,38 +126,90 @@ def create_data_analysis_tab() -> Dict[str, Any]:
# database
with gr.Column(visible=(data_analysis_type.value == "database")) as db_col:
- dialect = gr.Textbox(
- label="Dialect", elem_id="db_dialect", value="mysql"
- )
- user = gr.Textbox(label="Username", elem_id="db_username")
- password = gr.Textbox(
- label="Password", elem_id="db_password", type="password"
- )
- host = gr.Textbox(label="Host", elem_id="db_host")
- port = gr.Textbox(label="Port", elem_id="db_port", value=3306)
- dbname = gr.Textbox(label="DBname", elem_id="db_name")
- tables = gr.Textbox(
- label="Tables",
- elem_id="db_tables",
- placeholder="List db tables, separated by commas, e.g. table_A, table_B, ... , using all tables if blank",
- )
+ with gr.Row():
+ dialect = gr.Textbox(
+ label="Dialect", elem_id="db_dialect", value="mysql"
+ )
+ port = gr.Textbox(label="Port", elem_id="db_port", value=3306)
+ host = gr.Textbox(label="Host", elem_id="db_host")
+ with gr.Row():
+ user = gr.Textbox(label="Username", elem_id="db_username")
+ password = gr.Textbox(
+ label="Password", elem_id="db_password", type="password"
+ )
+ with gr.Row():
+ dbname = gr.Textbox(label="DBname", elem_id="db_name")
+ tables = gr.Textbox(
+ label="Tables",
+ elem_id="db_tables",
+ placeholder="List db tables, separated by commas, e.g. table_A, table_B, ... , using all tables if blank",
+ )
descriptions = gr.Textbox(
label="Descriptions",
- lines=5,
+ lines=3,
elem_id="db_descriptions",
placeholder='A dict of table descriptions, e.g. {"table_A": "text_description_A", "table_B": "text_description_B"}',
)
- connect_db_button = gr.Button(
- "Connect Database",
- elem_id="connect_db_button",
- variant="primary",
- ) # 点击功能中更新analysis_type
+ prompt_type = gr.Radio(
+ [
+ "general",
+ "sql",
+ "custom",
+ ],
+ value="general",
+ label="\N{rocket} Please choose the prompt template type",
+ elem_id="nl2sql_prompt_type",
+ )
- connection_info = gr.Textbox(
- label="Connection Info", elem_id="db_connection_info"
+ prompt_template = gr.Textbox(
+ label="prompt template",
+ elem_id="db_nl2sql_prompt",
+ value=DA_GENERAL_PROMPTS,
+ lines=4,
)
+ update_button = gr.Button(
+ "Update Settings",
+ elem_id="update_settings",
+ variant="primary",
+ ) # 点击功能中更新 analysis_type, nl2sql参数以及prompt
+
+ update_info = gr.Textbox(label="Update Info", elem_id="update_info")
+
+ # with gr.Row():
+ # prompt_update_button = gr.Button(
+ # "prompt update",
+ # elem_id="prompt_update",
+ # variant="primary",
+ # ) # 点击功能中更新 nl2sql prompt
+
+ # update_info = gr.Textbox(
+ # label="Update Info", elem_id="prompt_update_info"
+ # )
+
+ def change_prompt_template(prompt_type):
+ if prompt_type == "general":
+ return {
+ prompt_template: gr.update(
+ value=DA_GENERAL_PROMPTS, interactive=False
+ )
+ }
+ elif prompt_type == "sql":
+ return {
+ prompt_template: gr.update(
+ value=DA_SQL_PROMPTS, interactive=False
+ )
+ }
+ else:
+ return {prompt_template: gr.update(value="", interactive=True)}
+
+ prompt_type.input(
+ fn=change_prompt_template,
+ inputs=prompt_type,
+ outputs=[prompt_template],
+ )
+
inputs_db = {
dialect,
user,
@@ -156,15 +219,23 @@ def create_data_analysis_tab() -> Dict[str, Any]:
dbname,
tables,
descriptions,
+ prompt_template,
}
- connect_db_button.click(
- fn=connect_database,
+ update_button.click(
+ fn=update_setting,
inputs=inputs_db,
- outputs=connection_info,
- api_name="connect_db",
+ outputs=update_info,
+ api_name="update_info_clk",
)
+ # prompt_update_button.click(
+ # fn=nl2sql_prompt_update,
+ # inputs=prompt_template,
+ # outputs=update_info,
+ # api_name="update_nl2sql_prompt",
+ # )
+
def data_analysis_type_change(type_value):
if type_value == "datafile":
return {
@@ -234,4 +305,6 @@ def data_analysis_type_change(type_value):
dbname.elem_id: dbname,
tables.elem_id: tables,
descriptions.elem_id: descriptions,
+ prompt_type.elem_id: prompt_type,
+ prompt_template.elem_id: prompt_template,
}
diff --git a/src/pai_rag/app/web/ui_constants.py b/src/pai_rag/app/web/ui_constants.py
index f582ba6b..65ac7689 100644
--- a/src/pai_rag/app/web/ui_constants.py
+++ b/src/pai_rag/app/web/ui_constants.py
@@ -3,6 +3,9 @@
EXTRACT_URL_PROMPTS = "你是一位智能小助手,请根据下面我所提供的相关知识,对我提出的问题进行回答。回答的内容必须包括其定义、特征、应用领域以及相关网页链接等等内容,同时务必满足下方所提的要求!\n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
ACCURATE_CONTENT_PROMPTS = "你是一位知识小助手,请根据下面我提供的知识库中相关知识,对我提出的若干问题进行回答,同时回答的内容需满足我所提的要求! \n=====\n 知识库相关知识如下:\n{context_str}\n=====\n 请根据上方所提供的知识库内容与要求,回答以下问题:\n {query_str}"
+DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
+DA_SQL_PROMPTS = "给定一个输入问题,其中包含了需要执行的SQL语句,请提取问题中的SQL语句,并使用{schema}进行校验优化,生成符合相应语法{dialect}和schema的SQL语句。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
+
PROMPT_MAP = {
SIMPLE_PROMPTS: "Simple",
GENERAL_PROMPTS: "General",
diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py
index 67f8d868..085ce61a 100644
--- a/src/pai_rag/app/web/view_model.py
+++ b/src/pai_rag/app/web/view_model.py
@@ -153,6 +153,7 @@ class ViewModel(BaseModel):
db_name: str = None
db_tables: str = None
db_descriptions: str = None
+ db_nl2sql_prompt: str = None
# postprocessor
reranker_type: str = (
@@ -376,6 +377,8 @@ def from_app_config(config):
else:
view_model.db_descriptions = None
+ view_model.db_nl2sql_prompt = config["data_analysis"].get("nl2sql_prompt", None)
+
reranker_type = config["postprocessor"].get(
"reranker_type", "simple-weighted-reranker"
)
@@ -549,6 +552,7 @@ def to_app_config(self):
config["data_analysis"]["host"] = self.db_host
config["data_analysis"]["port"] = self.db_port
config["data_analysis"]["dbname"] = self.db_name
+ config["data_analysis"]["nl2sql_prompt"] = self.db_nl2sql_prompt
# string to list
if self.db_tables:
# 去掉首位空格和末尾逗号
@@ -815,5 +819,6 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["db_name"] = {"value": self.db_name}
settings["db_tables"] = {"value": self.db_tables}
settings["db_descriptions"] = {"value": self.db_descriptions}
+ settings["db_nl2sql_prompt"] = {"value": self.db_nl2sql_prompt}
return settings
diff --git a/src/pai_rag/config/settings.toml b/src/pai_rag/config/settings.toml
index 3e94e6c4..2088cb79 100644
--- a/src/pai_rag/config/settings.toml
+++ b/src/pai_rag/config/settings.toml
@@ -27,6 +27,7 @@ persist_path = "localdata/storage"
[rag.data_analysis]
analysis_type = "nl2pandas"
+nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
[rag.data_loader]
type = "local"
diff --git a/src/pai_rag/config/settings_multi_modal.toml b/src/pai_rag/config/settings_multi_modal.toml
index 9a44e809..3ef69135 100644
--- a/src/pai_rag/config/settings_multi_modal.toml
+++ b/src/pai_rag/config/settings_multi_modal.toml
@@ -27,6 +27,7 @@ persist_path = "localdata/storage"
[rag.data_analysis]
analysis_type = "nl2pandas"
+nl2sql_prompt = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
[rag.data_loader]
type = "Local" # [Local, Oss]
diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
index 3ae73bbb..52a55db4 100644
--- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
+++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
@@ -126,6 +126,7 @@ def retrieve_with_metadata(
query_bundle.query_str = self._limit_check(query_bundle.query_str)
logger.info(f"Limited SQL query: {query_bundle.query_str}")
+ # set timeout to 10s
# set timeout to 10s
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(10) # start
@@ -133,9 +134,11 @@ def retrieve_with_metadata(
raw_response_str, metadata = self._sql_database.run_sql(
query_bundle.query_str
)
- except TimeoutError:
- logger.info("SQL Query Timed Out (>10s)")
- raw_response_str = "SQL Query Timed Out (>10s)"
+ except (TimeoutError, NotImplementedError) as error:
+ logger.info("Invalid SQL or SQL Query Timed Out (>10s)")
+ raise error
+ # raw_response_str = "Invalid SQL or SQL Query Timed Out (>10s)"
+ # metadata = {"result": {e}, "col_keys": []}
finally:
signal.alarm(0) # cancel
@@ -262,6 +265,7 @@ def __init__(
sql_database, tables, context_query_kwargs, table_retriever
)
self._tables = tables
+ self._tables = tables
self._context_str_prefix = context_str_prefix
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
@@ -363,12 +367,24 @@ def retrieve_with_metadata(
metadata,
) = self._sql_retriever.retrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
+ retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":
+ # new_sql_query_str = self._sql_query_modification(sql_query_str)
+ # (
+ # retrieved_nodes,
+ # metadata,
+ # ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
+ # logger.info(
+ # f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
+ # )
+ # 如果生成的sql语句执行后无结果,待bad case补充
+ # if retrieved_nodes[0].metadata["query_output"] == "":
+
# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
@@ -391,6 +407,10 @@ def retrieve_with_metadata(
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
+ retrieved_nodes[0].metadata["invalid_flag"] = 1
+ retrieved_nodes[0].metadata[
+ "generated_query_code_instruction"
+ ] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
@@ -405,6 +425,10 @@ def retrieve_with_metadata(
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables
+ # add query_tables into metadata
+ query_tables = self._get_table_from_sql(self._tables, sql_query_str)
+ retrieved_nodes[0].metadata["query_tables"] = query_tables
+
return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
async def aretrieve_with_metadata(
@@ -442,6 +466,7 @@ async def aretrieve_with_metadata(
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
+ retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
@@ -463,17 +488,40 @@ async def aretrieve_with_metadata(
logger.info(f"async error info: {e}\n")
new_sql_query_str = self._sql_query_modification(sql_query_str)
- (
- retrieved_nodes,
- metadata,
- ) = await self._sql_retriever.aretrieve_with_metadata(new_sql_query_str)
- retrieved_nodes[0].metadata["invalid_flag"] = 1
- retrieved_nodes[0].metadata[
- "generated_query_code_instruction"
- ] = sql_query_str
- logger.info(
- f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
- )
+
+ # 如果找到table,生成新的sql_query
+ if new_sql_query_str != sql_query_str:
+ (
+ retrieved_nodes,
+ metadata,
+ ) = await self._sql_retriever.aretrieve_with_metadata(
+ new_sql_query_str
+ )
+ retrieved_nodes[0].metadata["invalid_flag"] = 1
+ retrieved_nodes[0].metadata[
+ "generated_query_code_instruction"
+ ] = sql_query_str
+ logger.info(
+ f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
+ )
+ # 没有找到table,新旧sql_query一样,不再通过_sql_retriever执行,直接retrieved_nodes
+ else:
+ logger.info(f"[{new_sql_query_str}] is not even a SQL")
+ retrieved_nodes = [
+ NodeWithScore(
+ node=TextNode(
+ text=new_sql_query_str,
+ metadata={
+ "query_code_instruction": new_sql_query_str,
+ "generated_query_code_instruction": sql_query_str,
+ "query_output": "",
+ "invalid_flag": 1,
+ },
+ ),
+ score=1.0,
+ ),
+ ]
+ metadata = {}
# err_node = TextNode(text=f"Error: {e!s}")
# logger.info(f"async error_node info: {err_node}\n")
# retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)]
@@ -484,6 +532,10 @@ async def aretrieve_with_metadata(
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables
+ # add query_tables into metadata
+ query_tables = self._get_table_from_sql(self._tables, sql_query_str)
+ retrieved_nodes[0].metadata["query_tables"] = query_tables
+
return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
@@ -493,6 +545,13 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection.append(table)
return table_collection
+ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
+ table_collection = list()
+ for table in table_list:
+ if table.lower() in sql_query.lower():
+ table_collection.append(table)
+ return table_collection
+
def _sql_query_modification(self, sql_query_str):
table_pattern = r"FROM\s+(\w+)"
match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL)
@@ -504,6 +563,9 @@ def _sql_query_modification(self, sql_query_str):
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
logger.info("No table is matched")
+ # raise ValueError("No table is matched")
+ new_sql_query_str = sql_query_str
+ logger.info("No table is matched")
return new_sql_query_str
diff --git a/src/pai_rag/modules/dataanalysis/data_analysis.py b/src/pai_rag/modules/dataanalysis/data_analysis.py
index 1af3b669..c7e3142f 100644
--- a/src/pai_rag/modules/dataanalysis/data_analysis.py
+++ b/src/pai_rag/modules/dataanalysis/data_analysis.py
@@ -1,3 +1,4 @@
+import functools
import logging
import os
import glob
@@ -7,6 +8,7 @@
from sqlalchemy.engine import URL
from sqlalchemy.pool import QueuePool
from llama_index.core import SQLDatabase
+from llama_index.core.prompts import PromptTemplate
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
@@ -36,6 +38,11 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
llm = new_params["LlmModule"]
embed_model = new_params["EmbeddingModule"]
data_analysis_type = config.get("analysis_type", "nl2pandas")
+ nl2sql_prompt = config.get("nl2sql_prompt", None)
+ if nl2sql_prompt:
+ nl2sql_prompt = PromptTemplate(nl2sql_prompt)
+ else:
+ nl2sql_prompt = DEFAULT_TEXT_TO_SQL_TMPL
if data_analysis_type == "nl2pandas":
df = self.get_dataframe(config)
@@ -51,7 +58,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
sql_database, tables, table_descriptions = self.db_connection(config)
analysis_retriever = MyNLSQLRetriever(
sql_database=sql_database,
- text_to_sql_prompt=DEFAULT_TEXT_TO_SQL_TMPL,
+ text_to_sql_prompt=nl2sql_prompt,
tables=tables,
context_query_kwargs=table_descriptions,
sql_only=False,
@@ -59,6 +66,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
llm=llm,
)
logger.info("DataAnalysis NL2SQLRetriever used")
+ # logger.info(f"nl2sql prompt: {nl2sql_prompt}")
else:
raise ValueError(
@@ -76,7 +84,6 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
)
def db_connection(self, config):
- # get rds_db config
dialect = config.get("dialect", "sqlite")
user = config.get("user", "")
password = config.get("password", "")
@@ -87,6 +94,37 @@ def db_connection(self, config):
desired_tables = config.get("tables", [])
table_descriptions = config.get("descriptions", {})
+ return self.inspect_db_connection(
+ dialect=dialect,
+ user=user,
+ password=password,
+ host=host,
+ port=port,
+ path=path,
+ dbname=dbname,
+ desired_tables=tuple(desired_tables) if desired_tables else None,
+ table_descriptions=tuple(table_descriptions.items())
+ if table_descriptions
+ else None,
+ )
+
+ @functools.cache
+ def inspect_db_connection(
+ self,
+ dialect,
+ user,
+ password,
+ host,
+ port,
+ path,
+ dbname,
+ desired_tables,
+ table_descriptions,
+ ):
+ desired_tables = list(desired_tables) if desired_tables else None
+ table_descriptions = dict(table_descriptions) if table_descriptions else None
+
+ # get rds_db config
logger.info(f"desired_tables from ui input: {desired_tables}")
logger.info(f"table_descriptions from ui input: {table_descriptions}")