Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): minor updates #833

Merged
merged 23 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions wren-ai-service/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,18 @@
st.session_state["sql_user_corrections_by_step"] = []
if "sql_regeneration_results" not in st.session_state:
st.session_state["sql_regeneration_results"] = None
if "language" not in st.session_state:
st.session_state["language"] = "English"


def onchange_demo_dataset():
st.session_state["chosen_dataset"] = st.session_state["choose_demo_dataset"]


def onchange_language():
st.session_state["language"] = st.session_state["language_selectbox"]


with st.sidebar:
st.markdown("## Deploy MDL Model")
uploaded_file = st.file_uploader(
Expand All @@ -80,6 +86,25 @@ def onchange_demo_dataset():
on_change=onchange_demo_dataset,
)

st.selectbox(
"Language",
key="language_selectbox",
options=[
"English",
"Spanish",
"French",
"TraditionalChinese",
"SimplifiedChinese",
"German",
"Portuguese",
"Russian",
"Japanese",
"Korean",
],
index=0,
on_change=onchange_language,
)

if uploaded_file is not None:
match = re.match(
r".+_(" + "|".join(DATA_SOURCES) + r")_mdl\.json$",
Expand Down
6 changes: 6 additions & 0 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,9 @@ def ask(query: str, query_history: Optional[dict] = None):
"query": query,
"id": st.session_state["deployment_id"],
"history": query_history,
"configurations": {
"language": st.session_state["language"],
},
},
)

Expand Down Expand Up @@ -703,6 +706,9 @@ def get_sql_answer(
"query": query,
"sql": sql,
"sql_summary": sql_summary,
"configurations": {
"language": st.session_state["language"],
},
},
)

Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def exception_handler(request, exc: Exception):

@app.exception_handler(RequestValidationError)
async def request_exception_handler(request, exc: Exception):
print(str(exc))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could remove printing the exception here, FastAPI handle the exception and also leave the log in console

return ORJSONResponse(
status_code=400,
content={"detail": str(exc)},
Expand Down
28 changes: 24 additions & 4 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import logging
from datetime import datetime
from pprint import pformat
from typing import Any, Dict, List, Optional

import aiohttp
import orjson
import pytz
from haystack import component

from src.core.engine import (
Expand Down Expand Up @@ -111,13 +113,21 @@ def __init__(self, engine: Engine):
)
async def run(
self,
replies: List[str],
replies: List[str] | List[List[str]],
project_id: str | None = None,
) -> dict:
try:
cleaned_generation_result = orjson.loads(
clean_generation_result(replies[0])
)["results"]
if isinstance(replies[0], dict):
cleaned_generation_result = [
orjson.loads(clean_generation_result(reply["replies"][0]))[
"results"
][0]
for reply in replies
]
else:
cleaned_generation_result = orjson.loads(
clean_generation_result(replies[0])
)["results"]
Comment on lines +116 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering to know which scenario we encountered to do the change


if isinstance(cleaned_generation_result, dict):
cleaned_generation_result = [cleaned_generation_result]
Expand Down Expand Up @@ -361,3 +371,13 @@ def construct_instructions(configurations: AskConfigurations | None):
instructions += f"- For calendar year related computation, it should be started from {configurations.fiscal_year.start} to {configurations.fiscal_year.end}"

return instructions


def show_current_time(timezone: AskConfigurations.Timezone):
# Get the current time in the specified timezone
tz = pytz.timezone(
timezone.name
) # Assuming timezone.name contains the timezone string
current_time = datetime.now(tz)

return f'{current_time.strftime("%Y-%m-%d %A")}' # YYYY-MM-DD weekday_name, ex: 2024-10-23 Wednesday
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List

Expand All @@ -18,6 +17,7 @@
TEXT_TO_SQL_RULES,
SQLGenPostProcessor,
construct_instructions,
show_current_time,
sql_generation_system_prompt,
)
from src.utils import async_timer, timer
Expand Down Expand Up @@ -136,8 +136,8 @@ def prompt(
documents: List[str],
history: AskHistory,
alert: str,
configurations: AskConfigurations,
prompt_builder: PromptBuilder,
configurations: AskConfigurations | None = None,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"documents: {documents}")
Expand All @@ -149,7 +149,7 @@ def prompt(
history=history,
alert=alert,
instructions=construct_instructions(configurations),
current_time=datetime.now(),
current_time=show_current_time(configurations.timezone),
)


