Skip to content

Commit

Permalink
✨ support for files bigger than 1MB in sync (#509)
Browse files Browse the repository at this point in the history
* files that are larger than the 1MB limit allowed in the dbfs put API v2.0 will use the streaming api calls to upload the content and avoid the http 400 MAX_BLOCK_SIZE_EXCEEDED error

* merge with upstream changes

* fix test

* add changelog

---------

Co-authored-by: Ivan Trusov <polarpersonal@gmail.com>
Co-authored-by: Matt Hayes <matt.hayes@databricks.com>
  • Loading branch information
3 people authored Mar 21, 2023
1 parent a37e44e commit f55b40a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 45 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- 📌 switch from using `retry` to `tenacity`

### Added
- ✨ support for files bigger than 1MB in sync

## [0.8.8] - 2022-02-22

# Fixed
Expand Down
53 changes: 44 additions & 9 deletions dbx/sync/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def _api(
more_opts = {"ssl": ssl} if ssl is not None else {}
async with session.post(url=url, json=json_data, headers=headers, **more_opts) as resp:
if resp.status in ok_status:
break
return await resp.json()
if resp.status == 429:
dbx_echo("Rate limited")
await _rate_limit_sleep(resp)
Expand Down Expand Up @@ -197,14 +197,49 @@ async def put(
path = f"{self.base_path}/{sub_path}"
with open(full_source_path, "rb") as f:
contents = base64.b64encode(f.read()).decode("ascii")
await self._api_put(
api_base_path=self.api_base_path,
path=path,
session=session,
api_token=self.api_token,
contents=contents,
ssl=self.ssl,
)

if len(contents) <= 1024 * 1024:
await self._api_put(
api_base_path=self.api_base_path,
path=path,
session=session,
api_token=self.api_token,
contents=contents,
ssl=self.ssl,
)
else:
dbx_echo(f"Streaming {path}")

resp = await self._api(
url=f"{self.api_base_path}/create",
path=path,
session=session,
api_token=self.api_token,
ssl=self.ssl,
overwrite=True,
)
handle = resp.get("handle")
import textwrap

chunks = textwrap.wrap(contents, 1024 * 1024)
for chunk in chunks:
await self._api(
url=f"{self.api_base_path}/add-block",
path=path,
session=session,
api_token=self.api_token,
ssl=self.ssl,
handle=handle,
data=chunk,
)
await self._api(
url=f"{self.api_base_path}/close",
path=path,
session=session,
api_token=self.api_token,
ssl=self.ssl,
handle=handle,
)


class ReposClient(BaseClient):
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/sync/clients/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ def dummy_file_path() -> str:
with open(file_path, "w") as f:
f.write("yo")
yield file_path


@pytest.fixture
def dummy_file_path_2mb() -> str:
with temporary_directory() as tempdir:
file_path = os.path.join(tempdir, "file")
with open(file_path, "w") as f:
f.write("y" * 1024 * 2048)
yield file_path
96 changes: 69 additions & 27 deletions tests/unit/sync/clients/test_dbfs_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import base64
from unittest.mock import AsyncMock, MagicMock, PropertyMock
import textwrap
from tests.unit.sync.utils import create_async_with_result
from unittest.mock import AsyncMock, MagicMock, PropertyMock, call

import pytest

Expand All @@ -22,7 +24,7 @@ def test_init(client):

def test_delete(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -39,7 +41,7 @@ def test_delete_secure(client: DBFSClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False)
client = DBFSClient(base_path="/tmp/foo", config=mock_config)
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -50,21 +52,6 @@ def test_delete_secure(client: DBFSClient):
assert session.post.call_args[1]["ssl"] is True


def test_delete_secure(client: DBFSClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True)
client = DBFSClient(base_path="/tmp/foo", config=mock_config)
session = MagicMock()
resp = MagicMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))

assert session.post.call_count == 1
assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete"
assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"}
assert session.post.call_args[1]["ssl"] is False


def test_delete_backslash(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
Expand All @@ -82,7 +69,7 @@ def test_delete_no_path(client: DBFSClient):

def test_delete_recursive(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True))
Expand All @@ -98,7 +85,7 @@ def test_delete_rate_limited(client: DBFSClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -118,7 +105,7 @@ def test_delete_rate_limited_retry_after(client: DBFSClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down Expand Up @@ -146,7 +133,7 @@ def test_delete_unauthorized(client: DBFSClient):

def test_mkdirs(client: DBFSClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.mkdirs(sub_path="foo/bar", session=session))
Expand Down Expand Up @@ -179,7 +166,7 @@ def test_mkdirs_rate_limited(client: DBFSClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -199,7 +186,7 @@ def test_mkdirs_rate_limited_retry_after(client: DBFSClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down Expand Up @@ -227,7 +214,7 @@ def test_mkdirs_unauthorized(client: DBFSClient):

def test_put(client: DBFSClient, dummy_file_path: str):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)

Expand All @@ -244,6 +231,61 @@ def test_put(client: DBFSClient, dummy_file_path: str):
assert is_dbfs_user_agent(session.post.call_args[1]["headers"]["user-agent"])


def test_put_max_block_size_exceeded(client: DBFSClient, dummy_file_path_2mb: str):
expected_handle = 1234

async def mock_json(*args, **kwargs):
return {"handle": expected_handle}

def mock_post(url, *args, **kwargs):
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
if "/api/2.0/dbfs/put" in url:
contents = kwargs.get("json").get("contents")
if len(contents) > 1024 * 1024: # replicate the api error thrown when contents exceeds max allowed
setattr(type(resp), "status", PropertyMock(return_value=400))
elif "/api/2.0/dbfs/create" in url:
# return a mock response json
resp.json = MagicMock(side_effect=mock_json)

return create_async_with_result(resp)

session = AsyncMock()
post = MagicMock(side_effect=mock_post)
session.post = post

asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path_2mb, session=session))

with open(dummy_file_path_2mb, "r") as f:
expected_contents = f.read()

chunks = textwrap.wrap(base64.b64encode(bytes(expected_contents, encoding="utf8")).decode("ascii"), 1024 * 1024)

assert session.post.call_count == len(chunks) + 2
assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/create"
assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block"
assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block"
assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block"
assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/close"

assert session.post.call_args_list[0][1]["json"] == {
"path": "dbfs:/tmp/foo/foo/bar",
"overwrite": True,
}

for i, chunk in enumerate(chunks):
assert session.post.call_args_list[i + 1][1]["json"] == {
"data": chunk,
"path": "dbfs:/tmp/foo/foo/bar",
"handle": expected_handle,
}, f"invalid json for chunk {i}"

assert session.post.call_args_list[4][1]["json"] == {
"path": "dbfs:/tmp/foo/foo/bar",
"handle": expected_handle,
}


def test_put_backslash(client: DBFSClient, dummy_file_path: str):
session = MagicMock()
resp = MagicMock()
Expand All @@ -267,7 +309,7 @@ def test_put_rate_limited(client: DBFSClient, dummy_file_path: str):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -291,7 +333,7 @@ def test_put_rate_limited_retry_after(client: DBFSClient, dummy_file_path: str):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/sync/clients/test_repos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_init(mock_config):

def test_delete(client: ReposClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -50,7 +50,7 @@ def test_delete_secure(client: ReposClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False)
client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config)
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand All @@ -65,7 +65,7 @@ def test_delete_insecure(client: ReposClient):
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True)
client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config)
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session))
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_delete_no_path(client: ReposClient):

def test_delete_recursive(client: ReposClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True))
Expand All @@ -109,7 +109,7 @@ def test_delete_rate_limited(client: ReposClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -129,7 +129,7 @@ def test_delete_rate_limited_retry_after(client: ReposClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_delete_unauthorized(client: ReposClient):

def test_mkdirs(client: ReposClient):
session = MagicMock()
resp = MagicMock()
resp = AsyncMock()
setattr(type(resp), "status", PropertyMock(return_value=200))
session.post.return_value = create_async_with_result(resp)
asyncio.run(client.mkdirs(sub_path="foo/bar", session=session))
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_mkdirs_rate_limited(client: ReposClient):
rate_limit_resp = MagicMock()
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))

Expand All @@ -210,7 +210,7 @@ def test_mkdirs_rate_limited_retry_after(client: ReposClient):
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))

success_resp = MagicMock()
success_resp = AsyncMock()
setattr(type(success_resp), "status", PropertyMock(return_value=200))

session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
Expand Down

0 comments on commit f55b40a

Please sign in to comment.