Skip to content

Commit

Permalink
[Feature] Support Models in dbutils.fs operations (#750)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->
- Support files operations in WorkspaceClient.Files for Databricks UC
Model artifacts so that user can use databricks sdk to download UC model
artifacts.
- This PR is part of the work to migrate mlflow client towards using
databricks sdk for model artifacts download/upload operations for better
security.

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->
- Existing tests in test_dbfs_mixins.py, similar to how _VolumesPath is
tested
- The following code works
```
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()
resp = w.files.download("/Models/system/ai/dbrx_instruct/3/MLmodel")
```

- [x] `make test` run locally
- [x] `make fmt` applied
- [x] relevant integration tests applied
  • Loading branch information
shichengzhou-db committed Sep 12, 2024
1 parent b34f502 commit 3162545
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
18 changes: 9 additions & 9 deletions databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __repr__(self) -> str:
return f"<_DbfsIO {self._path} {'read' if self.readable() else 'write'}=True>"


class _VolumesIO(BinaryIO):
class _FilesIO(BinaryIO):

def __init__(self, api: files.FilesAPI, path: str, *, read: bool, write: bool, overwrite: bool):
self._buffer = []
Expand Down Expand Up @@ -262,7 +262,7 @@ def __exit__(self, __t, __value, __traceback):
self.close()

def __repr__(self) -> str:
return f"<_VolumesIO {self._path} {'read' if self.readable() else 'write'}=True>"
return f"<_FilesIO {self._path} {'read' if self.readable() else 'write'}=True>"


class _Path(ABC):
Expand Down Expand Up @@ -398,7 +398,7 @@ def __repr__(self) -> str:
return f'<_LocalPath {self._path}>'


class _VolumesPath(_Path):
class _FilesPath(_Path):

def __init__(self, api: files.FilesAPI, src: Union[str, pathlib.Path]):
self._path = pathlib.PurePosixPath(str(src).replace('dbfs:', '').replace('file:', ''))
Expand All @@ -411,7 +411,7 @@ def _is_dbfs(self) -> bool:
return False

def child(self, path: str) -> Self:
return _VolumesPath(self._api, str(self._path / path))
return _FilesPath(self._api, str(self._path / path))

def _is_dir(self) -> bool:
try:
Expand All @@ -431,7 +431,7 @@ def exists(self) -> bool:
return self.is_dir

def open(self, *, read=False, write=False, overwrite=False) -> BinaryIO:
return _VolumesIO(self._api, self.as_string, read=read, write=write, overwrite=overwrite)
return _FilesIO(self._api, self.as_string, read=read, write=write, overwrite=overwrite)

def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]:
if not self.is_dir:
Expand All @@ -458,13 +458,13 @@ def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]:
def delete(self, *, recursive=False):
if self.is_dir:
for entry in self.list(recursive=False):
_VolumesPath(self._api, entry.path).delete(recursive=True)
_FilesPath(self._api, entry.path).delete(recursive=True)
self._api.delete_directory(self.as_string)
else:
self._api.delete(self.as_string)

def __repr__(self) -> str:
return f'<_VolumesPath {self._path}>'
return f'<_FilesPath {self._path}>'


class _DbfsPath(_Path):
Expand Down Expand Up @@ -589,8 +589,8 @@ def _path(self, src):
'UC Volumes paths, not external locations or DBFS mount points.')
if src.scheme == 'file':
return _LocalPath(src.geturl())
if src.path.startswith('/Volumes'):
return _VolumesPath(self._files_api, src.geturl())
if src.path.startswith(('/Volumes', '/Models')):
return _FilesPath(self._files_api, src.geturl())
return _DbfsPath(self._dbfs_api, src.geturl())

def copy(self, src: str, dst: str, *, recursive=False, overwrite=False):
Expand Down
13 changes: 8 additions & 5 deletions tests/test_dbfs_mixins.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

from databricks.sdk.errors import NotFound
from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _LocalPath,
_VolumesPath)
from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _FilesPath,
_LocalPath)


def test_moving_dbfs_file_to_local_dir(config, tmp_path, mocker):
Expand Down Expand Up @@ -55,11 +55,14 @@ def test_moving_local_dir_to_dbfs(config, tmp_path, mocker):


@pytest.mark.parametrize('path,expected_type', [('/path/to/file', _DbfsPath),
('/Volumes/path/to/file', _VolumesPath),
('/Volumes/path/to/file', _FilesPath),
('/Models/path/to/file', _FilesPath),
('dbfs:/path/to/file', _DbfsPath),
('dbfs:/Volumes/path/to/file', _VolumesPath),
('dbfs:/Volumes/path/to/file', _FilesPath),
('dbfs:/Models/path/to/file', _FilesPath),
('file:/path/to/file', _LocalPath),
('file:/Volumes/path/to/file', _LocalPath), ])
('file:/Volumes/path/to/file', _LocalPath),
('file:/Models/path/to/file', _LocalPath), ])
def test_fs_path(config, path, expected_type):
dbfs_ext = DbfsExt(config)
assert isinstance(dbfs_ext._path(path), expected_type)
Expand Down

0 comments on commit 3162545

Please sign in to comment.