Skip to content

Commit

Permalink
[async] Include prediction id upload request (#1788)
Browse files Browse the repository at this point in the history
* Cast TraceContext into Mapping[str, str] to fix linter

* Include prediction id upload request

Based on #1667

This PR introduces two small changes to the file upload interface.

1. We now allow downstream services to include the destination of the
asset in a `Location` header, rather than assuming that it's the same as
the final upload url (either the one passed via `--upload-url` or the
result of a 307 redirect response.

2. We now include the `X-Prediction-Id` header in upload request, this
allows the downstream client to potentially do configuration/routing
based on the prediction ID. This ID should be considered unsafe and
needs to be validated by the downstream service.

* Extract ChunkFileReader into top-level class

---------

Co-authored-by: Mattt Zmuda <mattt@replicate.com>
  • Loading branch information
aron and mattt committed Jul 18, 2024
1 parent 1e7d482 commit 7e56834
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 32 deletions.
89 changes: 60 additions & 29 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import io
import mimetypes
import os
from typing import Any, AsyncIterator, Awaitable, Callable, Collection, Dict, Optional
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Dict,
Mapping,
Optional,
cast,
)
from urllib.parse import urlparse

import httpx
Expand Down Expand Up @@ -62,7 +72,7 @@ def webhook_headers() -> "dict[str, str]":

async def on_request_trace_context_hook(request: httpx.Request) -> None:
ctx = current_trace_context() or {}
request.headers.update(ctx)
request.headers.update(cast(Mapping[str, str], ctx))


def httpx_webhook_client() -> httpx.AsyncClient:
Expand Down Expand Up @@ -101,20 +111,32 @@ def httpx_file_client() -> httpx.AsyncClient:
# requests has no write timeout, keep that
# httpx default for pool is 5, use that
timeout = httpx.Timeout(connect=10, read=15, write=None, pool=5)
headers = {key: str(value) for key, value in (current_trace_context() or {})}
headers["User-Agent"] = _user_agent

return httpx.AsyncClient(
event_hooks={"request": [on_request_trace_context_hook]},
headers=common_headers(),
transport=transport,
follow_redirects=True,
timeout=timeout,
http2=True,
headers=headers,
)


class ChunkFileReader:
def __init__(self, fh: io.IOBase) -> None:
self.fh = fh

async def __aiter__(self) -> AsyncIterator[bytes]:
self.fh.seek(0)
while True:
chunk = self.fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
if not chunk:
log.info("finished reading file")
break
yield chunk


# there's a case for splitting this apart or inlining parts of it
# I'm somewhat sympathetic to separating webhooks and files, but they both have
# the same semantics of holding a client for the lifetime of runner
Expand Down Expand Up @@ -171,10 +193,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None:

# files

async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
async def upload_file(
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
) -> str:
"""put file to signed endpoint"""
log.debug("upload_file")
fh.seek(0)
# try to guess the filename of the given object
name = getattr(fh, "name", "file")
filename = os.path.basename(name) or "file"
Expand All @@ -192,17 +215,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str:
# ensure trailing slash
url_with_trailing_slash = url if url.endswith("/") else url + "/"

async def chunk_file_reader() -> AsyncIterator[bytes]:
while 1:
chunk = fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
if not chunk:
log.info("finished reading file")
break
yield chunk

url = url_with_trailing_slash + filename

headers = {"Content-Type": content_type}
if prediction_id is not None:
headers["X-Prediction-ID"] = prediction_id

# this is a somewhat unfortunate hack, but it works
# and is critical for upload training/quantization outputs
# if we get multipart uploads working or a separate API route
Expand All @@ -212,29 +230,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]:
resp1 = await self.file_client.put(
url,
content=b"",
headers={"Content-Type": content_type},
headers=headers,
follow_redirects=False,
)
if resp1.status_code == 307 and resp1.headers["Location"]:
log.info("got file upload redirect from api")
url = resp1.headers["Location"]

log.info("doing real upload to %s", url)
resp = await self.file_client.put(
url,
content=chunk_file_reader(),
headers={"Content-Type": content_type},
content=ChunkFileReader(fh),
headers=headers,
)
# TODO: if file size is >1MB, show upload throughput
resp.raise_for_status()

# strip any signing gubbins from the URL
final_url = urlparse(str(resp.url))._replace(query="").geturl()
# Try to extract the final asset URL from the `Location` header
# otherwise fallback to the URL of the final request.
final_url = str(resp.url)
if "location" in resp.headers:
final_url = resp.headers.get("location")

return final_url
# strip any signing gubbins from the URL
return urlparse(final_url)._replace(query="").geturl()

# this previously lived in json.upload_files, but it's clearer here
# this is a great pattern that should be adopted for input files
async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
async def upload_files(
self, obj: Any, *, url: Optional[str], prediction_id: Optional[str]
) -> Any:
"""
Iterates through an object from make_encodeable and uploads any files.
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
Expand All @@ -249,15 +274,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
# TODO: upload concurrently
if isinstance(obj, dict):
return {
key: await self.upload_files(value, url) for key, value in obj.items()
key: await self.upload_files(
value, url=url, prediction_id=prediction_id
)
for key, value in obj.items()
}
if isinstance(obj, list):
return [await self.upload_files(value, url) for value in obj]
return [
await self.upload_files(value, url=url, prediction_id=prediction_id)
for value in obj
]
if isinstance(obj, Path):
with obj.open("rb") as f:
return await self.upload_file(f, url)
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
return await self.upload_file(obj, url)
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
return obj

# inputs
Expand Down
4 changes: 3 additions & 1 deletion python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None:
async def _upload_files(self, output: Any) -> Any:
try:
# TODO: clean up output files
return await self._client_manager.upload_files(output, self._upload_url)
return await self._client_manager.upload_files(
output, url=self._upload_url, prediction_id=self.p.id
)
except Exception as error:
# If something goes wrong uploading a file, it's irrecoverable.
# The re-raised exception will be caught and cause the prediction
Expand Down
97 changes: 95 additions & 2 deletions python/tests/server/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import httpx
import os
import responses
import tempfile

import cog
Expand All @@ -7,12 +9,103 @@


@pytest.mark.asyncio
async def test_upload_files():
async def test_upload_files_without_url():
client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")
obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(obj, None)
result = await client_manager.upload_files(obj, url=None, prediction_id=None)
assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"}


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_url(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(201)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_prediction_id(respx_mock):
uploader = respx_mock.put(
"/bucket/my_file.txt", headers={"x-prediction-id": "p123"}
).mock(return_value=httpx.Response(201))

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id="p123"
)
assert result == {"path": "https://example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_location_header(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}

assert uploader.call_count == 1


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
async def test_upload_files_with_retry(respx_mock):
uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(502)
)

client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "my_file.txt")
with open(temp_path, "w") as fh:
fh.write("file content")

obj = {"path": cog.Path(temp_path)}
with pytest.raises(httpx.HTTPStatusError):
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
assert uploader.call_count == 3

0 comments on commit 7e56834

Please sign in to comment.