Skip to content

Commit

Permalink
update timezone and add engine timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Oct 28, 2024
1 parent f7d0c74 commit a5c5ded
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 51 deletions.
116 changes: 67 additions & 49 deletions wren-ai-service/src/providers/engine/wren.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
import logging
import os
Expand Down Expand Up @@ -28,6 +29,7 @@ async def execute_sql(
session: aiohttp.ClientSession,
project_id: str | None = None,
dry_run: bool = True,
timeout: float = 10.0,
**kwargs,
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
data = {
Expand All @@ -40,21 +42,25 @@ async def execute_sql(
else:
data["limit"] = 500

async with session.post(
f"{self._endpoint}/api/graphql",
json={
"query": "mutation PreviewSql($data: PreviewSQLDataInput) { previewSql(data: $data) }",
"variables": {"data": data},
},
) as response:
res = await response.json()
if data := res.get("data"):
return True, data, None
return (
False,
None,
res.get("errors", [{}])[0].get("message", "Unknown error"),
)
try:
async with session.post(
f"{self._endpoint}/api/graphql",
json={
"query": "mutation PreviewSql($data: PreviewSQLDataInput) { previewSql(data: $data) }",
"variables": {"data": data},
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
res = await response.json()
if data := res.get("data"):
return True, data, None
return (
False,
None,
res.get("errors", [{}])[0].get("message", "Unknown error"),
)
except asyncio.TimeoutError:
return False, None, f"Request timed out: {timeout} seconds"


@provider("wren_ibis")
Expand All @@ -80,6 +86,7 @@ async def execute_sql(
sql: str,
session: aiohttp.ClientSession,
dry_run: bool = True,
timeout: float = 10.0,
**kwargs,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
api_endpoint = f"{self._endpoint}/v2/connector/{self._source}/query"
Expand All @@ -88,25 +95,29 @@ async def execute_sql(
else:
api_endpoint += "?limit=500"

async with session.post(
api_endpoint,
json={
"sql": remove_limit_statement(sql),
"manifestStr": self._manifest,
"connectionInfo": self._connection_info,
},
) as response:
if dry_run:
res = await response.text()
else:
res = await response.json()

if response.status == 204:
return True, None, None
if response.status == 200:
return True, res, None

return False, None, res
try:
async with session.post(
api_endpoint,
json={
"sql": remove_limit_statement(sql),
"manifestStr": self._manifest,
"connectionInfo": self._connection_info,
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
if dry_run:
res = await response.text()
else:
res = await response.json()

if response.status == 204:
return True, None, None
if response.status == 200:
return True, res, None

return False, None, res
except asyncio.TimeoutError:
return False, None, f"Request timed out: {timeout} seconds"


@provider("wren_engine")
Expand All @@ -127,6 +138,7 @@ async def execute_sql(
"manifest": os.getenv("WREN_ENGINE_MANIFEST"),
},
dry_run: bool = True,
timeout: float = 10.0,
**kwargs,
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
api_endpoint = (
Expand All @@ -135,18 +147,24 @@ async def execute_sql(
else f"{self._endpoint}/v1/mdl/preview"
)

async with session.get(
api_endpoint,
json={
"manifest": orjson.loads(base64.b64decode(properties.get("manifest")))
if properties.get("manifest")
else {},
"sql": remove_limit_statement(sql),
"limit": 1 if dry_run else 500,
},
) as response:
res = await response.json()
if response.status == 200:
return True, res, None

return False, None, res
try:
async with session.get(
api_endpoint,
json={
"manifest": orjson.loads(
base64.b64decode(properties.get("manifest"))
)
if properties.get("manifest")
else {},
"sql": remove_limit_statement(sql),
"limit": 1 if dry_run else 500,
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
res = await response.json()
if response.status == 200:
return True, res, None

return False, None, res
except asyncio.TimeoutError:
return False, None, f"Request timed out: {timeout} seconds"
2 changes: 1 addition & 1 deletion wren-ai-service/src/web/v1/services/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class FiscalYear(BaseModel):

fiscal_year: Optional[FiscalYear] = None
language: str = "English"
timezone: str = "" # time-zone:utf-offset-in-hours
timezone: str = "Asia/Taipei"


# POST /v1/asks
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/web/v1/services/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# POST /v1/sql-expansions
class SqlExpansionConfigurations(BaseModel):
language: str = "English"
timezone: str = "" # time-zone:utf-offset-in-hours
timezone: str = "Asia/Taipei"


class SqlExpansionRequest(BaseModel):
Expand Down

0 comments on commit a5c5ded

Please sign in to comment.