Skip to content

Commit

Permalink
add reference
Browse files Browse the repository at this point in the history
  • Loading branch information
aero-xi committed Sep 11, 2024
1 parent 8deae9b commit 1ad1766
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 29 deletions.
28 changes: 26 additions & 2 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _format_rag_response(
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
ref_table = doc["metadata"].get("query_tables", None)
invalid_flag = doc["metadata"].get("invalid_flag", 0)
file_url = doc["metadata"].get("file_url", None)
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and doc["text"] == "":
Expand Down Expand Up @@ -143,8 +144,31 @@ def _format_rag_response(
"""
elif ref_table:
ref_table_format = ", ".join([i for i in ref_table])
formatted_table_name = f"数据库中相关表名包括: {ref_table_format}"
content = f"""{formatted_table_name}"""
formatted_table_name = f"查询数据库中相关表名包括: <b>{ref_table_format}</b>"

if invalid_flag == 0:
run_flag = " ✓ "
ref_sql = doc["metadata"].get("query_code_instruction", None)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
# content = f"""{formatted_table_name} \n\n {formatted_sql_query}"""
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> """
"""\n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query} \n sql查询是否有效:</span>"""
f"""<span style="color:green; font-size: 14px;">{run_flag}"""
)
else:
run_flag = " ✗ "
ref_sql = doc["metadata"].get(
"generated_query_code_instruction", None
)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> """
"""\n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query} \n sql查询是否有效:</span>"""
f"""<span style="color:red; font-size: 14px;">{run_flag}"""
)
else:
content = ""
content_list.append(content)
Expand Down
62 changes: 35 additions & 27 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def retrieve_with_metadata(
query_bundle.query_str = self._limit_check(query_bundle.query_str)
logger.info(f"Limited SQL query: {query_bundle.query_str}")

# set timeout to 5s
# set timeout to 10s
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(10) # start
try:
Expand Down Expand Up @@ -362,18 +362,21 @@ def retrieve_with_metadata(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
if retrieved_nodes[0].metadata["query_output"] == []:
new_sql_query_str = self._sql_query_modification(sql_query_str)
(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":

# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
# metadata,
# ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
# logger.info(
# f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
# )
except BaseException as e:
# if handle_sql_errors is True, then return error message
if self._handle_sql_errors:
Expand All @@ -384,6 +387,10 @@ def retrieve_with_metadata(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
Expand All @@ -395,9 +402,7 @@ def retrieve_with_metadata(
# raise

# add query_tables into metadata
query_tables = self._get_table_from_sql(
self._tables, retrieved_nodes[0].metadata["query_code_instruction"]
)
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
Expand Down Expand Up @@ -436,20 +441,21 @@ async def aretrieve_with_metadata(
retrieved_nodes,
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
if retrieved_nodes[0].metadata["query_output"] == []:
new_sql_query_str = self._sql_query_modification(sql_query_str)
(
retrieved_nodes,
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(
new_sql_query_str
)
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# if retrieved_nodes[0].metadata["query_output"] == "":
# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
# metadata,
# ) = await self._sql_retriever.aretrieve_with_metadata(
# new_sql_query_str
# )
# logger.info(
# f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
# )

except BaseException as e:
# if handle_sql_errors is True, then return error message
Expand All @@ -461,6 +467,10 @@ async def aretrieve_with_metadata(
retrieved_nodes,
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
Expand All @@ -471,9 +481,7 @@ async def aretrieve_with_metadata(
# else:
# raise
# add query_tables into metadata
query_tables = self._get_table_from_sql(
self._tables, retrieved_nodes[0].metadata["query_code_instruction"]
)
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}
Expand Down

0 comments on commit 1ad1766

Please sign in to comment.