diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index 52a55db4..883058bb 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -126,7 +126,6 @@ def retrieve_with_metadata( query_bundle.query_str = self._limit_check(query_bundle.query_str) logger.info(f"Limited SQL query: {query_bundle.query_str}") - # set timeout to 10s # set timeout to 10s signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(10) # start @@ -367,22 +366,9 @@ def retrieve_with_metadata( metadata, ) = self._sql_retriever.retrieve_with_metadata(sql_query_str) retrieved_nodes[0].metadata["invalid_flag"] = 0 - retrieved_nodes[0].metadata["invalid_flag"] = 0 logger.info( f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" ) - # 如果生成的sql语句执行后无结果,待bad case补充 - # if retrieved_nodes[0].metadata["query_output"] == "": - - # new_sql_query_str = self._sql_query_modification(sql_query_str) - # ( - # retrieved_nodes, - # metadata, - # ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str) - # logger.info( - # f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" - # ) - # 如果生成的sql语句执行后无结果,待bad case补充 # if retrieved_nodes[0].metadata["query_output"] == "": # new_sql_query_str = self._sql_query_modification(sql_query_str) @@ -399,31 +385,38 @@ def retrieve_with_metadata( logger.info(f"async error info: {e}\n") new_sql_query_str = self._sql_query_modification(sql_query_str) - ( - retrieved_nodes, - metadata, - ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str) - retrieved_nodes[0].metadata["invalid_flag"] = 1 - retrieved_nodes[0].metadata[ - "generated_query_code_instruction" - ] = sql_query_str - retrieved_nodes[0].metadata["invalid_flag"] = 1 - retrieved_nodes[0].metadata[ - "generated_query_code_instruction" - ] = sql_query_str - logger.info( - f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" - ) - # err_node = TextNode(text=f"Error: {e!s}") - # logger.info(f"async error_node info: {err_node}\n") - # retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)] - # metadata = {} - # else: - # raise - # add query_tables into metadata - query_tables = self._get_table_from_sql(self._tables, sql_query_str) - retrieved_nodes[0].metadata["query_tables"] = query_tables + # 如果找到table,生成新的sql_query + if new_sql_query_str != sql_query_str: + ( + retrieved_nodes, + metadata, + ) = self._sql_retriever.retrieve_with_metadata(new_sql_query_str) + retrieved_nodes[0].metadata["invalid_flag"] = 1 + retrieved_nodes[0].metadata[ + "generated_query_code_instruction" + ] = sql_query_str + logger.info( + f"> Whole SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" + ) + # 没有找到table,新旧sql_query一样,不再通过_sql_retriever执行,直接retrieved_nodes + else: + logger.info(f"[{new_sql_query_str}] is not even a SQL") + retrieved_nodes = [ + NodeWithScore( + node=TextNode( + text=new_sql_query_str, + metadata={ + "query_code_instruction": new_sql_query_str, + "generated_query_code_instruction": sql_query_str, + "query_output": "", + "invalid_flag": 1, + }, + ), + score=1.0, + ), + ] + metadata = {} # add query_tables into metadata query_tables = self._get_table_from_sql(self._tables, sql_query_str) @@ -466,7 +459,6 @@ async def aretrieve_with_metadata( metadata, ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str) retrieved_nodes[0].metadata["invalid_flag"] = 0 - retrieved_nodes[0].metadata["invalid_flag"] = 0 logger.info( f"> SQL query result: {retrieved_nodes[0].metadata['query_output']}\n" ) @@ -522,15 +514,6 @@ async def aretrieve_with_metadata( ), ] metadata = {} - # err_node = TextNode(text=f"Error: {e!s}") - # logger.info(f"async error_node info: {err_node}\n") - # retrieved_nodes = [NodeWithScore(node=err_node, score=1.0)] - # metadata = {} - # else: - # raise - # add query_tables into metadata - query_tables = self._get_table_from_sql(self._tables, sql_query_str) - retrieved_nodes[0].metadata["query_tables"] = query_tables # add query_tables into metadata query_tables = self._get_table_from_sql(self._tables, sql_query_str) @@ -545,13 +528,6 @@ def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: table_collection.append(table) return table_collection - def _get_table_from_sql(self, table_list: list, sql_query: str) -> list: - table_collection = list() - for table in table_list: - if table.lower() in sql_query.lower(): - table_collection.append(table) - return table_collection - def _sql_query_modification(self, sql_query_str): table_pattern = r"FROM\s+(\w+)" match = re.search(table_pattern, sql_query_str, re.IGNORECASE | re.DOTALL) @@ -563,9 +539,6 @@ def _sql_query_modification(self, sql_query_str): # raise ValueError("No table is matched") new_sql_query_str = sql_query_str logger.info("No table is matched") - # raise ValueError("No table is matched") - new_sql_query_str = sql_query_str - logger.info("No table is matched") return new_sql_query_str