Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): refine sql generation #784

Merged
merged 10 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand Down Expand Up @@ -117,6 +118,7 @@
Previous SQL Summary: {{ history.summary }}
Previous SQL Query: {{ history.sql }}
User's Follow-up Question: {{ query }}
Current Time: {{ current_time }}

{% if instructions %}
Instructions: {{ instructions }}
Expand Down Expand Up @@ -147,6 +149,7 @@ def prompt(
history=history,
alert=alert,
instructions=construct_instructions(configurations),
current_time=datetime.now(),
)


Expand Down Expand Up @@ -183,7 +186,7 @@ class GenerationResults(BaseModel):
results: list[SQLResult]


FOLLOWUP_GENERATION_MODEL_KWARGS = {
FOLLOWUP_SQL_GENERATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -204,7 +207,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=FOLLOWUP_GENERATION_MODEL_KWARGS,
generation_kwargs=FOLLOWUP_SQL_GENERATION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=text_to_sql_with_followup_user_prompt_template
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class RelationshipResult(BaseModel):
relationships: list[ModelRelationship]


RELATIONSHIP_MODEL_KWARGS = {
RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand Down Expand Up @@ -136,7 +136,8 @@ def __init__(
self._components = {
"prompt_builder": PromptBuilder(template=user_prompt_template),
"generator": llm_provider.get_generator(
system_prompt=system_prompt, generation_kwargs=RELATIONSHIP_MODEL_KWARGS
system_prompt=system_prompt,
generation_kwargs=RELATIONSHIP_RECOMMENDATION_MODEL_KWARGS,
),
"engine": engine,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SemanticResult(BaseModel):
models: list[SemanticModel]


SEMANTIC_MODEL_KWARGS = {
SEMANTICS_DESCRIPTION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand Down Expand Up @@ -182,7 +182,8 @@ def __init__(self, llm_provider: LLMProvider, **_):
self._components = {
"prompt_builder": PromptBuilder(template=user_prompt_template),
"generator": llm_provider.get_generator(
system_prompt=system_prompt, generation_kwargs=SEMANTIC_MODEL_KWARGS
system_prompt=system_prompt,
generation_kwargs=SEMANTICS_DESCRIPTION_MODEL_KWARGS,
),
}
self._final = "normalize"
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class AnswerResults(BaseModel):
answer: str


ANSWER_MODEL_KWARGS = {
SQL_ANSWER_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -190,7 +190,7 @@ def __init__(
),
"generator": llm_provider.get_generator(
system_prompt=sql_to_answer_system_prompt,
generation_kwargs=ANSWER_MODEL_KWARGS,
generation_kwargs=SQL_ANSWER_MODEL_KWARGS,
),
"post_processor": SQLAnswerGenerationPostProcessor(),
}
Expand Down
30 changes: 19 additions & 11 deletions wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
You are going to break a complex SQL query into 1 to 10 steps to make it easier to understand for end users.
Each step should have a SQL query part, a summary explaining the purpose of that query, and a CTE name to link the queries.
Also, you need to give a short description describing the purpose of the original SQL query.
Description and summary in each step MUST BE in the same language as the user's question.
Description and summary in each step MUST BE in the same language as user specified.

### SQL QUERY BREAKDOWN INSTRUCTIONS ###
- YOU MUST BREAK DOWN any SQL query into small steps if there is JOIN operations or sub-queries.
Expand All @@ -39,7 +39,7 @@
- MUST USE alias from the original SQL query.

### SUMMARY AND DESCRIPTION INSTRUCTIONS ###
- SUMMARY AND DESCRIPTION MUST USE the same language as the user's question
- SUMMARY AND DESCRIPTION MUST BE the same language as the user speficied.
- SUMMARY AND DESCRIPTION MUST BE human-readable and easy to understand.
- SUMMARY AND DESCRIPTION MUST BE concise and to the point.

Expand Down Expand Up @@ -85,16 +85,16 @@
The final answer must be a valid JSON format as following:

{
"description": <SHORT_SQL_QUERY_DESCRIPTION_USING_SAME_LANGUAGE_USER_QUESTION_USING>,
"description": <SHORT_SQL_QUERY_DESCRIPTION_STRING>,
"steps: [
{
"sql": <SQL_QUERY_STRING_1>,
"summary": <SUMMARY_STRING_USING_SAME_LANGUAGE_USER_QUESTION_USING_1>,
"summary": <SUMMARY_STRING_1>,
"cte_name": <CTE_NAME_STRING_1>
},
{
"sql": <SQL_QUERY_STRING_2>,
"summary": <SUMMARY_STRING_USING_SAME_LANGUAGE_USER_QUESTION_USING_2>,
"summary": <SUMMARY_STRING_2>,
"cte_name": <CTE_NAME_STRING_2>
},
...
Expand All @@ -106,6 +106,7 @@
### INPUT ###
User's Question: {{ query }}
SQL query: {{ sql }}
Language: {{ language }}

Let's think step by step.
"""
Expand All @@ -114,10 +115,11 @@
## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(query: str, sql: str, prompt_builder: PromptBuilder) -> dict:
def prompt(query: str, sql: str, language: str, prompt_builder: PromptBuilder) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"sql: {sql}")
return prompt_builder.run(query=query, sql=sql)
logger.debug(f"language: {language}")
return prompt_builder.run(query=query, sql=sql, language=language)


@async_timer
Expand Down Expand Up @@ -154,7 +156,7 @@ class BreakdownResults(BaseModel):
steps: list[StepResult]


BREAKDOWN_MODEL_KWARGS = {
SQL_BREAKDOWN_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -175,7 +177,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_breakdown_system_prompt,
generation_kwargs=BREAKDOWN_MODEL_KWARGS,
generation_kwargs=SQL_BREAKDOWN_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_breakdown_user_prompt_template
Expand All @@ -187,7 +189,9 @@ def __init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)

def visualize(self, query: str, sql: str, project_id: str | None = None) -> None:
def visualize(
self, query: str, sql: str, language: str, project_id: str | None = None
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Path(destination).mkdir(parents=True, exist_ok=True)
Expand All @@ -199,6 +203,7 @@ def visualize(self, query: str, sql: str, project_id: str | None = None) -> None
"query": query,
"sql": sql,
"project_id": project_id,
"language": language,
**self._components,
},
show_legend=True,
Expand All @@ -207,14 +212,17 @@ def visualize(self, query: str, sql: str, project_id: str | None = None) -> None

@async_timer
@observe(name="SQL Breakdown Generation")
async def run(self, query: str, sql: str, project_id: str | None = None) -> dict:
async def run(
self, query: str, sql: str, language: str, project_id: str | None = None
) -> dict:
logger.info("SQL Breakdown Generation pipeline is running...")
return await self._pipe.execute(
["post_process"],
inputs={
"query": query,
"sql": sql,
"project_id": project_id,
"language": language,
**self._components,
},
)
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class CorrectedResults(BaseModel):
results: list[CorrectedSQLResult]


CORRECTION_MODEL_KWARGS = {
SQL_CORRECTION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -136,7 +136,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=CORRECTION_MODEL_KWARGS,
generation_kwargs=SQL_CORRECTION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_correction_user_prompt_template
Expand Down
10 changes: 7 additions & 3 deletions wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand Down Expand Up @@ -47,6 +48,7 @@
{% endfor %}

User's input: {{query}}
Current Time: {{ current_time }}
"""


Expand All @@ -62,7 +64,9 @@ def prompt(
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
logger.debug(f"history: {history}")
return prompt_builder.run(query=query, documents=documents, sql=history.sql)
return prompt_builder.run(
query=query, documents=documents, sql=history.sql, current_time=datetime.now()
)


@async_timer
Expand Down Expand Up @@ -98,7 +102,7 @@ class ExpansionResults(BaseModel):
results: list[ExpandedResult]


EXPANSION_MODEL_KWARGS = {
SQL_EXPANSION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -119,7 +123,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_expansion_system_prompt,
generation_kwargs=EXPANSION_MODEL_KWARGS,
generation_kwargs=SQL_EXPANSION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_expansion_user_prompt_template
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ class ExplanationResults(BaseModel):
results: Optional[AggregatedItemsResult]


EXPLANATION_MODEL_KWARGS = {
SQL_EXPLANATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -631,7 +631,7 @@ def __init__(
),
"generator": llm_provider.get_generator(
system_prompt=sql_explanation_system_prompt,
generation_kwargs=EXPLANATION_MODEL_KWARGS,
generation_kwargs=SQL_EXPLANATION_MODEL_KWARGS,
),
"post_processor": SQLExplanationGenerationPostProcessor(),
}
Expand Down
7 changes: 5 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

Expand Down Expand Up @@ -101,6 +102,7 @@

### QUESTION ###
User's Question: {{ query }}
Current Time: {{ current_time }}

Let's think step by step.
"""
Expand Down Expand Up @@ -133,6 +135,7 @@ def prompt(
alert=alert,
instructions=construct_instructions(configurations),
samples=samples,
current_time=datetime.now(),
)


Expand Down Expand Up @@ -165,7 +168,7 @@ class GenerationResults(BaseModel):
results: list[SQLResult]


GENERATION_MODEL_KWARGS = {
SQL_GENERATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -186,7 +189,7 @@ def __init__(
self._components = {
"generator": llm_provider.get_generator(
system_prompt=sql_generation_system_prompt,
generation_kwargs=GENERATION_MODEL_KWARGS,
generation_kwargs=SQL_GENERATION_MODEL_KWARGS,
),
"prompt_builder": PromptBuilder(
template=sql_generation_user_prompt_template
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_regeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class RegenerationResults(BaseModel):
steps: list[StepResult]


REGENERATION_MODEL_KWARGS = {
SQL_REGENERATION_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
Expand All @@ -189,7 +189,7 @@ def __init__(
),
"generator": llm_provider.get_generator(
system_prompt=sql_regeneration_system_prompt,
generation_kwargs=REGENERATION_MODEL_KWARGS,
generation_kwargs=SQL_REGENERATION_MODEL_KWARGS,
),
"post_processor": SQLBreakdownGenPostProcessor(engine=engine),
}
Expand Down
Loading
Loading