Skip to content

Commit

Permalink
validate prediction response to raise errors, but return the unvalida…
Browse files Browse the repository at this point in the history
…ted output to avoid converting urls to File/Path

Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Feb 29, 2024
1 parent 51e8a45 commit 644d1cd
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
if resp1.status_code == 307 and resp1.headers["Location"]:
log.info("got file upload redirect from api")
url = resp1.headers["Location"]
log.info("doing real upload to", url)
log.info("doing real upload to %s", url)
resp = await self.file_client.put(
url,
content=chunk_file_reader(),
Expand Down
9 changes: 8 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,13 @@ async def shared_predict(
if respond_async:
return JSONResponse(jsonable_encoder(initial_response), status_code=202)

# by now, output Path and File are already converted to str
# so when we validate the schema, those urls get cast back to Path and File
# in the previous implementation those would then get encoded as strings
# however the changes to Path and File break this and return the filename instead
try:
prediction = await async_result
# we're only doing this to catch validation errors
response = PredictionResponse(**prediction.dict())
except ValidationError as e:
_log_invalid_output(e)
Expand All @@ -329,7 +334,9 @@ async def shared_predict(
# )
# dict_resp["output"] = output
# encoded_response = jsonable_encoder(dict_resp)
encoded_response = jsonable_encoder(response.dict())

# return *prediction* and not *response* to preserve urls
encoded_response = jsonable_encoder(prediction.dict())
return JSONResponse(content=encoded_response)

@app.post("/predictions/{prediction_id}/cancel")
Expand Down
2 changes: 2 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test_default_int_input(client, match):
assert resp.json() == match({"output": 9, "status": "succeeded"})


# the data uri BytesIO gets consumed by jsonable_encoder
# doesn't really matter that much for our purposes
@uses_predictor("input_file")
def test_file_input_data_url(client, match):
resp = client.post(
Expand Down

0 comments on commit 644d1cd

Please sign in to comment.