Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): add function list to prompt rule #842

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ demo:
poetry run streamlit run demo/app.py

test test_args='': up && down
poetry run pytest -s {{test_args}}
poetry run pytest -s {{test_args}} --ignore tests/pytest/test_usecases.py

test-usecases:
poetry run python -m tests.pytest.test_usecases

load-test:
poetry run python -m tests.locust.locust_script
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def add_quotes(sql: str) -> Tuple[str, bool]:
return sql, False


def get_connection_info(data_source: str):
def _get_connection_info(data_source: str):
if data_source == "bigquery":
return {
"project_id": os.getenv("bigquery.project-id"),
Expand Down Expand Up @@ -77,7 +77,7 @@ def rerun_wren_engine(mdl_json: Dict, dataset_type: str, dataset: str):
_replace_wren_engine_env_variables("wren_engine", {"manifest": MANIFEST})
else:
WREN_IBIS_CONNECTION_INFO = base64.b64encode(
orjson.dumps(get_connection_info(dataset_type))
orjson.dumps(_get_connection_info(dataset_type))
).decode()

_replace_wren_engine_env_variables(
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_data_from_wren_engine(
json={
"sql": quoted_sql,
"manifestStr": base64.b64encode(orjson.dumps(manifest)).decode(),
"connectionInfo": get_connection_info(dataset_type),
"connectionInfo": _get_connection_info(dataset_type),
"limit": 100,
},
)
Expand Down
74 changes: 68 additions & 6 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,16 @@ async def run(
) -> dict:
try:
if isinstance(replies[0], dict):
cleaned_generation_result = [
orjson.loads(clean_generation_result(reply["replies"][0]))[
"results"
][0]
for reply in replies
]
cleaned_generation_result = []
for reply in replies:
try:
cleaned_generation_result.append(
orjson.loads(clean_generation_result(reply["replies"][0]))[
"results"
][0]
)
except Exception as e:
logger.exception(f"Error in SQLGenPostProcessor: {e}")
else:
cleaned_generation_result = orjson.loads(
clean_generation_result(replies[0])
Expand Down Expand Up @@ -203,6 +207,64 @@ async def _task(result: Dict[str, str]):
- ONLY USE the tables and columns mentioned in the database schema.
- ONLY USE "*" if the user query asks for all the columns of a table.
- ONLY CHOOSE columns belong to the tables mentioned in the database schema.
- ONLY USE the following SQL functions when generating answers:
- Aggregation functions:
- AVG
- COUNT
- MAX
- MIN
- SUM
- ARRAY_AGG
- BOOL_OR
- Math functions:
- ABS
- CBRT
- CEIL
- EXP
- FLOOR
- LN
- ROUND
- SIGN
- GREATEST
- LEAST
- MOD
- POWER
- String functions:
- LENGTH
- REVERSE
- CHR
- CONCAT
- FORMAT
- LOWER
- LPAD
- LTRIM
- POSITION
- REPLACE
- RPAD
- RTRIM
- STRPOS
- SUBSTR
- SUBSTRING
- TRANSLATE
- TRIM
- UPPER
- Date and Time functions:
- CURRENT_DATE
- DATE_TRUNC
- EXTRACT
- operators:
- `+`
- `-`
- `*`
- `/`
- `||`
- `<`
- `>`
- `>=`
- `<=`
- `=`
- `<>`
- `!=`
- YOU MUST USE "JOIN" if you choose columns from multiple tables!
- YOU MUST USE "lower(<column_name>) = lower(<value>)" function for case-insensitive comparison!
- DON'T USE "DATE_ADD" or "DATE_SUB" functions for date operations, instead use syntax like this "current_date - INTERVAL '7' DAY"!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def visualize(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations,
configurations: AskConfigurations = AskConfigurations(),
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/generation"
Expand Down Expand Up @@ -258,7 +258,7 @@ async def run(
query: str,
contexts: List[str],
history: AskHistory,
configurations: AskConfigurations,
configurations: AskConfigurations = AskConfigurations(),
project_id: str | None = None,
):
logger.info("Follow-Up SQL Generation pipeline is running...")
Expand Down
12 changes: 10 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,11 @@ def __init__(
)

def visualize(
self, query: str, sql: str, language: str, project_id: str | None = None
self,
query: str,
sql: str,
language: str = "English",
project_id: str | None = None,
) -> None:
destination = "outputs/pipelines/generation"
if not Path(destination).exists():
Expand All @@ -213,7 +217,11 @@ def visualize(
@async_timer
@observe(name="SQL Breakdown Generation")
async def run(
self, query: str, sql: str, language: str, project_id: str | None = None
self,
query: str,
sql: str,
language: str = "English",
project_id: str | None = None,
) -> dict:
logger.info("SQL Breakdown Generation pipeline is running...")
return await self._pipe.execute(
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def run(
query: str,
contexts: List[str],
history: AskHistory,
timezone: AskConfigurations.Timezone,
timezone: AskConfigurations.Timezone = AskConfigurations().timezone,
project_id: str | None = None,
):
logger.info("Sql Expansion Generation pipeline is running...")
Expand Down
4 changes: 2 additions & 2 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def visualize(
query: str,
contexts: List[str],
exclude: List[Dict],
configurations: AskConfigurations,
configurations: AskConfigurations = AskConfigurations(),
samples: List[Dict] | None = None,
project_id: str | None = None,
) -> None:
Expand Down Expand Up @@ -242,7 +242,7 @@ async def run(
query: str,
contexts: List[str],
exclude: List[Dict],
configurations: AskConfigurations,
configurations: AskConfigurations = AskConfigurations(),
samples: List[Dict] | None = None,
project_id: str | None = None,
):
Expand Down
135 changes: 135 additions & 0 deletions wren-ai-service/tests/pytest/test_usecases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import asyncio
import base64
import json
import time
import uuid

import aiohttp
import orjson
import requests

from demo.utils import _get_connection_info, _replace_wren_engine_env_variables


def is_ai_service_ready(url: str):
try:
response = requests.get(f"{url}/health")
return response.status_code == 200
except requests.exceptions.ConnectionError:
return False


def test_load_mdl_and_questions():
try:
with open("tests/data/hubspot/mdl.json", "r") as f:
mdl_str = orjson.dumps(json.load(f)).decode("utf-8")

with open("tests/data/hubspot/questions.json", "r") as f:
questions = json.load(f)["questions"]
except FileNotFoundError:
raise Exception(
"tests/data/hubspot/mdl.json or tests/data/hubspot/questions.json not found"
)

return mdl_str, questions


def setup_datasource(mdl_str: str):
dataset_type = "bigquery"
connection_info = _get_connection_info(dataset_type)
_replace_wren_engine_env_variables(
"wren_ibis",
{
"manifest": base64.b64encode(mdl_str.encode("utf-8")).decode("utf-8"),
"source": dataset_type,
"connection_info": base64.b64encode(orjson.dumps(connection_info)).decode(),
},
)
ready = False
while not ready:
ready = is_ai_service_ready(url)
time.sleep(1)


def deploy_mdl(mdl_str: str, url: str):
semantics_preperation_id = str(uuid.uuid4())
response = requests.post(
f"{url}/v1/semantics-preparations",
json={"mdl": mdl_str, "id": semantics_preperation_id},
)
assert response.status_code == 200

status = "indexing"
while status == "indexing":
response = requests.get(
f"{url}/v1/semantics-preparations/{semantics_preperation_id}/status"
)

assert response.status_code == 200
status = response.json()["status"]

assert status == "finished"

return semantics_preperation_id


async def ask_question(question: str, url: str, semantics_preperation_id: str):
print(f"preparing to ask question: {question}")
async with aiohttp.ClientSession() as session:
response = await session.post(
f"{url}/v1/asks", json={"query": question, "id": semantics_preperation_id}
)
assert response.status == 200

query_id = (await response.json())["query_id"]

response = await session.get(f"{url}/v1/asks/{query_id}/result")
while (await response.json())["status"] != "finished" and (
await response.json()
)["status"] != "failed":
response = await session.get(f"{url}/v1/asks/{query_id}/result")

assert response.status == 200

print(f"got the result of question: {question}")
return await response.json()


async def ask_questions(questions: list[str], url: str, semantics_preperation_id: str):
tasks = []
for question in questions:
task = asyncio.ensure_future(
ask_question(question, url, semantics_preperation_id)
)
tasks.append(task)
await asyncio.sleep(10)

return await asyncio.gather(*tasks)


if __name__ == "__main__":
url = "http://localhost:5556"

assert is_ai_service_ready(
url
), "WrenAI AI service is not running, please start it first via 'just up && just start'"

mdl_str, questions = test_load_mdl_and_questions()

setup_datasource(mdl_str)

semantics_preperation_id = deploy_mdl(mdl_str, url)

# ask questions
results = asyncio.run(ask_questions(questions, url, semantics_preperation_id))
assert len(results) == len(questions)

# count the number of results that are failed
for question, result in zip(questions, results):
print(f"question: {question}")
print(json.dumps(result, indent=2))

failed_count = sum(1 for result in results if result["status"] == "failed")
assert (
failed_count == 0
), f"got {failed_count} failed results in {len(results)} questions"
Loading