Skip to content

Commit

Permalink
Merge pull request #492 from mwakaba2/upgrade-anyio-3
Browse files Browse the repository at this point in the history
Upgrade anyio to v3
  • Loading branch information
kevin-bates authored May 3, 2021
2 parents 5e77b73 + eff9b0a commit 678878f
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 50 deletions.
13 changes: 9 additions & 4 deletions jupyter_server/services/contents/filecheckpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
)
from .fileio import AsyncFileManagerMixin, FileManagerMixin

from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from jupyter_core.utils import ensure_dir_exists
from traitlets import Unicode

Expand Down Expand Up @@ -156,7 +161,7 @@ async def restore_checkpoint(self, contents_mgr, checkpoint_id, path):

async def checkpoint_model(self, checkpoint_id, os_path):
"""construct the info dict for a given checkpoint"""
stats = await run_sync_in_worker_thread(os.stat, os_path)
stats = await run_sync(os.stat, os_path)
last_modified = tz.utcfromtimestamp(stats.st_mtime)
info = dict(
id=checkpoint_id,
Expand All @@ -176,7 +181,7 @@ async def rename_checkpoint(self, checkpoint_id, old_path, new_path):
new_cp_path,
)
with self.perm_to_403():
await run_sync_in_worker_thread(shutil.move, old_cp_path, new_cp_path)
await run_sync(shutil.move, old_cp_path, new_cp_path)

async def delete_checkpoint(self, checkpoint_id, path):
"""delete a file's checkpoint"""
Expand All @@ -187,7 +192,7 @@ async def delete_checkpoint(self, checkpoint_id, path):

self.log.debug("unlinking %s", cp_path)
with self.perm_to_403():
await run_sync_in_worker_thread(os.unlink, cp_path)
await run_sync(os.unlink, cp_path)

async def list_checkpoints(self, path):
"""list the checkpoints for a given file
Expand Down
21 changes: 13 additions & 8 deletions jupyter_server/services/contents/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import os
import shutil

from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from tornado.web import HTTPError

from jupyter_server.utils import (
Expand All @@ -36,7 +41,7 @@ def replace_file(src, dst):
async def async_replace_file(src, dst):
""" replace dst with src asynchronously
"""
await run_sync_in_worker_thread(os.replace, src, dst)
await run_sync(os.replace, src, dst)

def copy2_safe(src, dst, log=None):
"""copy src to dst
Expand All @@ -55,9 +60,9 @@ async def async_copy2_safe(src, dst, log=None):
like shutil.copy2, but log errors in copystat instead of raising
"""
await run_sync_in_worker_thread(shutil.copyfile, src, dst)
await run_sync(shutil.copyfile, src, dst)
try:
await run_sync_in_worker_thread(shutil.copystat, src, dst)
await run_sync(shutil.copystat, src, dst)
except OSError:
if log:
log.debug("copystat on %s failed", dst, exc_info=True)
Expand Down Expand Up @@ -355,7 +360,7 @@ async def _read_notebook(self, os_path, as_version=4):
"""Read a notebook from an os path."""
with self.open(os_path, 'r', encoding='utf-8') as f:
try:
return await run_sync_in_worker_thread(partial(nbformat.read, as_version=as_version), f)
return await run_sync(partial(nbformat.read, as_version=as_version), f)
except Exception as e:
e_orig = e

Expand All @@ -379,7 +384,7 @@ async def _read_notebook(self, os_path, as_version=4):
async def _save_notebook(self, os_path, nb):
"""Save a notebook to an os_path."""
with self.atomic_writing(os_path, encoding='utf-8') as f:
await run_sync_in_worker_thread(partial(nbformat.write, version=nbformat.NO_CONVERT), nb, f)
await run_sync(partial(nbformat.write, version=nbformat.NO_CONVERT), nb, f)

async def _read_file(self, os_path, format):
"""Read a non-notebook file.
Expand All @@ -394,7 +399,7 @@ async def _read_file(self, os_path, format):
raise HTTPError(400, "Cannot read non-file %s" % os_path)

with self.open(os_path, 'rb') as f:
bcontent = await run_sync_in_worker_thread(f.read)
bcontent = await run_sync(f.read)

if format is None or format == 'text':
# Try to interpret as unicode if format is unknown or if unicode
Expand Down Expand Up @@ -429,4 +434,4 @@ async def _save_file(self, os_path, content, format):
) from e

with self.atomic_writing(os_path, text=False) as f:
await run_sync_in_worker_thread(f.write, bcontent)
await run_sync(f.write, bcontent)
25 changes: 15 additions & 10 deletions jupyter_server/services/contents/filemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import mimetypes
import nbformat

from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from send2trash import send2trash
from tornado import web

Expand Down Expand Up @@ -578,7 +583,7 @@ async def _dir_model(self, path, content=True):
if content:
model['content'] = contents = []
os_dir = self._get_os_path(path)
dir_contents = await run_sync_in_worker_thread(os.listdir, os_dir)
dir_contents = await run_sync(os.listdir, os_dir)
for name in dir_contents:
try:
os_path = os.path.join(os_dir, name)
Expand All @@ -588,7 +593,7 @@ async def _dir_model(self, path, content=True):
continue

