Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tiny sync #219

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 31 additions & 58 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def retrieve_with_metadata(
query_bundle.query_str = self._limit_check(query_bundle.query_str)
logger.info(f"Limited SQL query: {query_bundle.query_str}")

# set timeout to 10s
# set timeout to 10s
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(10) # start
Expand Down Expand Up @@ -367,22 +366,9 @@ def retrieve_with_metadata(
metadata,
) = self._sql_retriever.retrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":

# new_sql_query_str = self._sql_query_modification(sql_query_str)
# (
# retrieved_nodes,
# metadata,
# ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
# logger.info(
# f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
# )
# 如果生成的sql语句执行后无结果,待bad case补充
# if retrieved_nodes[0].metadata["query_output"] == "":

# new_sql_query_str = self._sql_query_modification(sql_query_str)
Expand All @@ -399,31 +385,38 @@ def retrieve_with_metadata(
logger.info(f"async error info: {e}\n")

new_sql_query_str = self._sql_query_modification(sql_query_str)
(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# err_node = TextNode(text=f"Error: {e!s}")
# logger.info(f"async error_node info: {err_node}\n")
# retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)]
# metadata = {}
# else:
# raise

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables
# 如果找到table,生成新的sql_query
if new_sql_query_str != sql_query_str:
(
retrieved_nodes,
metadata,
) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 1
retrieved_nodes[0].metadata[
"generated_query_code_instruction"
] = sql_query_str
logger.info(
f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
# 没有找到table,新旧sql_query一样,不再通过_sql_retriever执行,直接retrieved_nodes
else:
logger.info(f"[{new_sql_query_str}] is not even a SQL")
retrieved_nodes = [
NodeWithScore(
node=TextNode(
text=new_sql_query_str,
metadata={
"query_code_instruction": new_sql_query_str,
"generated_query_code_instruction": sql_query_str,
"query_output": "",
"invalid_flag": 1,
},
),
score=1.0,
),
]
metadata = {}

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
Expand Down Expand Up @@ -466,7 +459,6 @@ async def aretrieve_with_metadata(
metadata,
) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
retrieved_nodes[0].metadata["invalid_flag"] = 0
retrieved_nodes[0].metadata["invalid_flag"] = 0
logger.info(
f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n"
)
Expand Down Expand Up @@ -522,15 +514,6 @@ async def aretrieve_with_metadata(
),
]
metadata = {}
# err_node = TextNode(text=f"Error: {e!s}")
# logger.info(f"async error_node info: {err_node}\n")
# retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)]
# metadata = {}
# else:
# raise
# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
retrieved_nodes[0].metadata["query_tables"] = query_tables

# add query_tables into metadata
query_tables = self._get_table_from_sql(self._tables, sql_query_str)
Expand All @@ -545,13 +528,6 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection.append(table)
return table_collection

def _get_table_from_sql(self, table_list: list, sql_query: str) -> list:
table_collection = list()
for table in table_list:
if table.lower() in sql_query.lower():
table_collection.append(table)
return table_collection

def _sql_query_modification(self, sql_query_str):
table_pattern = r"FROM\s+(\w+)"
match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL)
Expand All @@ -563,9 +539,6 @@ def _sql_query_modification(self, sql_query_str):
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
logger.info("No table is matched")
# raise ValueError("No table is matched")
new_sql_query_str = sql_query_str
logger.info("No table is matched")

return new_sql_query_str

Expand Down
Loading