Skip to content

Commit

Permalink
chore(wren-ai-service): minor-updates (#856)
Browse files Browse the repository at this point in the history
Co-authored-by: Aster Sun <imastr114@gmail.com>
  • Loading branch information
cyyeh and imAsterSun authored Nov 4, 2024
1 parent 89e4b22 commit b445eca
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 18 deletions.
10 changes: 6 additions & 4 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ async def _check_if_sql_executable(
)

if not status:
logger.exception(f"SQL is not executable: {addition["error_message"]}")
logger.exception(
f"SQL is not executable: {addition.get('error_message', '')}"
)

return status

Expand Down Expand Up @@ -173,16 +175,16 @@ async def _task(result: Dict[str, str]):
valid_generation_results.append(
{
"sql": quoted_sql,
"correlation_id": addition["correlation_id"],
"correlation_id": addition.get("correlation_id", ""),
}
)
else:
invalid_generation_results.append(
{
"sql": quoted_sql,
"type": "DRY_RUN",
"error": addition["error_message"],
"correlation_id": addition["correlation_id"],
"error": addition.get("error_message", ""),
"correlation_id": addition.get("correlation_id", ""),
}
)
else:
Expand Down
64 changes: 50 additions & 14 deletions wren-ai-service/src/providers/engine/wren.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,24 @@ async def execute_sql(
) as response:
res = await response.json()
if data := res.get("data"):
return True, data, {
"correlation_id": res.get("correlationId"),
}
return (
True,
data,
{
"correlation_id": res.get("correlationId"),
},
)
return (
False,
None,
{
"error_message": res.get("errors", [{}])[0].get("message", "Unknown error"),
"correlation_id": res.get("extensions", {}).get("other", {}).get("correlationId"),
}
"error_message": res.get("errors", [{}])[0].get(
"message", "Unknown error"
),
"correlation_id": res.get("extensions", {})
.get("other", {})
.get("correlationId"),
},
)
except asyncio.TimeoutError:
return False, None, f"Request timed out: {timeout} seconds"
Expand Down Expand Up @@ -115,12 +123,23 @@ async def execute_sql(
else:
res = await response.json()

if response.status == 204:
return True, None, None
if response.status == 200:
return True, res, None
if response.status == 200 or response.status == 204:
return (
True,
res,
{
"correlation_id": "",
},
)

return False, None, res
return (
False,
None,
{
"error_message": res,
"correlation_id": "",
},
)
except asyncio.TimeoutError:
return False, None, f"Request timed out: {timeout} seconds"

Expand Down Expand Up @@ -166,10 +185,27 @@ async def execute_sql(
},
timeout=aiohttp.ClientTimeout(total=timeout),
) as response:
res = await response.json()
if dry_run:
res = await response.text()
else:
res = await response.json()

if response.status == 200:
return True, res, None
return (
True,
res,
{
"correlation_id": "",
},
)

return False, None, res
return (
False,
None,
{
"error_message": res,
"correlation_id": "",
},
)
except asyncio.TimeoutError:
return False, None, f"Request timed out: {timeout} seconds"

0 comments on commit b445eca

Please sign in to comment.