try:
st = await run_sync_in_worker_thread(os.lstat, os_path)
st = await run_sync(os.lstat, os_path)
except OSError as e:
# skip over broken symlinks in listing
if e.errno == errno.ENOENT:
Expand Down Expand Up @@ -721,7 +726,7 @@ async def _save_directory(self, os_path, model, path=''):
raise web.HTTPError(400, u'Cannot create hidden directory %r' % os_path)
if not os.path.exists(os_path):
with self.perm_to_403():
await run_sync_in_worker_thread(os.mkdir, os_path)
await run_sync(os.mkdir, os_path)
elif not os.path.isdir(os_path):
raise web.HTTPError(400, u'Not a directory: %s' % (os_path))
else:
Expand Down Expand Up @@ -791,16 +796,16 @@ async def _check_trash(os_path):
# It's a bit more nuanced than this, but until we can better
# distinguish errors from send2trash, assume that we can only trash
# files on the same partition as the home directory.
file_dev = (await run_sync_in_worker_thread(os.stat, os_path)).st_dev
home_dev = (await run_sync_in_worker_thread(os.stat, os.path.expanduser('~'))).st_dev
file_dev = (await run_sync(os.stat, os_path)).st_dev
home_dev = (await run_sync(os.stat, os.path.expanduser('~'))).st_dev
return file_dev == home_dev

async def is_non_empty_dir(os_path):
if os.path.isdir(os_path):
# A directory containing only leftover checkpoints is
# considered empty.
cp_dir = getattr(self.checkpoints, 'checkpoint_dir', None)
dir_contents = set(await run_sync_in_worker_thread(os.listdir, os_path))
dir_contents = set(await run_sync(os.listdir, os_path))
if dir_contents - {cp_dir}:
return True

Expand Down Expand Up @@ -828,11 +833,11 @@ async def is_non_empty_dir(os_path):
raise web.HTTPError(400, u'Directory %s not empty' % os_path)
self.log.debug("Removing directory %s", os_path)
with self.perm_to_403():
await run_sync_in_worker_thread(shutil.rmtree, os_path)
await run_sync(shutil.rmtree, os_path)
else:
self.log.debug("Unlinking file %s", os_path)
with self.perm_to_403():
await run_sync_in_worker_thread(rm, os_path)
await run_sync(rm, os_path)

async def rename_file(self, old_path, new_path):
"""Rename a file."""
Expand All @@ -851,7 +856,7 @@ async def rename_file(self, old_path, new_path):
# Move the file
try:
with self.perm_to_403():
await run_sync_in_worker_thread(shutil.move, old_os_path, new_os_path)
await run_sync(shutil.move, old_os_path, new_os_path)
except web.HTTPError:
raise
except Exception as e:
Expand Down
9 changes: 7 additions & 2 deletions jupyter_server/services/contents/largefilemanager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from anyio import run_sync_in_worker_thread
try:
from anyio.to_thread import run_sync
except ImportError:
# fallback on anyio v2 for python version < 3.7
from anyio import run_sync_in_worker_thread as run_sync

from tornado import web
import base64
import os, io
Expand Down Expand Up @@ -135,6 +140,6 @@ async def _save_large_file(self, os_path, content, format):
if os.path.islink(os_path):
os_path = os.path.join(os.path.dirname(os_path), os.readlink(os_path))
with io.open(os_path, 'ab') as f:
await run_sync_in_worker_thread(f.write, bcontent)
await run_sync(f.write, bcontent)


48 changes: 23 additions & 25 deletions jupyter_server/tests/services/contents/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def jp_contents_manager(request, tmp_path):


@pytest.fixture(params=[FileContentsManager, AsyncFileContentsManager])
def file_contents_manager_class(request, tmp_path):
def jp_file_contents_manager_class(request, tmp_path):
return request.param

# -------------- Functions ----------------------------
Expand Down Expand Up @@ -100,46 +100,45 @@ async def check_populated_dir_files(jp_contents_manager, api_path):
# ----------------- Tests ----------------------------------


def test_root_dir(file_contents_manager_class, tmp_path):
fm = file_contents_manager_class(root_dir=str(tmp_path))
def test_root_dir(jp_file_contents_manager_class, tmp_path):
fm = jp_file_contents_manager_class(root_dir=str(tmp_path))
assert fm.root_dir == str(tmp_path)


def test_missing_root_dir(file_contents_manager_class, tmp_path):
def test_missing_root_dir(jp_file_contents_manager_class, tmp_path):
root = tmp_path / 'notebook' / 'dir' / 'is' / 'missing'
with pytest.raises(TraitError):
file_contents_manager_class(root_dir=str(root))
jp_file_contents_manager_class(root_dir=str(root))


def test_invalid_root_dir(file_contents_manager_class, tmp_path):
def test_invalid_root_dir(jp_file_contents_manager_class, tmp_path):
temp_file = tmp_path / 'file.txt'
temp_file.write_text('')
with pytest.raises(TraitError):
file_contents_manager_class(root_dir=str(temp_file))
jp_file_contents_manager_class(root_dir=str(temp_file))


