Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Support URLFile in the upload_file function #1987

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import base64
import io
import mimetypes
Expand All @@ -21,7 +22,7 @@

from .. import types
from ..schema import PredictionResponse, Status, WebhookEvent
from ..types import Path
from ..types import Path, URLFile
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
from .retry_transport import RetryTransport
Expand Down Expand Up @@ -126,11 +127,11 @@ def __init__(self, fh: io.IOBase) -> None:
self.fh = fh

async def __aiter__(self) -> AsyncIterator[bytes]:
self.fh.seek(0)
if self.fh.seekable():
self.fh.seek(0)

while True:
chunk = self.fh.read(1024 * 1024)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
aron marked this conversation as resolved.
Show resolved Hide resolved
if not chunk:
log.info("finished reading file")
break
Expand Down Expand Up @@ -288,7 +289,10 @@ async def upload_files(
with obj.open("rb") as f:
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
try:
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
finally:
obj.close()
aron marked this conversation as resolved.
Show resolved Hide resolved
return obj

# inputs
Expand Down
26 changes: 13 additions & 13 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, TypeVar, Union

import httpx
import requests
from pydantic import Field, SecretStr

FILENAME_ILLEGAL_CHARS = set("\u0000/")
Expand Down Expand Up @@ -195,22 +194,17 @@ def unlink(self, missing_ok: bool = False) -> None:
raise


# we would prefer URLFile to stay lazy
# except... that doesn't really work with httpx?


class URLFile(io.IOBase):
"""
URLFile is a proxy object for a :class:`urllib3.response.HTTPResponse`
object that is created lazily. It's a file-like object constructed from a
URL that can survive pickling/unpickling.

This is the only place Cog uses requests
"""

__slots__ = ("__target__", "__url__")

def __init__(self, url: str) -> None:
object.__setattr__(self, "name", os.path.basename(url))
aron marked this conversation as resolved.
Show resolved Hide resolved
object.__setattr__(self, "__url__", url)

# We provide __getstate__ and __setstate__ explicitly to ensure that the
Expand Down Expand Up @@ -242,19 +236,25 @@ def __delattr__(self, name: str) -> None:

# Luckily the only dunder method on HTTPResponse is __iter__
def __iter__(self) -> Iterator[bytes]:
return iter(self.__wrapped__)
response = self.__wrapped__
return iter(response)

@property
def __wrapped__(self) -> Any:
try:
return object.__getattribute__(self, "__target__")
except AttributeError:
url = object.__getattribute__(self, "__url__")
resp = requests.get(url, stream=True)
resp.raise_for_status()
resp.raw.decode_content = True
object.__setattr__(self, "__target__", resp.raw)
return resp.raw

# We create a streaming response here, much like the `requests`
# version in the main 0.9.x branch. The only concerning bit here
aron marked this conversation as resolved.
Show resolved Hide resolved
# is that the book keeping for closing the response needs to be
# handled elsewhere. There's probably a better design for this
# in the long term.
res = urllib.request.urlopen(url)
object.__setattr__(self, "__target__", res)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels so much simpler than the requests version. What am I missing, why would we use requests here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, actually, thinking about this some more -- does this actually work? This doesn't spin off a background process to download the file, it blocks until the file is downloaded and stores the whole thing in memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we chat about this sync? Is requests threaded by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this in-person and this interface will do what we want. It looks like we can actually get something similar using httpx as well:

with httpx.stream("GET", "https://example.com/") as response:
    raw_stream = response.extensions["network_stream"]
    print(raw_stream.read(2))
# b'\x1f\x8b'

encode/httpx#2296 (comment)

But for the moment we'll go with this simpler implementation and re-factor later.


return res

def __repr__(self) -> str:
try:
Expand Down
69 changes: 63 additions & 6 deletions python/tests/server/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import httpx
from email.message import Message
import io
import os
import responses
import tempfile
from urllib.response import addinfourl
from unittest import mock

import cog
import httpx
import pytest
from cog.server.clients import ClientManager

pytest.mark.asyncio

aron marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.asyncio
async def test_upload_files_without_url():
client_manager = ClientManager()
temp_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -103,9 +107,62 @@ async def test_upload_files_with_retry(respx_mock):

obj = {"path": cog.Path(temp_path)}
with pytest.raises(httpx.HTTPStatusError):
result = await client_manager.upload_files(
await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert uploader.call_count == 3


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
@mock.patch("urllib.request.urlopen")
async def test_upload_files_with_url_file(urlopen_mock, respx_mock):
fp = io.BytesIO(b"hello world")
urlopen_mock.return_value = addinfourl(
fp=fp, headers=Message(), url="https://example.com/cdn/my_file.txt"
)

uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(
201, headers={"Location": "https://cdn.example.com/bucket/my_file.txt"}
)
)

client_manager = ClientManager()

obj = {"path": cog.types.URLFile("https://example.com/cdn/my_file.txt")}
result = await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)
assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}

assert uploader.call_count == 1
assert urlopen_mock.call_count == 1
assert urlopen_mock.call_args[0][0] == "https://example.com/cdn/my_file.txt"


@pytest.mark.asyncio
@pytest.mark.respx(base_url="https://example.com")
@mock.patch("urllib.request.urlopen")
async def test_upload_files_with_url_file_with_retry(urlopen_mock, respx_mock):
fp = io.BytesIO(b"hello world")
urlopen_mock.return_value = addinfourl(
fp=fp, headers=Message(), url="https://example.com/cdn/my_file.txt"
)

uploader = respx_mock.put("/bucket/my_file.txt").mock(
return_value=httpx.Response(502)
)

client_manager = ClientManager()

obj = {"path": cog.types.URLFile("https://example.com/cdn/my_file.txt")}
with pytest.raises(httpx.HTTPStatusError):
await client_manager.upload_files(
obj, url="https://example.com/bucket", prediction_id=None
)

assert result == {"path": "https://cdn.example.com/bucket/my_file.txt"}
assert uploader.call_count == 3
assert uploader.call_count == 3
assert urlopen_mock.call_count == 1
assert urlopen_mock.call_args[0][0] == "https://example.com/cdn/my_file.txt"
Loading