diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3d341a3..0cc08f54 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,7 @@ repos: - types-six - types-toml - types-ujson + - types-aiofiles args: [--show-error-codes] - repo: https://github.com/Quantco/pre-commit-mirrors-prettier rev: 2.7.1 diff --git a/environment.yml b/environment.yml index 4eb1ca7d..42ff7132 100644 --- a/environment.yml +++ b/environment.yml @@ -33,6 +33,7 @@ dependencies: - tenacity - xattr - aiofiles + - aioshutil - pyyaml - ujson - prometheus_client diff --git a/quetz/main.py b/quetz/main.py index 30ea4749..f187ee7b 100644 --- a/quetz/main.py +++ b/quetz/main.py @@ -1388,7 +1388,7 @@ async def post_upload( dest = os.path.join(condainfo.info["subdir"], filename) body.seek(0) - pkgstore.add_package(body, channel_name, dest) + await pkgstore.add_package_async(body, channel_name, dest) package_name = str(condainfo.info.get("name")) package_data = rest_models.Package( diff --git a/quetz/pkgstores.py b/quetz/pkgstores.py index 1d8de253..4a921479 100644 --- a/quetz/pkgstores.py +++ b/quetz/pkgstores.py @@ -18,6 +18,8 @@ from threading import Lock from typing import IO, List, Tuple, Union +import aiofiles +import aioshutil import fsspec from tenacity import retry from tenacity.retry import retry_if_exception_type @@ -80,6 +82,10 @@ def list_files(self, channel: str) -> List[str]: def url(self, channel: str, src: str, expires: int = 0) -> str: pass + @abc.abstractmethod + def add_package_async(self, package: File, channel: str, destination: str): + pass + @abc.abstractmethod def add_package(self, package: File, channel: str, destination: str): pass @@ -175,6 +181,15 @@ def add_package(self, package: File, channel: str, destination: str) -> None: with self._atomic_open(channel, destination) as f: shutil.copyfileobj(package, f) + async def add_package_async( + self, package: File, channel: str, destination: str + ) -> None: + full_path = path.join(self.channels_dir, channel, destination) + self.fs.makedirs(path.dirname(full_path), exist_ok=True) + + async with aiofiles.open(full_path, 'wb') as f: + await f.write(package.read()) + def add_file( self, data: Union[str, bytes], channel: str, destination: StrPath ) -> None: @@ -286,7 +301,9 @@ def __init__(self, config): # to the s3fs constructor key = config['key'] if config['key'] != '' else None secret = config['secret'] if config['secret'] != '' else None - self.fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) + self.fs = s3fs.S3FileSystem( + key=key, secret=secret, asynchronous=True, client_kwargs=client_kwargs + ) self.bucket_prefix = config['bucket_prefix'] self.bucket_suffix = config['bucket_suffix'] @@ -333,6 +350,15 @@ def add_package(self, package: File, channel: str, destination: str) -> None: # use a chunk size of 10 Megabytes shutil.copyfileobj(package, pkg, 10 * 1024 * 1024) + async def add_package_async( + self, package: File, channel: str, destination: str + ) -> None: + with self._get_fs() as fs: + bucket = self._bucket_map(channel) + with fs.open(path.join(bucket, destination), "wb", acl="private") as pkg: + # use a chunk size of 10 Megabytes + await aioshutil.copyfileobj(package, pkg, 10 * 1024 * 1024) + def add_file( self, data: Union[str, bytes], channel: str, destination: StrPath ) -> None: @@ -426,6 +452,7 @@ def __init__(self, config): account_name=self.storage_account_name, connection_string=self.conn_string, account_key=self.access_key, + asynchronous=True, ) self.container_prefix = config['container_prefix'] @@ -472,6 +499,15 @@ def add_package(self, package: File, channel: str, destination: str) -> None: # use a chunk size of 10 Megabytes shutil.copyfileobj(package, pkg, 10 * 1024 * 1024) + async def add_package_async( + self, package: File, channel: str, destination: str + ) -> None: + with self._get_fs() as fs: + container = self._container_map(channel) + with fs.open(path.join(container, destination), "wb") as pkg: + # use a chunk size of 10 Megabytes + await aioshutil.copyfileobj(package, pkg, 10 * 1024 * 1024) + def add_file( self, data: Union[str, bytes], channel: str, destination: StrPath ) -> None: @@ -571,6 +607,7 @@ def __init__(self, config): token=self.token if self.token else None, cache_timeout=self.cache_timeout, default_location=self.region, + asynchronous=True, ) self.bucket_prefix = config['bucket_prefix'] @@ -621,6 +658,15 @@ def add_package(self, package: File, channel: str, destination: str) -> None: # use a chunk size of 10 Megabytes shutil.copyfileobj(package, pkg, 10 * 1024 * 1024) + async def add_package_async( + self, package: File, channel: str, destination: str + ) -> None: + with self._get_fs() as fs: + container = self._bucket_map(channel) + with fs.open(path.join(container, destination), "wb") as pkg: + # use a chunk size of 10 Megabytes + await aioshutil.copyfileobj(package, pkg, 10 * 1024 * 1024) + def add_file( self, data: Union[str, bytes], channel: str, destination: StrPath ) -> None: diff --git a/quetz/tests/test_pkg_stores.py b/quetz/tests/test_pkg_stores.py index 6c5499b1..52f09784 100644 --- a/quetz/tests/test_pkg_stores.py +++ b/quetz/tests/test_pkg_stores.py @@ -3,6 +3,9 @@ import shutil import time import uuid +from collections.abc import Collection +from io import BytesIO +from pathlib import Path import pytest @@ -10,6 +13,7 @@ AzureBlobStore, GoogleCloudStorageStore, LocalStore, + PackageStore, S3Store, has_xattr, ) @@ -180,30 +184,36 @@ def channel(any_store, channel_name): any_store.remove_channel(channel_name) -def test_store_add_list_files(any_store, channel, channel_name): - def assert_files(expected_files, n_retries=3): - n_retries = 3 - - files = [] - for i in range(n_retries): - files = pkg_store.list_files(channel_name) - try: - assert files == expected_files - except AssertionError: - continue - break - assert files == expected_files +def assert_files( + pkg_store: PackageStore, + channel_name: str, + expected_files: Collection[str], + n_retries: int = 3, +): + """ + Asserts that the files in the package store match the expected files with retry. + """ + for _ in range(n_retries): + files = pkg_store.list_files(channel_name) + try: + assert files == expected_files + except AssertionError: + continue + else: + assert set(files) == set(expected_files) + +def test_store_add_list_files(any_store, channel, channel_name): pkg_store = any_store pkg_store.add_file("content", channel_name, "test.txt") pkg_store.add_file("content", channel_name, "test_2.txt") - assert_files(["test.txt", "test_2.txt"]) + assert_files(pkg_store, channel_name, ["test.txt", "test_2.txt"]) pkg_store.delete_file(channel_name, "test.txt") - assert_files(["test_2.txt"]) + assert_files(pkg_store, channel_name, ["test_2.txt"]) metadata = pkg_store.get_filemetadata(channel_name, "test_2.txt") assert metadata[0] > 0 @@ -211,26 +221,36 @@ def assert_files(expected_files, n_retries=3): assert type(metadata[2]) is str -def test_move_file(any_store, channel, channel_name): - def assert_files(expected_files, n_retries=3): - n_retries = 3 - - files = [] - for i in range(n_retries): - files = pkg_store.list_files(channel_name) - try: - assert files == expected_files - except AssertionError: - continue - break - assert files == expected_files +@pytest.mark.asyncio +async def test_add_package_async(any_store, channel, channel_name): + pkg_store = any_store + + data = (Path(__file__).parent / "data" / "test-package-0.1-0.tar.bz2").read_bytes() + + await pkg_store.add_package_async( + BytesIO(data), channel_name, "test-package-0.1-0.tar.gz" + ) + + assert_files(pkg_store, channel_name, ["test-package-0.1-0.tar.gz"]) + +def test_add_package(any_store, channel, channel_name): + pkg_store = any_store + + data = (Path(__file__).parent / "data" / "test-package-0.1-0.tar.bz2").read_bytes() + + pkg_store.add_package(BytesIO(data), channel_name, "test-package-0.1-0.tar.gz") + + assert_files(pkg_store, channel_name, ["test-package-0.1-0.tar.gz"]) + + +def test_move_file(any_store, channel, channel_name): pkg_store = any_store pkg_store.add_file("content", channel_name, "test.txt") pkg_store.move_file(channel_name, "test.txt", "test_2.txt") - assert_files(['test_2.txt']) + assert_files(pkg_store, channel_name, ['test_2.txt']) @pytest.mark.parametrize("redirect_enabled", [False, True]) diff --git a/setup.cfg b/setup.cfg index e1a19e5c..91215c78 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ install_requires = ujson uvicorn zstandard + aioshutil [options.entry_points] console_scripts =