Skip to content

Commit

Permalink
chore(wren-ai-service): minor updates (#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh authored Oct 31, 2024
1 parent 3a0148a commit feff81b
Show file tree
Hide file tree
Showing 27 changed files with 261 additions and 136 deletions.
25 changes: 25 additions & 0 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,18 @@
st.session_state["sql_user_corrections_by_step"] = []
if "sql_regeneration_results" not in st.session_state:
st.session_state["sql_regeneration_results"] = None
if "language" not in st.session_state:
st.session_state["language"] = "English"


def onchange_demo_dataset():
st.session_state["chosen_dataset"] = st.session_state["choose_demo_dataset"]


def onchange_language():
st.session_state["language"] = st.session_state["language_selectbox"]


with st.sidebar:
st.markdown("## Deploy MDL Model")
uploaded_file = st.file_uploader(
Expand All @@ -80,6 +86,25 @@ def onchange_demo_dataset():
on_change=onchange_demo_dataset,
)

st.selectbox(
"Language",
key="language_selectbox",
options=[
"English",
"Spanish",
"French",
"TraditionalChinese",
"SimplifiedChinese",
"German",
"Portuguese",
"Russian",
"Japanese",
"Korean",
],
index=0,
on_change=onchange_language,
)

if uploaded_file is not None:
match = re.match(
r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$",
Expand Down
6 changes: 6 additions & 0 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,9 @@ def ask(query: str, query_history: Optional[dict] = None):
"query": query,
"id": st.session_state["deployment_id"],
"history": query_history,
"configurations": {
"language": st.session_state["language"],
},
},
)

Expand Down Expand Up @@ -703,6 +706,9 @@ def get_sql_answer(
"query": query,
"sql": sql,
"sql_summary": sql_summary,
"configurations": {
"language": st.session_state["language"],
},
},
)

Expand Down
28 changes: 24 additions & 4 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import logging
from datetime import datetime
from pprint import pformat
from typing import Any, Dict, List, Optional

import aiohttp
import orjson
import pytz
from haystack import component

from src.core.engine import (
Expand Down Expand Up @@ -111,13 +113,21 @@ def __init__(self, engine: Engine):
)
async def run(
self,
replies: List[str],
replies: List[str] | List[List[str]],
project_id: str | None = None,
) -> dict:
try:
cleaned_generation_result = orjson.loads(
clean_generation_result(replies[0])
)["results"]
if isinstance(replies[0], dict):
cleaned_generation_result = [
orjson.loads(clean_generation_result(reply["replies"][0]))[
"results"
][0]
for reply in replies
]
else:
cleaned_generation_result = orjson.loads(
clean_generation_result(replies[0])
)["results"]

if isinstance(cleaned_generation_result, dict):
cleaned_generation_result = [cleaned_generation_result]
Expand Down Expand Up @@ -361,3 +371,13 @@ def construct_instructions(configurations: AskConfigurations | None):
instructions += f"- For calendar year related computation, it should be started from {configurations.fiscal_year.start} to {configurations.fiscal_year.end}"

return instructions


def show_current_time(timezone: AskConfigurations.Timezone):
# Get the current time in the specified timezone
tz = pytz.timezone(
timezone.name
) # Assuming timezone.name contains the timezone string
current_time = datetime.now(tz)