Expand Down Expand Up @@ -228,8 +228,8 @@ def visualize(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations,
project_id: str | None = None,
configurations: AskConfigurations | None = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Expand Down Expand Up @@ -258,8 +258,8 @@ async def run(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations,
project_id: str | None = None,
configurations: AskConfigurations | None = None,
):
logger.info("Follow-Up SQL Generation pipeline is running...")
return await self._pipe.execute(
Expand Down
34 changes: 27 additions & 7 deletions wren-ai-service/src/pipelines/generation/sql_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
3. Read the sql and understand the data.
4. Generate a consice and clear answer in string format and a reasoning process in string format to the user's question based on the data, sql and sql summary.
5. If answer is in list format, only list top few examples, and tell users there are more results omitted.
6. Answer must be in the same language user specified.

### OUTPUT FORMAT

Expand All @@ -49,7 +50,7 @@
SQL: {{ sql }}
SQL summary: {{ sql_summary }}
Data: {{ sql_data }}

Language: {{ language }}
Please think step by step and answer the user's question.
"""

Expand Down Expand Up @@ -127,15 +128,20 @@ def prompt(
sql: str,
sql_summary: str,
execute_sql: dict,
language: str,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"query: {query}")
logger.debug(f"sql: {sql}")
logger.debug(f"sql_summary: {sql_summary}")
logger.debug(f"sql data: {execute_sql}")

logger.debug(f"language: {language}")
return prompt_builder.run(
query=query, sql=sql, sql_summary=sql_summary, sql_data=execute_sql["results"]
query=query,
sql=sql,
sql_summary=sql_summary,
sql_data=execute_sql["results"],
language=language,
)


Expand Down Expand Up @@ -202,7 +208,12 @@ def __init__(
)

def visualize(
self, query: str, sql: str, sql_summary: str, project_id: str | None = None
self,
query: str,
sql: str,
sql_summary: str,
language: str,
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Expand All @@ -215,6 +226,7 @@ def visualize(
"query": query,
"sql": sql,
"sql_summary": sql_summary,
"language": language,
"project_id": project_id,
**self._components,
},
Expand All @@ -225,7 +237,12 @@ def visualize(
@async_timer
@observe(name="SQL Answer Generation")
async def run(
self, query: str, sql: str, sql_summary: str, project_id: str | None = None
self,
query: str,
sql: str,
sql_summary: str,
language: str,
project_id: str | None = None,
) -> dict:
logger.info("Sql_Answer Generation pipeline is running...")
return await self._pipe.execute(
Expand All @@ -234,6 +251,7 @@ async def run(
"query": query,
"sql": sql,
"sql_summary": sql_summary,
"language": language,
"project_id": project_id,
**self._components,
},
Expand All @@ -257,9 +275,11 @@ async def run(
engine=engine,
)

pipeline.visualize("query", "SELECT * FROM table_name", "sql summary")
pipeline.visualize("query", "SELECT * FROM table_name", "sql summary", "English")
async_validate(
lambda: pipeline.run("query", "SELECT * FROM table_name", "sql summary")
lambda: pipeline.run(
"query", "SELECT * FROM table_name", "sql summary", "English"
)
)

langfuse_context.flush()
57 changes: 32 additions & 25 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import sys
from pathlib import Path
Expand Down Expand Up @@ -28,8 +29,8 @@
You are a Trino SQL expert with exceptional logical thinking skills and debugging skills.

### TASK ###
Now you are given a list of syntactically incorrect Trino SQL queries and related error messages.
With given database schema, please think step by step to correct these wrong Trino SQL quries.
Now you are given syntactically incorrect Trino SQL query and related error message.
With given database schema, please think step by step to correct the wrong Trino SQL query.

### DATABASE SCHEMA ###
{% for document in documents %}
Expand All @@ -41,19 +42,15 @@

{
"results": [
{"sql": <CORRECTED_SQL_QUERY_STRING_1>, "summary": <ORIGINAL_SUMMARY_STRING_1>},
{"sql": <CORRECTED_SQL_QUERY_STRING_2>, "summary": <ORIGINAL_SUMMARY_STRING_2>}
{"sql": <CORRECTED_SQL_QUERY_STRING>, "summary": <ORIGINAL_SUMMARY_STRING>},
]
}

{{ alert }}

### QUESTION ###
{% for invalid_generation_result in invalid_generation_results %}
sql: {{ invalid_generation_result.sql }}
summary: {{ invalid_generation_result.summary }}
error: {{ invalid_generation_result.error }}
{% endfor %}
SQL: {{ invalid_generation_result.sql }}
Error Message: {{ invalid_generation_result.error }}

Let's think step by step.
"""
Expand All @@ -62,46 +59,56 @@
## Start of Pipeline
@timer
@observe(capture_input=False)
def prompt(
def prompts(
documents: List[Document],
invalid_generation_results: List[Dict],
alert: str,
prompt_builder: PromptBuilder,
) -> dict:
) -> list[dict]:
logger.debug(
f"documents: {orjson.dumps(documents, option=orjson.OPT_INDENT_2).decode()}"
)
logger.debug(
f"invalid_generation_results: {orjson.dumps(invalid_generation_results, option=orjson.OPT_INDENT_2).decode()}"
)
return prompt_builder.run(
documents=documents,
invalid_generation_results=invalid_generation_results,
alert=alert,
)
return [
prompt_builder.run(
documents=documents,
invalid_generation_result=invalid_generation_result,
alert=alert,
)
for invalid_generation_result in invalid_generation_results
]


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_correction(prompt: dict, generator: Any) -> dict:
logger.debug(f"prompt: {orjson.dumps(prompt, option=orjson.OPT_INDENT_2).decode()}")
return await generator.run(prompt=prompt.get("prompt"))
async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[dict]:
logger.debug(
f"prompts: {orjson.dumps(prompts, option=orjson.OPT_INDENT_2).decode()}"
)

tasks = []
for prompt in prompts:
task = asyncio.ensure_future(generator.run(prompt=prompt.get("prompt")))
tasks.append(task)

return await asyncio.gather(*tasks)


@async_timer
@observe(capture_input=False)
async def post_process(
generate_sql_correction: dict,
generate_sql_corrections: list[dict],
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
) -> list[dict]:
logger.debug(
f"generate_sql_correction: {orjson.dumps(generate_sql_correction, option=orjson.OPT_INDENT_2).decode()}"
)
return await post_processor.run(
generate_sql_correction.get("replies"), project_id=project_id
f"generate_sql_corrections: {orjson.dumps(generate_sql_corrections, option=orjson.OPT_INDENT_2).decode()}"
)

return await post_processor.run(generate_sql_corrections, project_id=project_id)


## End of Pipeline

Expand Down
Loading
Loading