Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

db分析增加reference #212

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 29 additions & 41 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def _format_rag_response(
self.session_id = session_id
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
ref_table = doc["metadata"].get("query_tables", None)
invalid_flag = doc["metadata"].get("invalid_flag", 0)
file_url = doc["metadata"].get("file_url", None)
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and doc["text"] == "":
Expand Down Expand Up @@ -140,6 +142,33 @@ def _format_rag_response(
</span>
<br>
"""
elif ref_table:
ref_table_format = ", ".join([i for i in ref_table])
formatted_table_name = f"查询数据库中相关表名包括: <b>{ref_table_format}</b>"

if invalid_flag == 0:
run_flag = " ✓ "
ref_sql = doc["metadata"].get("query_code_instruction", None)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
# content = f"""{formatted_table_name} \n\n {formatted_sql_query}"""
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> """
"""\n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query} \n sql查询是否有效:</span>"""
f"""<span style="color:green; font-size: 14px;">{run_flag}"""
)
else:
run_flag = " ✗ "
ref_sql = doc["metadata"].get(
"generated_query_code_instruction", None
)
formatted_sql_query = f"生成的sql语句为:<b>{ref_sql}</b>"
content = (
f"""<span style="color:grey; font-size: 14px;">{formatted_table_name}</span> """
"""\n"""
f"""<span style="color:grey; font-size: 14px;">{formatted_sql_query} \n sql查询是否有效:</span>"""
f"""<span style="color:red; font-size: 14px;">{run_flag}"""
)
else:
content = ""
content_list.append(content)
Expand Down Expand Up @@ -441,46 +470,5 @@ def evaluate_for_response_stage(self):
raise RagApiError(code=r.status_code, msg=response.message)
print("evaluate_for_response_stage response", response)

def _format_data_analysis_rag_response(
self, question, response, session_id: str = None, stream: bool = False
):
if stream:
text = response["delta"]
else:
text = response["answer"]

docs = response.get("docs", []) or []
is_finished = response.get("is_finished", True)

referenced_docs = ""
if is_finished and len(docs) == 0 and not text:
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question)
return response
elif is_finished:
seen_filenames = set()
file_idx = 1
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
if filename and filename not in seen_filenames:
seen_filenames.add(filename)
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
title = doc["metadata"].get("title")
if not title:
referenced_docs += f'[{file_idx}]: {formatted_file_name} Score:{doc["score"]} \n'
else:
referenced_docs += f'[{file_idx}]: [{title}]({formatted_file_name}) Score:{doc["score"]} \n'

file_idx += 1
formatted_answer = ""
if session_id:
new_query = response["new_query"]
formatted_answer += f"**Query Transformation**: {new_query} \n\n"
formatted_answer += f"**Answer**: {text} \n\n"
if referenced_docs:
formatted_answer += f"**Reference**:\n {referenced_docs}"

response["result"] = formatted_answer
return response


rag_client = RagWebClient()
74 changes: 52 additions & 22 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def retrieve_with_metadata(
query_bundle.query_str = self._limit_check(query_bundle.query_str)
logger.info(f"Limited SQL query: {query_bundle.query_str}")

# set timeout to 5s
# set timeout to 10s
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(10) # start
try:
Expand Down Expand Up @@ -261,6 +261,7 @@ def __init__(
self._get_tables = self._load_get_tables_fn(
sql_database, tables, context_query_kwargs, table_retriever
)
self._tables = tables
self._context_str_prefix = context_str_prefix
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
Expand Down Expand Up @@ -361,18 +362,21 @@ def retrieve_with_metadata(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
if retrieved_nodes[0].metadata["query_output"] == []:
new_sql_query_str = self._sql_query_modification(sql_query_str)
(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":

# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
# metadata,
# ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
# logger.info(
# f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
# )
except BaseException as e:
# if handle_sql_errors is True, then return error message
if self._handle_sql_errors:
Expand All @@ -383,6 +387,10 @@ def retrieve_with_metadata(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
Expand All @@ -393,6 +401,10 @@ def retrieve_with_metadata(
# else:
# raise

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

async def aretrieve_with_metadata(
Expand Down Expand Up @@ -429,20 +441,21 @@ async def aretrieve_with_metadata(
retrieved_nodes,
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
if retrieved_nodes[0].metadata["query_output"] == []:
new_sql_query_str = self._sql_query_modification(sql_query_str)
(
retrieved_nodes,
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(
new_sql_query_str
)
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# if retrieved_nodes[0].metadata["query_output"] == "":
# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
# metadata,
# ) = await self._sql_retriever.aretrieve_with_metadata(
# new_sql_query_str
# )
# logger.info(
# f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
# )

except BaseException as e:
# if handle_sql_errors is True, then return error message
Expand All @@ -454,6 +467,10 @@ async def aretrieve_with_metadata(
retrieved_nodes,
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
Expand All @@ -463,8 +480,19 @@ async def aretrieve_with_metadata(
# metadata = {}
# else:
# raise
# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection = list()
for table in table_list:
if table.lower() in sql_query.lower():
table_collection.append(table)
return table_collection

def _sql_query_modification(self, sql_query_str):
table_pattern = r"FROM\s+(\w+)"
match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL)
Expand All @@ -473,7 +501,9 @@ def _sql_query_modification(self, sql_query_str):
new_sql_query_str = f"SELECT * FROM {first_table}"
logger.info(f"use the whole table {first_table} instead if possible")
else:
raise ValueError("No table is matched")
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
logger.info("No table is matched")

return new_sql_query_str

Expand Down
Loading