Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add RetrySqlQueryCreatorTool for handling failed SQL query generation #15

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""Toolkit for interacting with a SQL database."""
from typing import List

from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseToolkit

from langchain_community.utilities.sql_database import SQLDatabase

from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool
from langchain_core.pydantic_v1 import Field

from langchain_community.tools import BaseTool
from langchain_community.tools.sql_coder.tool import (
QuerySparkSQLDataBaseTool,
SqlQueryCreatorTool,
RetrySqlQueryCreatorTool
)

class SQLCoderToolkit(BaseToolkit):

Check failure on line 19 in libs/community/langchain_community/agent_toolkits/sqlcoder/toolkit.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (I001)

langchain_community/agent_toolkits/sqlcoder/toolkit.py:2:1: I001 Import block is un-sorted or un-formatted
"""Toolkit for interacting with SQL databases."""

db: SQLDatabase = Field(exclude=True)
Expand Down Expand Up @@ -54,6 +55,7 @@
db=self.db, description=query_sql_database_tool_description
),
QuerySQLCheckerTool(db=self.db, llm=self.llm),
RetrySqlQueryCreatorTool(sqlcreatorllm=self.sqlcreatorllm),
SqlQueryCreatorTool(
sqlcreatorllm=self.sqlcreatorllm ,
db=self.db,
Expand Down
141 changes: 104 additions & 37 deletions libs/community/langchain_community/tools/sql_coder/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from langchain_core.tools import StateTool
import re

ERROR = ""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider removing the unused ERROR variable.

The variable ERROR is defined but never used in the code. If it's not needed, it would be better to remove it to keep the code clean.

Suggested change
ERROR = ""
# Consider removing the unused ERROR variable.
# The variable `ERROR` is defined but never used in the code.
# If it's not needed, it would be better to remove it to keep the code clean.

class BaseSQLDatabaseTool(BaseModel):
"""Base tool for interacting with a SQL database."""

Expand Down Expand Up @@ -43,7 +44,7 @@ class Config(StateTool.Config):
description: str = """
Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned.
If an error is returned, re-run the sql_db_query_creator tool to get the correct query.
If an error is returned, re-run the retry_sql_db_query_creator tool to get the correct query.
"""

def __init__(__pydantic_self__, **data: Any) -> None:
Expand All @@ -65,6 +66,7 @@ def _run(
)
executable_query = executable_query.strip('\"')
executable_query = re.sub('\\n```', '',executable_query)
self.db.run_no_throw(executable_query)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Duplicate call to self.db.run_no_throw(executable_query).

The method self.db.run_no_throw(executable_query) is called twice consecutively. This seems redundant and could be removed.

return self.db.run_no_throw(executable_query)

async def _arun(
Expand All @@ -75,14 +77,98 @@ async def _arun(
raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async")

def _extract_sql_query(self):
for value in self.state:
for value in reversed(self.state):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (bug_risk): Reversing the state list might have unintended consequences.

Reversing the state list could lead to unexpected behavior if the order of states is important. Ensure that this change is intentional and won't cause issues.

for key, input_string in value.items():
if "sql_db_query_creator" in key:
if "tool='retry_sql_db_query_creator'" in key:
return input_string
elif "tool='sql_db_query_creator'" in key:
Comment on lines +84 to +86
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): We've found these issues:

Suggested change
if "tool='retry_sql_db_query_creator'" in key:
return input_string
elif "tool='sql_db_query_creator'" in key:
if (
"tool='retry_sql_db_query_creator'" in key
or "tool='sql_db_query_creator'" in key
):

return input_string
return None



class RetrySqlQueryCreatorTool(StateTool):
"""Tool for re-creating SQL query.Use this to retry creation of sql query."""

name = "retry_sql_db_query_creator"
description = """
This is a tool used to re-create sql query for user input based on the incorrect query generated and error returned from sql_db_query tool.
Input to this tool is user prompt, incorrect sql query and error message
Output is a sql query
After running this tool, you can run sql_db_query tool to get the result
"""
sqlcreatorllm: BaseLanguageModel = Field(exclude=True)


class Config(StateTool.Config):
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
extra = Extra.allow

def __init__(__pydantic_self__, **data: Any) -> None:
"""Initialize the tool."""
super().__init__(**data)

def _run(
self,
user_input: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the SQL query for the incorrect query."""
return self._create_sql_query(user_input)

async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SqlQueryCreatorTool does not support async")

def _create_sql_query(self,user_input):