return f'{current_time.strftime("%Y-%m-%d %A")}' # YYYY-MM-DD weekday_name, ex: 2024-10-23 Wednesday
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand All @@ -18,6 +17,7 @@
TEXT_TO_SQL_RULES,
SQLGenPostProcessor,
construct_instructions,
show_current_time,
sql_generation_system_prompt,
)
from src.utils import async_timer, timer
Expand Down Expand Up @@ -136,8 +136,8 @@ def prompt(
documents: List[str],
history: AskHistory,
alert: str,
configurations: AskConfigurations,
prompt_builder: PromptBuilder,
configurations: AskConfigurations | None = None,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
Expand All @@ -149,7 +149,7 @@ def prompt(
history=history,
alert=alert,
instructions=construct_instructions(configurations),
current_time=datetime.now(),
current_time=show_current_time(configurations.timezone),
)


Expand Down Expand Up @@ -228,8 +228,8 @@ def visualize(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations,
project_id: str | None = None,
configurations: AskConfigurations | None = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Expand Down Expand Up @@ -258,8 +258,8 @@ async def run(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations,
project_id: str | None = None,
configurations: AskConfigurations | None = None,
):
logger.info("Follow-Up SQL Generation pipeline is running...")
return await self._pipe.execute(
Expand Down
34 changes: 27 additions & 7 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
3. Read the sql and understand the data.
4. Generate a consice and clear answer in string format and a reasoning process in string format to the user's question based on the data, sql and sql summary.
5. If answer is in list format, only list top few examples, and tell users there are more results omitted.
6. Answer must be in the same language user specified.
### OUTPUT FORMAT
Expand All @@ -49,7 +50,7 @@
SQL: {{ sql }}
SQL summary: {{ sql_summary }}
Data: {{ sql_data }}
Language: {{ language }}
Please think step by step and answer the user's question.
"""

Expand Down Expand Up @@ -127,15 +128,20 @@ def prompt(
sql: str,
sql_summary: str,
execute_sql: dict,
language: str,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"sql: {sql}")
logger.debug(f"sql_summary: {sql_summary}")
logger.debug(f"sql data: {execute_sql}")

logger.debug(f"language: {language}")
return prompt_builder.run(
query=query, sql=sql, sql_summary=sql_summary, sql_data=execute_sql["results"]
query=query,
sql=sql,
sql_summary=sql_summary,
sql_data=execute_sql["results"],
language=language,
)


Expand Down Expand Up @@ -202,7 +208,12 @@ def __init__(
)

def visualize(
self, query: str, sql: str, sql_summary: str, project_id: str | None = None
self,
query: str,
sql: str,
sql_summary: str,
language: str,
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Expand All @@ -215,6 +226,7 @@ def visualize(
"query": query,
"sql": sql,
"sql_summary": sql_summary,
"language": language,
"project_id": project_id,
**self._components,
},
Expand All @@ -225,7 +237,12 @@ def visualize(
@async_timer
@observe(name="SQL Answer Generation")
async def run(
self, query: str, sql: str, sql_summary: str, project_id: str | None = None
self,
query: str,
sql: str,
sql_summary: str,
language: str,
project_id: str | None = None,
) -> dict:
logger.info("Sql_Answer Generation pipeline is running...")
return await self._pipe.execute(
Expand All @@ -234,6 +251,7 @@ async def run(
"query": query,
"sql": sql,
"sql_summary": sql_summary,
"language": language,
"project_id": project_id,
**self._components,
},
Expand All @@ -257,9 +275,11 @@ async def run(
engine=engine,
)

pipeline.visualize("query", "SELECT * FROM table_name", "sql summary")
pipeline.visualize("query", "SELECT * FROM table_name", "sql summary", "English")
async_validate(
lambda: pipeline.run("query", "SELECT * FROM table_name", "sql summary")
lambda: pipeline.run(
"query", "SELECT * FROM table_name", "sql summary", "English"
)
)

langfuse_context.flush()
57 changes: 32 additions & 25 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import sys
from pathlib import Path
Expand Down Expand Up @@ -28,8 +29,8 @@
You are a Trino SQL expert with exceptional logical thinking skills and debugging skills.
### TASK ###
Now you are given a list of syntactically incorrect Trino SQL queries and related error messages.
With given database schema, please think step by step to correct these wrong Trino SQL quries.
Now you are given syntactically incorrect Trino SQL query and related error message.
With given database schema, please think step by step to correct the wrong Trino SQL query.
### DATABASE SCHEMA ###
{% for document in documents %}
Expand All @@ -41,19 +42,15 @@
{
"results": [
{"sql": <CORRECTED_SQL_QUERY_STRING_1>, "summary": <ORIGINAL_SUMMARY_STRING_1>},
{"sql": <CORRECTED_SQL_QUERY_STRING_2>, "summary": <ORIGINAL_SUMMARY_STRING_2>}
{"sql": <CORRECTED_SQL_QUERY_STRING>, "summary": <ORIGINAL_SUMMARY_STRING>},
]
}
{{ alert }}
### QUESTION ###
{% for invalid_generation_result in invalid_generation_results %}
sql: {{ invalid_generation_result.sql }}
summary: {{ invalid_generation_result.summary }}
error: {{ invalid_generation_result.error }}
{% endfor %}
SQL: {{ invalid_generation_result.sql }}
Error Message: {{ invalid_generation_result.error }}
Let's think step by step.
"""
Expand All @@ -62,46 +59,56 @@
## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
def prompts(
documents: List[Document],
invalid_generation_results: List[Dict],
alert: str,
prompt_builder: PromptBuilder,
) -> dict:
) -> list[dict]:
logger.debug(
f"documents: {orjson.dumps(documents, option=orjson.OPT_INDENT_2).decode()}"
)
logger.debug(
f"invalid_generation_results: {orjson.dumps(invalid_generation_results, option=orjson.OPT_INDENT_2).decode()}"
)
return prompt_builder.run(
documents=documents,
invalid_generation_results=invalid_generation_results,
alert=alert,
)
return [
prompt_builder.run(
documents=documents,
invalid_generation_result=invalid_generation_result,
alert=alert,
)
for invalid_generation_result in invalid_generation_results
]


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_correction(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))
async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[dict]:
logger.debug(
f"prompts: {orjson.dumps(prompts, option=orjson.OPT_INDENT_2).decode()}"
)

tasks = []
for prompt in prompts:
task = asyncio.ensure_future(generator.run(prompt=prompt.get("prompt")))
tasks.append(task)

return await asyncio.gather(*tasks)


@async_timer
@observe(capture_input=False)
async def post_process(
generate_sql_correction: dict,
generate_sql_corrections: list[dict],
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
) -> list[dict]:
logger.debug(
f"generate_sql_correction: {orjson.dumps(generate_sql_correction, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_correction.get("replies"), project_id=project_id
f"generate_sql_corrections: {orjson.dumps(generate_sql_corrections, option=orjson.OPT_INDENT_2).decode()}"
)

return await post_processor.run(generate_sql_corrections, project_id=project_id)


## End of Pipeline

Expand Down
Loading

0 comments on commit feff81b

Please sign in to comment.