def test_get_os_path(file_contents_manager_class, tmp_path):
fm = file_contents_manager_class(root_dir=str(tmp_path))
def test_get_os_path(jp_file_contents_manager_class, tmp_path):
fm = jp_file_contents_manager_class(root_dir=str(tmp_path))
path = fm._get_os_path('/path/to/notebook/test.ipynb')
rel_path_list = '/path/to/notebook/test.ipynb'.split('/')
fs_path = os.path.join(fm.root_dir, *rel_path_list)
assert path == fs_path

fm = file_contents_manager_class(root_dir=str(tmp_path))
fm = jp_file_contents_manager_class(root_dir=str(tmp_path))
path = fm._get_os_path('test.ipynb')
fs_path = os.path.join(fm.root_dir, 'test.ipynb')
assert path == fs_path

fm = file_contents_manager_class(root_dir=str(tmp_path))
fm = jp_file_contents_manager_class(root_dir=str(tmp_path))
path = fm._get_os_path('////test.ipynb')
fs_path = os.path.join(fm.root_dir, 'test.ipynb')
assert path == fs_path


def test_checkpoint_subdir(file_contents_manager_class, tmp_path):
def test_checkpoint_subdir(jp_file_contents_manager_class, tmp_path):
subd = 'sub ∂ir'
cp_name = 'test-cp.ipynb'
fm = file_contents_manager_class(root_dir=str(tmp_path))
fm = jp_file_contents_manager_class(root_dir=str(tmp_path))
tmp_path.joinpath(subd).mkdir()
cpm = fm.checkpoints
cp_dir = cpm.checkpoint_path('cp', 'test.ipynb')
Expand All @@ -148,10 +147,10 @@ def test_checkpoint_subdir(file_contents_manager_class, tmp_path):
assert cp_dir == os.path.join(str(tmp_path), cpm.checkpoint_dir, cp_name)


async def test_bad_symlink(file_contents_manager_class, tmp_path):
async def test_bad_symlink(jp_file_contents_manager_class, tmp_path):
td = str(tmp_path)

cm = file_contents_manager_class(root_dir=td)
cm = jp_file_contents_manager_class(root_dir=td)
path = 'test bad symlink'
_make_dir(cm, path)

Expand All @@ -173,10 +172,10 @@ async def test_bad_symlink(file_contents_manager_class, tmp_path):
sys.platform.startswith('win'),
reason="Windows doesn't detect symlink loops"
)
async def test_recursive_symlink(file_contents_manager_class, tmp_path):
async def test_recursive_symlink(jp_file_contents_manager_class, tmp_path):
td = str(tmp_path)

cm = file_contents_manager_class(root_dir=td)
cm = jp_file_contents_manager_class(root_dir=td)
path = 'test recursive symlink'
_make_dir(cm, path)

Expand All @@ -195,9 +194,9 @@ async def test_recursive_symlink(file_contents_manager_class, tmp_path):
assert 'recursive' not in contents


async def test_good_symlink(file_contents_manager_class, tmp_path):
async def test_good_symlink(jp_file_contents_manager_class, tmp_path):
td = str(tmp_path)
cm = file_contents_manager_class(root_dir=td)
cm = jp_file_contents_manager_class(root_dir=td)
parent = 'test good symlink'
name = 'good symlink'
path = '{0}/{1}'.format(parent, name)
Expand All @@ -216,13 +215,13 @@ async def test_good_symlink(file_contents_manager_class, tmp_path):
sys.platform.startswith('win'),
reason="Can't test permissions on Windows"
)
async def test_403(file_contents_manager_class, tmp_path):
async def test_403(jp_file_contents_manager_class, tmp_path):
if hasattr(os, 'getuid'):
if os.getuid() == 0:
raise pytest.skip("Can't test permissions as root")

td = str(tmp_path)
cm = file_contents_manager_class(root_dir=td)
cm = jp_file_contents_manager_class(root_dir=td)
model = await ensure_async(cm.new_untitled(type='file'))
os_path = cm._get_os_path(model['path'])

Expand All @@ -233,10 +232,9 @@ async def test_403(file_contents_manager_class, tmp_path):
except HTTPError as e:
assert e.status_code == 403


async def test_escape_root(file_contents_manager_class, tmp_path):
async def test_escape_root(jp_file_contents_manager_class, tmp_path):
td = str(tmp_path)
cm = file_contents_manager_class(root_dir=td)
cm = jp_file_contents_manager_class(root_dir=td)
# make foo, bar next to root
with open(os.path.join(cm.root_dir, '..', 'foo'), 'w') as f:
f.write('foo')
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ install_requires =
terminado>=0.8.3
prometheus_client
pywin32>=1.0 ; sys_platform == 'win32'
anyio>=2.0.2,<3
anyio>=2.0.2,<3 ; python_version < '3.7'
anyio>=3.0.1,<4 ; python_version >= '3.7'

[options.extras_require]
test = coverage; pytest; pytest-cov; pytest-mock; requests; pytest-tornasync; pytest-console-scripts; ipykernel
Expand Down

0 comments on commit 678878f

Please sign in to comment.