sql_query = self._extract_sql_query()
error_message = self._extract_error_message()
if sql_query is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Consider logging when sql_query is None.

It might be useful to log a message when sql_query is None to help with debugging and understanding why the tool is not meant to be run directly.

Suggested change
if sql_query is None:
if sql_query is None:
logging.warning("SQL query is None. This tool is not meant to be run directly.")
return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool"

return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool"

prompt_input = PromptTemplate(
input_variables=["user_input","sql_query", "error_message"],
template=SQL_QUERY_CREATOR_RETRY
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"sql_query": sql_query,
"error_message": error_message,
"user_input": user_input
}
)
)
sql_query = sql_query.replace("```","")
sql_query = sql_query.replace("sql","")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Removing 'sql' from the query might cause issues.

The line sql_query = sql_query.replace("sql","") removes all occurrences of 'sql' from the query. This might lead to incorrect SQL queries if 'sql' is part of a table or column name.


return sql_query

def _extract_sql_query(self):
for value in reversed(self.state):
for key, input_string in value.items():
if "tool='retry_sql_db_query_creator'" in key:
return input_string
elif "tool='sql_db_query_creator'" in key:
Comment on lines +160 to +162
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): We've found these issues:

Suggested change
if "tool='retry_sql_db_query_creator'" in key:
return input_string
elif "tool='sql_db_query_creator'" in key:
if (
"tool='retry_sql_db_query_creator'" in key
or "tool='sql_db_query_creator'" in key
):

return input_string
return None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Consider raising an exception instead of returning None.

Returning None might lead to silent failures. Consider raising an exception to make it clear that an error has occurred.

Suggested change
return None
raise ValueError("No valid key found in input string")


def _extract_error_message(self):
for value in reversed(self.state):
for key, input_string in value.items():
if "tool='sql_db_query'" in key:
if "Error" in input_string:
return input_string
Comment on lines +169 to +171
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Merge nested if conditions (merge-nested-ifs)

Suggested change
if "tool='sql_db_query'" in key:
if "Error" in input_string:
return input_string
if "tool='sql_db_query'" in key and "Error" in input_string:
return input_string


ExplanationToo much nesting can make code difficult to understand, and this is especially
true in Python, where there are no brackets to help out with the delineation of
different nesting levels.

Reading deeply nested code is confusing, since you have to keep track of which
conditions relate to which levels. We therefore strive to reduce nesting where
possible, and the situation where two if conditions can be combined using
and is an easy win.

return None

class SqlQueryCreatorTool(StateTool):
"""Tool for creating SQL query.Use this to create sql query."""

Expand Down Expand Up @@ -147,43 +233,24 @@ def _parse_data_model_context(self):
def _create_sql_query(self,user_input):

few_shot_examples = self._parse_few_shot_examples()
sql_query = self._extract_sql_query()
db_schema = self._parse_db_schema()
data_model_context = self._parse_data_model_context()
if sql_query is None:
prompt_input = PromptTemplate(
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"],
template=self.SQL_QUERY_CREATOR_TEMPLATE,
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"db_schema": db_schema,
"user_input": user_input,
"few_shot_examples": few_shot_examples,
"data_model_context": data_model_context
}
)
)
else:
prompt_input = PromptTemplate(
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"],
template=SQL_QUERY_CREATOR_RETRY
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"db_schema": db_schema,
"user_input": user_input,
"few_shot_examples": few_shot_examples,
"data_model_context": data_model_context
}
)
prompt_input = PromptTemplate(
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"],
template=self.SQL_QUERY_CREATOR_TEMPLATE,
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"db_schema": db_schema,
"user_input": user_input,
"few_shot_examples": few_shot_examples,
"data_model_context": data_model_context
}
)
)
sql_query = sql_query.replace("```","")
sql_query = sql_query.replace("sql","")

Expand Down
16 changes: 14 additions & 2 deletions libs/langchain/langchain/tools/sqlcoder/prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@


SQL_QUERY_CREATOR_RETRY = """
You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query.
"""
Your task is convert an incorrect query resulting from user question to a correct query which is databricks sql compatible.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (typo): Typo in the prompt template.

The sentence should be 'Your task is to convert an incorrect query resulting from a user question to a correct query which is Databricks SQL compatible.'

Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float

### Task:
Generate a correct SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION].
The query you will correct is: {sql_query}
The error message is: {error_message}

### Response:
Based on your instructions, here is the SQL query I have generated
[SQL]"""

SQL_QUERY_CREATOR_7b = """
### Instructions:
Expand Down
Loading