Skip to content

Commit

Permalink
Include prediction id upload request
Browse files Browse the repository at this point in the history
Based on #1667

This PR introduces two small changes to the file upload interface.

1. We now allow downstream services to include the destination of the
asset in a `Location` header, rather than assuming that it's the same as
the final upload url (either the one passed via `--upload-url` or the
result of a 307 redirect response.

2. We now include the `X-Prediction-Id` header in upload request, this
allows the downstream client to potentially do configuration/routing
based on the prediction ID. This ID should be considered unsafe and
needs to be validated by the downstream service.
  • Loading branch information
aron committed Jul 16, 2024
1 parent 31c8610 commit 97ce06c
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 26 deletions.
65 changes: 43 additions & 22 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None:

# files

async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
async def upload_file(
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
) -> str:
"""put file to signed endpoint"""
log.debug("upload_file")
fh.seek(0)
# try to guess the filename of the given object
name = getattr(fh, "name", "file")
filename = os.path.basename(name) or "file"
Expand All @@ -198,17 +199,24 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
# ensure trailing slash
url_with_trailing_slash = url if url.endswith("/") else url + "/"

async def chunk_file_reader() -> AsyncIterator[bytes]:
while 1:
chunk = fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
if not chunk:
log.info("finished reading file")
break
yield chunk
class ChunkFileReader:
async def __aiter__(self) -> AsyncIterator[bytes]:
fh.seek(0)
while 1:
chunk = fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
if not chunk:
log.info("finished reading file")
break
yield chunk

url = url_with_trailing_slash + filename

headers = {"Content-Type": content_type}
if prediction_id is not None:
headers["X-Prediction-ID"] = prediction_id

# this is a somewhat unfortunate hack, but it works
# and is critical for upload training/quantization outputs
# if we get multipart uploads working or a separate API route
Expand All @@ -218,29 +226,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
resp1 = await self.file_client.put(
url,
content=b"",
headers={"Content-Type": content_type},
headers=headers,
follow_redirects=False,
)
if resp1.status_code == 307 and resp1.headers["Location"]:
log.info("got file upload redirect from api")
url = resp1.headers["Location"]

log.info("doing real upload to %s", url)
resp = await self.file_client.put(
url,
content=chunk_file_reader(),
headers={"Content-Type": content_type},
content=ChunkFileReader(),
headers=headers,
)
# TODO: if file size is >1MB, show upload throughput
resp.raise_for_status()

# strip any signing gubbins from the URL
final_url = urlparse(str(resp.url))._replace(query="").geturl()
# Try to extract the final asset URL from the `Location` header
# otherwise fallback to the URL of the final request.
final_url = str(resp.url)
if "location" in resp.headers:
final_url = resp.headers.get("location")

return final_url
# strip any signing gubbins from the URL
return urlparse(final_url)._replace(query="").geturl()

# this previously lived in json.upload_files, but it's clearer here
# this is a great pattern that should be adopted for input files
async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
async def upload_files(
self, obj: Any, *, url: Optional[str], prediction_id: Optional[str]
) -> Any:
"""
Iterates through an object from make_encodeable and uploads any files.
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
Expand All @@ -255,15 +270,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
# TODO: upload concurrently
if isinstance(obj, dict):
return {
key: await self.upload_files(value, url) for key, value in obj.items()
key: await self.upload_files(
value, url=url, prediction_id=prediction_id
)
for key, value in obj.items()
}
if isinstance(obj, list):
return [await self.upload_files(value, url) for value in obj]
return [
await self.upload_files(value, url=url, prediction_id=prediction_id)
for value in obj
]
if isinstance(obj, Path):
with obj.open("rb") as f:
return await self.upload_file(f, url)
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
return await self.upload_file(obj, url)
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
return obj

# inputs
Expand Down
8 changes: 6 additions & 2 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def __init__(
self._shutdown_event = shutdown_event # __main__ waits for this event

self._upload_url = upload_url
self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {}
self._predictions: dict[
str, tuple[schema.PredictionResponse, PredictionTask]
] = {}
self._predictions_in_flight: set[str] = set()
# it would be lovely to merge these but it's not fully clear how best to handle it
# since idempotent requests can kinda come whenever?
Expand Down Expand Up @@ -536,7 +538,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None:
async def _upload_files(self, output: Any) -> Any:
try:
# TODO: clean up output files
return await self._client_manager.upload_files(output, self._upload_url)
return await self._client_manager.upload_files(
output, url=self._upload_url, prediction_id=self.p.id
)
except Exception as error:
# If something goes wrong uploading a file, it's irrecoverable.
# The re-raised exception will be caught and cause the prediction
Expand Down
97 changes: 95 additions & 2 deletions python/tests/server/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import httpx
import os
import responses
import tempfile

import cog
Expand All @@ -7,12 +9,103 @@


@pytest.mark.asyncio
async def test_upload_files():
async def test_upload_files_without_url():
client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")
obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(obj, None)
result = await client_manager.upload_files(obj, url=None, prediction_id=None)
assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"}


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_url(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(201)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_prediction_id(respx_mock):
uploader = respx_mock.put(
"/bucket/my_file.txt", headers={"x-prediction-id": "p123"}
).mock(return_value=httpx.Response(201))

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id="p123"
)
assert result == {"path": "https://example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_location_header(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_retry(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(502)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
with pytest.raises(httpx.HTTPStatusError):
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
assert uploader.call_count == 3

0 comments on commit 97ce06c

Please sign in to comment.