From 2da3de0d85f07342428437994e3f52dc961be35d Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 4 Jul 2024 21:09:06 +0100 Subject: [PATCH] Include prediction id upload request Based on #1667 This PR introduces two small changes to the file upload interface. 1. We now allow downstream services to include the destination of the asset in a `Location` header, rather than assuming that it's the same as the final upload url (either the one passed via `--upload-url` or the result of a 307 redirect response. 2. We now include the `X-Prediction-Id` header in upload request, this allows the downstream client to potentially do configuration/routing based on the prediction ID. This ID should be considered unsafe and needs to be validated by the downstream service. --- python/cog/server/clients.py | 46 +++++++++++++----- python/cog/server/runner.py | 8 +++- python/tests/server/test_clients.py | 74 ++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 17 deletions(-) diff --git a/python/cog/server/clients.py b/python/cog/server/clients.py index b385660ce..9be3a5d87 100644 --- a/python/cog/server/clients.py +++ b/python/cog/server/clients.py @@ -75,7 +75,7 @@ def httpx_retry_client() -> httpx.AsyncClient: def httpx_file_client() -> httpx.AsyncClient: # verify: Union[str, bool, ssl.SSLContext] = True transport = RetryTransport( - max_attempts=3, + max_attempts=1, backoff_factor=0.1, retry_status_codes=[408, 429, 500, 502, 503, 504], retryable_methods=["PUT"], @@ -150,10 +150,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None: # files - async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str: + async def upload_file( + self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str] + ) -> str: """put file to signed endpoint""" log.debug("upload_file") - fh.seek(0) # try to guess the filename of the given object name = getattr(fh, "name", "file") filename = os.path.basename(name) or "file" @@ -172,6 +173,7 @@ async def upload_file(self, fh: io.IOBase, url: Optional[str]) -> str: url_with_trailing_slash = url if url.endswith("/") else url + "/" async def chunk_file_reader() -> AsyncIterator[bytes]: + fh.seek(0) while 1: chunk = fh.read(1024 * 1024) if isinstance(chunk, str): @@ -182,6 +184,11 @@ async def chunk_file_reader() -> AsyncIterator[bytes]: yield chunk url = url_with_trailing_slash + filename + + headers = {"Content-Type": content_type} + if prediction_id is not None: + headers["X-Prediction-ID"] = prediction_id + # this is a somewhat unfortunate hack, but it works # and is critical for upload training/quantization outputs # if we get multipart uploads working or a separate API route @@ -191,29 +198,36 @@ async def chunk_file_reader() -> AsyncIterator[bytes]: resp1 = await self.file_client.put( url, content=b"", - headers={"Content-Type": content_type}, + headers=headers, follow_redirects=False, ) if resp1.status_code == 307 and resp1.headers["Location"]: log.info("got file upload redirect from api") url = resp1.headers["Location"] + log.info("doing real upload to %s", url) resp = await self.file_client.put( url, content=chunk_file_reader(), - headers={"Content-Type": content_type}, + headers=headers, ) # TODO: if file size is >1MB, show upload throughput resp.raise_for_status() - # strip any signing gubbins from the URL - final_url = urlparse(str(resp.url))._replace(query="").geturl() + # Try to extract the final asset URL from the `Location` header + # otherwise fallback to the URL of the final request. + final_url = str(resp.url) + if "location" in resp.headers: + final_url = resp.headers.get("location") - return final_url + # strip any signing gubbins from the URL + return urlparse(final_url)._replace(query="").geturl() # this previously lived in json.upload_files, but it's clearer here # this is a great pattern that should be adopted for input files - async def upload_files(self, obj: Any, url: Optional[str]) -> Any: + async def upload_files( + self, obj: Any, *, url: Optional[str], prediction_id: Optional[str] + ) -> Any: """ Iterates through an object from make_encodeable and uploads any files. When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files. @@ -228,15 +242,21 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any: # TODO: upload concurrently if isinstance(obj, dict): return { - key: await self.upload_files(value, url) for key, value in obj.items() + key: await self.upload_files( + value, url=url, prediction_id=prediction_id + ) + for key, value in obj.items() } if isinstance(obj, list): - return [await self.upload_files(value, url) for value in obj] + return [ + await self.upload_files(value, url=url, prediction_id=prediction_id) + for value in obj + ] if isinstance(obj, Path): with obj.open("rb") as f: - return await self.upload_file(f, url) + return await self.upload_file(f, url=url, prediction_id=prediction_id) if isinstance(obj, io.IOBase): - return await self.upload_file(obj, url) + return await self.upload_file(obj, url=url, prediction_id=prediction_id) return obj # inputs diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 6f3849e9a..02306a63f 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -122,7 +122,9 @@ def __init__( self._shutdown_event = shutdown_event # __main__ waits for this event self._upload_url = upload_url - self._predictions: dict[str, tuple[schema.PredictionResponse, PredictionTask]] = {} + self._predictions: dict[ + str, tuple[schema.PredictionResponse, PredictionTask] + ] = {} self._predictions_in_flight: set[str] = set() # it would be lovely to merge these but it's not fully clear how best to handle it # since idempotent requests can kinda come whenever? @@ -536,7 +538,9 @@ async def _send_webhook(self, event: schema.WebhookEvent) -> None: async def _upload_files(self, output: Any) -> Any: try: # TODO: clean up output files - return await self._client_manager.upload_files(output, self._upload_url) + return await self._client_manager.upload_files( + output, url=self._upload_url, prediction_id=self.p.id + ) except Exception as error: # If something goes wrong uploading a file, it's irrecoverable. # The re-raised exception will be caught and cause the prediction diff --git a/python/tests/server/test_clients.py b/python/tests/server/test_clients.py index 0b287f387..f2fe81360 100644 --- a/python/tests/server/test_clients.py +++ b/python/tests/server/test_clients.py @@ -1,4 +1,6 @@ +import httpx import os +import responses import tempfile import cog @@ -7,12 +9,80 @@ @pytest.mark.asyncio -async def test_upload_files(): +async def test_upload_files_without_url(): client_manager = ClientManager() temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "my_file.txt") with open(temp_path, "w") as fh: fh.write("file content") obj = {"path": cog.Path(temp_path)} - result = await client_manager.upload_files(obj, None) + result = await client_manager.upload_files(obj, url=None, prediction_id=None) assert result == {"path": "data:text/plain;base64,ZmlsZSBjb250ZW50"} + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_url(respx_mock): + uploader = respx_mock.put("/bucket/my_file.txt").mock( + return_value=httpx.Response(201) + ) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id=None + ) + assert result == {"path": "https://example.com/bucket/my_file.txt"} + + assert uploader.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_prediction_id(respx_mock): + uploader = respx_mock.put( + "/bucket/my_file.txt", headers={"x-prediction-id": "p123"} + ).mock(return_value=httpx.Response(201)) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id="p123" + ) + assert result == {"path": "https://example.com/bucket/my_file.txt"} + + assert uploader.call_count == 1 + + +@pytest.mark.asyncio +@pytest.mark.respx(base_url="https://example.com") +async def test_upload_files_with_location_header(respx_mock): + uploader = respx_mock.put("/bucket/my_file.txt").mock( + return_value=httpx.Response( + 201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"} + ) + ) + + client_manager = ClientManager() + temp_dir = tempfile.mkdtemp() + temp_path = os.path.join(temp_dir, "my_file.txt") + with open(temp_path, "w") as fh: + fh.write("file content") + + obj = {"path": cog.Path(temp_path)} + result = await client_manager.upload_files( + obj, url="https://example.com/bucket", prediction_id=None + ) + assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"} + + assert uploader.call_count == 1