Skip to content

Commit

Permalink
chore(wren-ai-service): refine sql generation (#784)
Browse files Browse the repository at this point in the history
* add language params

* allow producing candidates while generating status

* fix bug

* add language

* add current time in prompt

* refine names

* fix

* update

* fix bug
  • Loading branch information
cyyeh authored Oct 22, 2024
1 parent 0cf9f70 commit 1c83788
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 87 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand Down Expand Up @@ -117,6 +118,7 @@
Previous SQL Summary: {{ history.summary }}
Previous SQL Query: {{ history.sql }}
User's Follow-up Question: {{ query }}
Current Time: {{ current_time }}
{% if instructions %}
Instructions: {{ instructions }}
Expand Down Expand Up @@ -147,6 +149,7 @@ def prompt(
history=history,
alert=alert,
instructions=construct_instructions(configurations),
current_time=datetime.now(),
)


Expand Down Expand Up @@ -183,7 +186,7 @@ class GenerationResults(BaseModel):
results: list[SQLResult]


FOLLOWUP_GENERATION_MODEL_KWARGS = {
FOLLOWUP_SQL_GENERATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -204,7 +207,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=FOLLOWUP_GENERATION_MODEL_KWARGS,
generation_kwargs=FOLLOWUP_SQL_GENERATION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=text_to_sql_with_followup_user_prompt_template
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class RelationshipResult(BaseModel):
relationships: list[ModelRelationship]


RELATIONSHIP_MODEL_KWARGS = {
RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand Down Expand Up @@ -136,7 +136,8 @@ def __init__(
self._components = {
"prompt_builder": PromptBuilder(template=user_prompt_template),
"generator": llm_provider.get_generator(
system_prompt=system_prompt, generation_kwargs=RELATIONSHIP_MODEL_KWARGS
system_prompt=system_prompt,
generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS,
),
"engine": engine,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SemanticResult(BaseModel):
models: list[SemanticModel]


SEMANTIC_MODEL_KWARGS = {
SEMANTICS_DESCRIPTION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand Down Expand Up @@ -182,7 +182,8 @@ def __init__(self, llm_provider: LLMProvider, **_):
self._components = {
"prompt_builder": PromptBuilder(template=user_prompt_template),
"generator": llm_provider.get_generator(
system_prompt=system_prompt, generation_kwargs=SEMANTIC_MODEL_KWARGS
system_prompt=system_prompt,
generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS,
),
}
self._final = "normalize"
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class AnswerResults(BaseModel):
answer: str


ANSWER_MODEL_KWARGS = {
SQL_ANSWER_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -190,7 +190,7 @@ def __init__(
),
"generator": llm_provider.get_generator(
system_prompt=sql_to_answer_system_prompt,
generation_kwargs=ANSWER_MODEL_KWARGS,
generation_kwargs=SQL_ANSWER_MODEL_KWARGS,
),
"post_processor": SQLAnswerGenerationPostProcessor(),
}
Expand Down
30 changes: 19 additions & 11 deletions wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
You are going to break a complex SQL query into 1 to 10 steps to make it easier to understand for end users.
Each step should have a SQL query part, a summary explaining the purpose of that query, and a CTE name to link the queries.
Also, you need to give a short description describing the purpose of the original SQL query.
Description and summary in each step MUST BE in the same language as the user's question.
Description and summary in each step MUST BE in the same language as user specified.
### SQL QUERY BREAKDOWN INSTRUCTIONS ###
- YOU MUST BREAK DOWN any SQL query into small steps if there is JOIN operations or sub-queries.
Expand All @@ -39,7 +39,7 @@
- MUST USE alias from the original SQL query.
### SUMMARY AND DESCRIPTION INSTRUCTIONS ###
- SUMMARY AND DESCRIPTION MUST USE the same language as the user's question
- SUMMARY AND DESCRIPTION MUST BE the same language as the user speficied.
- SUMMARY AND DESCRIPTION MUST BE human-readable and easy to understand.
- SUMMARY AND DESCRIPTION MUST BE concise and to the point.
Expand Down Expand Up @@ -85,16 +85,16 @@
The final answer must be a valid JSON format as following:
{
"description": <SHORT_SQL_QUERY_DESCRIPTION_USING_SAME_LANGUAGE_USER_QUESTION_USING>,
"description": <SHORT_SQL_QUERY_DESCRIPTION_STRING>,
"steps: [
{
"sql": <SQL_QUERY_STRING_1>,
"summary": <SUMMARY_STRING_USING_SAME_LANGUAGE_USER_QUESTION_USING_1>,
"summary": <SUMMARY_STRING_1>,
"cte_name": <CTE_NAME_STRING_1>
},
{
"sql": <SQL_QUERY_STRING_2>,
"summary": <SUMMARY_STRING_USING_SAME_LANGUAGE_USER_QUESTION_USING_2>,
"summary": <SUMMARY_STRING_2>,
"cte_name": <CTE_NAME_STRING_2>
},
...
Expand All @@ -106,6 +106,7 @@
### INPUT ###
User's Question: {{ query }}
SQL query: {{ sql }}
Language: {{ language }}
Let's think step by step.
"""
Expand All @@ -114,10 +115,11 @@
## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(query: str, sql: str, prompt_builder: PromptBuilder) -> dict:
def prompt(query: str, sql: str, language: str, prompt_builder: PromptBuilder) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"sql: {sql}")
return prompt_builder.run(query=query, sql=sql)
logger.debug(f"language: {language}")
return prompt_builder.run(query=query, sql=sql, language=language)


@async_timer
Expand Down Expand Up @@ -154,7 +156,7 @@ class BreakdownResults(BaseModel):
steps: list[StepResult]


BREAKDOWN_MODEL_KWARGS = {
SQL_BREAKDOWN_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -175,7 +177,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_breakdown_system_prompt,
generation_kwargs=BREAKDOWN_MODEL_KWARGS,
generation_kwargs=SQL_BREAKDOWN_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_breakdown_user_prompt_template
Expand All @@ -187,7 +189,9 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(self, query: str, sql: str, project_id: str | None = None) -> None:
def visualize(
self, query: str, sql: str, language: str, project_id: str | None = None
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)
Expand All @@ -199,6 +203,7 @@ def visualize(self, query: str, sql: str, project_id: str | None = None) -> None
"query": query,
"sql": sql,
"project_id": project_id,
"language": language,
**self._components,
},
show_legend=True,
Expand All @@ -207,14 +212,17 @@ def visualize(self, query: str, sql: str, project_id: str | None = None) -> None

@async_timer
@observe(name="SQL Breakdown Generation")
async def run(self, query: str, sql: str, project_id: str | None = None) -> dict:
async def run(
self, query: str, sql: str, language: str, project_id: str | None = None
) -> dict:
logger.info("SQL Breakdown Generation pipeline is running...")
return await self._pipe.execute(
["post_process"],
inputs={
"query": query,
"sql": sql,
"project_id": project_id,
"language": language,
**self._components,
},
)
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class CorrectedResults(BaseModel):
results: list[CorrectedSQLResult]


CORRECTION_MODEL_KWARGS = {
SQL_CORRECTION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -136,7 +136,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=CORRECTION_MODEL_KWARGS,
generation_kwargs=SQL_CORRECTION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_correction_user_prompt_template
Expand Down
10 changes: 7 additions & 3 deletions wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand Down Expand Up @@ -47,6 +48,7 @@
{% endfor %}
User's input: {{query}}
Current Time: {{ current_time }}
"""


Expand All @@ -62,7 +64,9 @@ def prompt(
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
logger.debug(f"history: {history}")
return prompt_builder.run(query=query, documents=documents, sql=history.sql)
return prompt_builder.run(
query=query, documents=documents, sql=history.sql, current_time=datetime.now()
)


@async_timer
Expand Down Expand Up @@ -98,7 +102,7 @@ class ExpansionResults(BaseModel):
results: list[ExpandedResult]


EXPANSION_MODEL_KWARGS = {
SQL_EXPANSION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -119,7 +123,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_expansion_system_prompt,
generation_kwargs=EXPANSION_MODEL_KWARGS,
generation_kwargs=SQL_EXPANSION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_expansion_user_prompt_template
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ class ExplanationResults(BaseModel):
results: Optional[AggregatedItemsResult]


EXPLANATION_MODEL_KWARGS = {
SQL_EXPLANATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -631,7 +631,7 @@ def __init__(
),
"generator": llm_provider.get_generator(
system_prompt=sql_explanation_system_prompt,
generation_kwargs=EXPLANATION_MODEL_KWARGS,
generation_kwargs=SQL_EXPLANATION_MODEL_KWARGS,
),
"post_processor": SQLExplanationGenerationPostProcessor(),
}
Expand Down
7 changes: 5 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

Expand Down Expand Up @@ -101,6 +102,7 @@
### QUESTION ###
User's Question: {{ query }}
Current Time: {{ current_time }}
Let's think step by step.
"""
Expand Down Expand Up @@ -133,6 +135,7 @@ def prompt(
alert=alert,
instructions=construct_instructions(configurations),
samples=samples,
current_time=datetime.now(),
)


Expand Down Expand Up @@ -165,7 +168,7 @@ class GenerationResults(BaseModel):
results: list[SQLResult]


GENERATION_MODEL_KWARGS = {
SQL_GENERATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -186,7 +189,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=GENERATION_MODEL_KWARGS,
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_generation_user_prompt_template
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_regeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class RegenerationResults(BaseModel):
steps: list[StepResult]


REGENERATION_MODEL_KWARGS = {
SQL_REGENERATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -189,7 +189,7 @@ def __init__(
),
"generator": llm_provider.get_generator(
system_prompt=sql_regeneration_system_prompt,
generation_kwargs=REGENERATION_MODEL_KWARGS,
generation_kwargs=SQL_REGENERATION_MODEL_KWARGS,
),
"post_processor": SQLBreakdownGenPostProcessor(engine=engine),
}
Expand Down
Loading

0 comments on commit 1c83788

Please sign in to comment.