Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ support for files bigger than 1MB in sync #509

Merged
merged 6 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions dbx/sync/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def _api(
more_opts = {"ssl": ssl} if ssl is not None else {}
async with session.post(url=url, json=json_data, headers=headers, **more_opts) as resp:
if resp.status in ok_status:
break
return await resp.json()
if resp.status == 429:
dbx_echo("Rate limited")
await _rate_limit_sleep(resp)
Expand Down Expand Up @@ -197,14 +197,49 @@ async def put(
path = f"{self.base_path}/{sub_path}"
with open(full_source_path, "rb") as f:
contents = base64.b64encode(f.read()).decode("ascii")
await self._api_put(
api_base_path=self.api_base_path,
path=path,
session=session,
api_token=self.api_token,
contents=contents,
ssl=self.ssl,
)

if len(contents) <= 1024 * 1024:
await self._api_put(
api_base_path=self.api_base_path,
path=path,
session=session,
api_token=self.api_token,
contents=contents,
ssl=self.ssl,
)
else:
dbx_echo(f"Streaming {path}")

resp = await self._api(
url=f"{self.api_base_path}/create",
path=path,
session=session,
api_token=self.api_token,
ssl=self.ssl,
overwrite=True,
)
handle = resp.get("handle")
import textwrap

chunks = textwrap.wrap(contents, 1024 * 1024)
for chunk in chunks:
await self._api(
url=f"{self.api_base_path}/add-block",
path=path,
session=session,
api_token=self.api_token,
ssl=self.ssl,
handle=handle,
data=chunk,
)
await self._api(
url=f"{self.api_base_path}/close",
path=path,
session=session,
api_token=self.api_token,
ssl=self.ssl,
handle=handle,
)


class ReposClient(BaseClient):
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/sync/clients/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ def dummy_file_path() -> str:
with open(file_path, "w") as f:
f.write("yo")
yield file_path


@pytest.fixture
def dummy_file_path_2mb() -> str:
with temporary_directory() as tempdir:
file_path = os.path.join(tempdir, "file")
with open(file_path, "w") as f:
f.write("y" * 1024 * 2048)
yield file_path
96 changes: 69 additions & 27 deletions tests/unit/sync/clients/test_dbfs_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import base64
from unittest.mock import AsyncMock, MagicMock, PropertyMock
import textwrap
from tests.unit.sync.utils import create_async_with_result
from unittest.mock import AsyncMock, MagicMock, PropertyMock, call

import pytest

Expand All @@ -22,7 +24,7 @@ def test_init(client):

def test_delete(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -39,7 +41,7 @@ def test_delete_secure(client: DBFSClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False)
client = DBFSClient(base_path="/tmp/foo", config=mock_config)
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -50,21 +52,6 @@ def test_delete_secure(client: DBFSClient):
assert session.post.call_args[1]["ssl"] is True


def test_delete_secure(client: DBFSClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True)
client = DBFSClient(base_path="/tmp/foo", config=mock_config)
session = MagicMock()
resp = MagicMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))

assert session.post.call_count == 1
assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete"
assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"}
assert session.post.call_args[1]["ssl"] is False


def test_delete_backslash(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
Expand All @@ -82,7 +69,7 @@ def test_delete_no_path(client: DBFSClient):

def test_delete_recursive(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True))
Expand All @@ -98,7 +85,7 @@ def test_delete_rate_limited(client: DBFSClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -118,7 +105,7 @@ def test_delete_rate_limited_retry_after(client: DBFSClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down Expand Up @@ -146,7 +133,7 @@ def test_delete_unauthorized(client: DBFSClient):

def test_mkdirs(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.mkdirs(sub_path="foo/bar", session=session))
Expand Down Expand Up @@ -179,7 +166,7 @@ def test_mkdirs_rate_limited(client: DBFSClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -199,7 +186,7 @@ def test_mkdirs_rate_limited_retry_after(client: DBFSClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down Expand Up @@ -227,7 +214,7 @@ def test_mkdirs_unauthorized(client: DBFSClient):

def test_put(client: DBFSClient, dummy_file_path: str):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)

Expand All @@ -244,6 +231,61 @@ def test_put(client: DBFSClient, dummy_file_path: str):
assert is_dbfs_user_agent(session.post.call_args[1]["headers"]["user-agent"])


def test_put_max_block_size_exceeded(client: DBFSClient, dummy_file_path_2mb: str):
expected_handle = 1234

async def mock_json(*args, **kwargs):
return {"handle": expected_handle}

def mock_post(url, *args, **kwargs):
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
if "/base/api/2.0/dbfs/put" in url:
contents = kwargs.get("json").get("contents")
if len(contents) > 1024 * 1024: # replicate the api error thrown when contents exceeds max allowed
setattr(type(resp), "status", PropertyMock(return_value=400))
elif "/base/api/2.0/dbfs/create" in url:
# return a mock response json
resp.json = MagicMock(side_effect=mock_json)

return create_async_with_result(resp)

session = AsyncMock()
post = MagicMock(side_effect=mock_post)
session.post = post

asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path_2mb, session=session))

with open(dummy_file_path_2mb, "r") as f:
expected_contents = f.read()

chunks = textwrap.wrap(base64.b64encode(bytes(expected_contents, encoding="utf8")).decode("ascii"), 1024 * 1024)

assert session.post.call_count == len(chunks) + 2
assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/create"
assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block"
assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block"
assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/add-block"
assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/base/api/2.0/dbfs/close"

assert session.post.call_args_list[0][1]["json"] == {
"path": "dbfs:/tmp/foo/foo/bar",
"overwrite": True,
}

for i, chunk in enumerate(chunks):
assert session.post.call_args_list[i + 1][1]["json"] == {
"data": chunk,
"path": "dbfs:/tmp/foo/foo/bar",
"handle": expected_handle,
}, f"invalid json for chunk {i}"

assert session.post.call_args_list[4][1]["json"] == {
"path": "dbfs:/tmp/foo/foo/bar",
"handle": expected_handle,
}


def test_put_backslash(client: DBFSClient, dummy_file_path: str):
session = MagicMock()
resp = MagicMock()
Expand All @@ -267,7 +309,7 @@ def test_put_rate_limited(client: DBFSClient, dummy_file_path: str):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -291,7 +333,7 @@ def test_put_rate_limited_retry_after(client: DBFSClient, dummy_file_path: str):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/sync/clients/test_repos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_init(mock_config):

def test_delete(client: ReposClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -50,7 +50,7 @@ def test_delete_secure(client: ReposClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False)
client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config)
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -65,7 +65,7 @@ def test_delete_insecure(client: ReposClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True)
client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config)
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_delete_no_path(client: ReposClient):

def test_delete_recursive(client: ReposClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True))
Expand All @@ -109,7 +109,7 @@ def test_delete_rate_limited(client: ReposClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -129,7 +129,7 @@ def test_delete_rate_limited_retry_after(client: ReposClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_delete_unauthorized(client: ReposClient):

def test_mkdirs(client: ReposClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.mkdirs(sub_path="foo/bar", session=session))
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_mkdirs_rate_limited(client: ReposClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -210,7 +210,7 @@ def test_mkdirs_rate_limited_retry_after(client: ReposClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down