diff --git a/python/cog/server/clients.py b/python/cog/server/clients.py index 7551d39c6..4547ee5ed 100644 --- a/python/cog/server/clients.py +++ b/python/cog/server/clients.py @@ -2,7 +2,17 @@ import io import mimetypes import os -from typing import Any, AsyncIterator, Awaitable, Callable, Collection, Dict, Optional +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Collection, + Dict, + Mapping, + Optional, + cast, +) from urllib.parse import urlparse import httpx @@ -62,7 +72,7 @@ def webhook_headers() -> "dict[str, str]": async def on_request_trace_context_hook(request: httpx.Request) -> None: ctx = current_trace_context() or {} - request.headers.update(ctx) + request.headers.update(cast(Mapping[str, str], ctx)) def httpx_webhook_client() -> httpx.AsyncClient: @@ -111,6 +121,22 @@ def httpx_file_client() -> httpx.AsyncClient: ) +class ChunkFileReader: + def __init__(self, fh: io.IOBase) -> None: + self.fh = fh + + async def __aiter__(self) -> AsyncIterator[bytes]: + self.fh.seek(0) + while True: + chunk = self.fh.read(1024 * 1024) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + if not chunk: + log.info("finished reading file") + break + yield chunk + + # there's a case for splitting this apart or inlining parts of it # I'm somewhat sympathetic to separating webhooks and files, but they both have # the same semantics of holding a client for the lifetime of runner @@ -167,10 +193,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None: # files - async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str: + async def upload_file( + self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str] + ) -> str: """put file to signed endpoint""" log.debug("upload_file") - fh.seek(0) # try to guess the filename of the given object name = getattr(fh, "name", "file") filename = os.path.basename(name) or "file" @@ -188,17 +215,12 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str: # ensure trailing slash url_with_trailing_slash = url if url.endswith("/") else url + "/" - async def chunk_file_reader() -> AsyncIterator[bytes]: - while 1: - chunk = fh.read(1024 * 1024) - if isinstance(chunk, str): - chunk = chunk.encode("utf-8") - if not chunk: - log.info("finished reading file") - break - yield chunk - url = url_with_trailing_slash + filename + + headers = {"Content-Type": content_type} + if prediction_id is not None: + headers["X-Prediction-ID"] = prediction_id + # this is a somewhat unfortunate hack, but it works # and is critical for upload training/quantization outputs # if we get multipart uploads working or a separate API route @@ -208,29 +230,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]: resp1 = await self.file_client.put( url, content=b"", - headers={"Content-Type": content_type}, + headers=headers, follow_redirects=False, ) if resp1.status_code == 307 and resp1.headers["Location"]: log.info("got file upload redirect from api") url = resp1.headers["Location"] + log.info("doing real upload to %s", url) resp = await self.file_client.put( url, - content=chunk_file_reader(), - headers={"Content-Type": content_type}, + content=ChunkFileReader(fh), + headers=headers, ) # TODO: if file size is >1MB, show upload throughput resp.raise_for_status() - # strip any signing gubbins from the URL - final_url = urlparse(str(resp.url))._replace(query="").geturl() + # Try to extract the final asset URL from the `Location` header + # otherwise fallback to the URL of the final request. + final_url = str(resp.url) + if "location" in resp.headers: + final_url = resp.headers.get("location") - return final_url + # strip any signing gubbins from the URL + return urlparse(final_url)._replace(query="").geturl() # this previously lived in json.upload_files, but it's clearer here # this is a great pattern that should be adopted for input files - async def upload_files(self, obj: Any, url: Optional[str]) -> Any: + async def upload_files( + self, obj: Any, *, url: Optional[str], prediction_id: Optional[str] + ) -> Any: """ Iterates through an object from make_encodeable and uploads any files. When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files. @@ -245,15 +274,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any: # TODO: upload concurrently if isinstance(obj, dict): return { - key: await self.upload_files(value, url) for key, value in obj.items() + key: await self.upload_files( + value, url=url, prediction_id=prediction_id + ) + for key, value in obj.items() } if isinstance(obj, list): - return [await self.upload_files(value, url) for value in obj] + return [ + await self.upload_files(value, url=url, prediction_id=prediction_id) + for value in obj + ] if isinstance(obj, Path): with obj.open("rb") as f: - return await self.upload_file(f, url) + return await self.upload_file(f, url=url, prediction_id=prediction_id) if isinstance(obj, io.IOBase): - return await self.upload_file(obj, url) + return await self.upload_file(obj, url=url, prediction_id=prediction_id) return obj # inputs diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 6f3849e9a..02306a63f 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -122,7 +122,9 @@ def __init__( self._shutdown_event = shutdown_event # __main__ waits for this event self._upload_url = upload_url - self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {} + self._predictions: dict[ + str, tuple[schema.PredictionResponse, PredictionTask] + ] = {} self._predictions_in_flight: set[str] = set() # it would be lovely to merge these but it's not fully clear how best to handle it # since idempotent requests can kinda come whenever? @@ -536,7 +538,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None: async def _upload_files(self, output: Any) -> Any: try: # TODO: clean up output files - return await self._client_manager.upload_files(output, self._upload_url) + return await self._client_manager.upload_files( + output, url=self._upload_url, prediction_id=self.p.id + ) except Exception as error: # If something goes wrong uploading a file, it's irrecoverable. # The re-raised exception will be caught and cause the prediction diff --git a/python/tests/server/test_clients.py b/python/tests/server/test_clients.py index 0b287f387..f4e9afccb 100644 --- a/python/tests/server/test_clients.py +++ b/python/tests/server/test_clients.py @@ -1,4 +1,6 @@ +import httpx import os +import responses import tempfile import cog @@ -7,12 +9,103 @@ @pytest.mark.asyncio -async def test_upload_files(): +async def test_upload_files_without_url(): client_manager = ClientManager() temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "my_file.txt") with open(temp_path, "w") as fh: fh.write("file content") obj = {"path": cog.Path(temp_path)} - result = await client_manager.upload_files(obj, None) + result = await client_manager.upload_files(obj, url=None, prediction_id=None) assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"} + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_url(respx_mock): + uploader = respx_mock.put("/bucket/my_file.txt").mock( + return_value=httpx.Response(201) + ) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id=None + ) + assert result == {"path": "https://example.com/bucket/my_file.txt"} + + assert uploader.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_prediction_id(respx_mock): + uploader = respx_mock.put( + "/bucket/my_file.txt", headers={"x-prediction-id": "p123"} + ).mock(return_value=httpx.Response(201)) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id="p123" + ) + assert result == {"path": "https://example.com/bucket/my_file.txt"} + + assert uploader.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_location_header(respx_mock): + uploader = respx_mock.put("/bucket/my_file.txt").mock( + return_value=httpx.Response( + 201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"} + ) + ) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id=None + ) + assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"} + + assert uploader.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_retry(respx_mock): + uploader = respx_mock.put("/bucket/my_file.txt").mock( + return_value=httpx.Response(502) + ) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + with pytest.raises(httpx.HTTPStatusError): + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id=None + ) + + assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"} + assert uploader.call_count == 3