Skip to content

Commit

Permalink
Add unit tests for put_file_to_signed_endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Aron Carroll <aron@replicate.com>
  • Loading branch information
aron committed May 20, 2024
1 parent dc3f931 commit 5a9df78
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 9 deletions.
19 changes: 10 additions & 9 deletions python/cog/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,28 @@ def put_file_to_signed_endpoint(
connect_timeout = 10
read_timeout = 15

headers = {
"Content-Type": content_type,
}
if prediction_id is not None:
headers["X-Prediction-ID"] = prediction_id

resp = client.put(
ensure_trailing_slash(endpoint) + filename,
fh, # type: ignore
headers={
"Content-Type": content_type,
"X-Prediction-ID": prediction_id,
},
headers=headers,
timeout=(connect_timeout, read_timeout),
)
resp.raise_for_status()

# Try to extract the final asset URL from the `Location` header
# otherwise fallback to the URL of the final request.
final_url = resp.url
if url := resp.headers.get("location"):
final_url = url
if "location" in resp.headers:
final_url = resp.headers.get("location")

# strip any signing gubbins from the URL
final_url = urlparse(resp.url)._replace(query="").geturl()

return final_url
return str(urlparse(final_url)._replace(query="").geturl())


def ensure_trailing_slash(url: str) -> str:
Expand Down
93 changes: 93 additions & 0 deletions python/tests/cog/test_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import requests
import io
import responses
from cog.files import put_file_to_signed_endpoint
from unittest.mock import Mock


def test_put_file_to_signed_endpoint():
mock_fh = io.BytesIO()
mock_client = Mock()

mock_response = Mock(spec=requests.Response)
mock_response.status_code = 201
mock_response.text = ""
mock_response.headers = {}
mock_response.url = "http://example.com/upload/file?some-gubbins"
mock_response.ok = True

mock_client.put.return_value = mock_response

final_url = put_file_to_signed_endpoint(
mock_fh, "http://example.com/upload", mock_client, prediction_id=None
)

assert final_url == "http://example.com/upload/file"
mock_client.put.assert_called_with(
"http://example.com/upload/file",
mock_fh,
headers={
"Content-Type": None,
},
timeout=(10, 15),
)


def test_put_file_to_signed_endpoint_with_prediction_id():
mock_fh = io.BytesIO()
mock_client = Mock()

mock_response = Mock(spec=requests.Response)
mock_response.status_code = 201
mock_response.text = ""
mock_response.headers = {}
mock_response.url = "http://example.com/upload/file?some-gubbins"
mock_response.ok = True

mock_client.put.return_value = mock_response

final_url = put_file_to_signed_endpoint(
mock_fh, "http://example.com/upload", mock_client, prediction_id="abc123"
)

assert final_url == "http://example.com/upload/file"
mock_client.put.assert_called_with(
"http://example.com/upload/file",
mock_fh,
headers={
"Content-Type": None,
"X-Prediction-ID": "abc123",
},
timeout=(10, 15),
)


def test_put_file_to_signed_endpoint_with_location():
mock_fh = io.BytesIO()
mock_client = Mock()

mock_response = Mock(spec=requests.Response)
mock_response.status_code = 201
mock_response.text = ""
mock_response.headers = {
"location": "http://cdn.example.com/bucket/file?some-gubbins"
}
mock_response.url = "http://example.com/upload/file?some-gubbins"
mock_response.ok = True

mock_client.put.return_value = mock_response

final_url = put_file_to_signed_endpoint(
mock_fh, "http://example.com/upload", mock_client, prediction_id="abc123"
)

assert final_url == "http://cdn.example.com/bucket/file"
mock_client.put.assert_called_with(
"http://example.com/upload/file",
mock_fh,
headers={
"Content-Type": None,
"X-Prediction-ID": "abc123",
},
timeout=(10, 15),
)

0 comments on commit 5a9df78

Please sign in to comment.