From 74298df2ca8aae126f1d0fb399e513f7085535a8 Mon Sep 17 00:00:00 2001 From: Denis Rykov Date: Mon, 29 Jul 2024 01:25:47 +0200 Subject: [PATCH] Add non-streaming download option --- CHANGELOG.md | 2 ++ src/stac_asset/_cli.py | 12 +++++++ src/stac_asset/_functions.py | 35 +++++++++++++++------ src/stac_asset/client.py | 14 +++++++-- src/stac_asset/filesystem_client.py | 11 +++++-- src/stac_asset/http_client.py | 10 ++++-- src/stac_asset/planetary_computer_client.py | 4 ++- src/stac_asset/s3_client.py | 11 +++++-- 8 files changed, 81 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5be6e8..cfed93b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] +- Support for non-streaming downloads and the corresponding `--no-stream` flag to the CLI ([#208](https://github.com/stac-utils/stac-asset/pull/208)) + ## [0.4.2] - 2024-07-28 ### Added diff --git a/src/stac_asset/_cli.py b/src/stac_asset/_cli.py index 82de402..e82ea18 100644 --- a/src/stac_asset/_cli.py +++ b/src/stac_asset/_cli.py @@ -146,6 +146,13 @@ def cli() -> None: help="The maximum number of downloads that can be active at one time", default=_functions.DEFAULT_MAX_CONCURRENT_DOWNLOADS, ) +@click.option( + "--no-stream", + help="Disable chunked reading of assets", + default=False, + is_flag=True, + show_default=True, +) # TODO add option to disable content type checking def download( href: Optional[str], @@ -165,6 +172,7 @@ def download( fail_fast: bool, overwrite: bool, max_concurrent_downloads: int, + no_stream: bool, ) -> None: """Download STAC assets from an item or item collection. @@ -204,6 +212,7 @@ def download( fail_fast=fail_fast, overwrite=overwrite, max_concurrent_downloads=max_concurrent_downloads, + no_stream=no_stream, ) ) @@ -226,6 +235,7 @@ async def download_async( fail_fast: bool, overwrite: bool, max_concurrent_downloads: int, + no_stream: bool, ) -> None: config = Config( alternate_assets=alternate_assets, @@ -272,6 +282,7 @@ async def download() -> Union[Item, ItemCollection]: config=config, messages=messages, max_concurrent_downloads=max_concurrent_downloads, + stream=not no_stream, ) elif type_ == "FeatureCollection": @@ -286,6 +297,7 @@ async def download() -> Union[Item, ItemCollection]: config=config, messages=messages, max_concurrent_downloads=max_concurrent_downloads, + stream=not no_stream, ) else: diff --git a/src/stac_asset/_functions.py b/src/stac_asset/_functions.py index 5559d06..e42158f 100644 --- a/src/stac_asset/_functions.py +++ b/src/stac_asset/_functions.py @@ -41,8 +41,7 @@ class Download: config: Config async def download( - self, - messages: Optional[MessageQueue], + self, messages: Optional[MessageQueue], stream: bool = True ) -> Union[Download, WrappedError]: if not os.path.exists(self.path) or self.config.overwrite: try: @@ -53,6 +52,7 @@ async def download( config=self.config, messages=messages, clients=self.clients, + stream=stream, ) except Exception as error: if self.config.fail_fast: @@ -134,10 +134,14 @@ async def add( else: stac_object.assets = assets - async def download(self, messages: Optional[MessageQueue]) -> None: + async def download( + self, messages: Optional[MessageQueue], stream: bool = True + ) -> None: tasks: Set[Task[Union[Download, WrappedError]]] = set() for download in self.downloads: - task = asyncio.create_task(self.download_with_lock(download, messages)) + task = asyncio.create_task( + self.download_with_lock(download, messages, stream) + ) tasks.add(task) task.add_done_callback(tasks.discard) @@ -169,11 +173,11 @@ async def download(self, messages: Optional[MessageQueue]) -> None: raise DownloadError(exceptions) async def download_with_lock( - self, download: Download, messages: Optional[MessageQueue] + self, download: Download, messages: Optional[MessageQueue], stream: bool = True ) -> Union[Download, WrappedError]: await self.semaphore.acquire() try: - return await download.download(messages=messages) + return await download.download(messages=messages, stream=stream) finally: self.semaphore.release() @@ -208,6 +212,7 @@ async def download_item( clients: Optional[List[Client]] = None, keep_non_downloaded: bool = False, max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS, + stream: bool = True, ) -> Item: """Downloads an item to the local filesystem. @@ -225,6 +230,8 @@ async def download_item( downloaded. max_concurrent_downloads: The maximum number of downloads that can be active at one time. + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Returns: Item: The `~pystac.Item`, with the updated asset hrefs and self href. @@ -241,7 +248,7 @@ async def download_item( max_concurrent_downloads=max_concurrent_downloads, ) as downloads: await downloads.add(item, Path(directory), file_name, keep_non_downloaded) - await downloads.download(messages) + await downloads.download(messages, stream) self_href = item.get_self_href() if self_href: @@ -263,6 +270,7 @@ async def download_collection( clients: Optional[List[Client]] = None, keep_non_downloaded: bool = False, max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS, + stream: bool = True, ) -> Collection: """Downloads a collection to the local filesystem. @@ -281,6 +289,8 @@ async def download_collection( downloaded. max_concurrent_downloads: The maximum number of downloads that can be active at one time. + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Returns: Collection: The collection, with updated asset hrefs @@ -294,7 +304,7 @@ async def download_collection( max_concurrent_downloads=max_concurrent_downloads, ) as downloads: await downloads.add(collection, Path(directory), file_name, keep_non_downloaded) - await downloads.download(messages) + await downloads.download(messages, stream) self_href = collection.get_self_href() if self_href: @@ -316,6 +326,7 @@ async def download_item_collection( clients: Optional[List[Client]] = None, keep_non_downloaded: bool = False, max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS, + stream: bool = True, ) -> ItemCollection: """Downloads an item collection to the local filesystem. @@ -333,6 +344,8 @@ async def download_item_collection( downloaded. max_concurrent_downloads: The maximum number of downloads that can be active at one time. + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Returns: ItemCollection: The item collection, with updated asset hrefs @@ -352,7 +365,7 @@ async def download_item_collection( item.set_self_href(None) root = Path(directory) / layout_template.substitute(item) await downloads.add(item, root, None, keep_non_downloaded) - await downloads.download(messages) + await downloads.download(messages, stream) if file_name: dest_href = Path(directory) / file_name for item in item_collection.items: @@ -372,6 +385,7 @@ async def download_asset( config: Config, messages: Optional[MessageQueue] = None, clients: Optional[Clients] = None, + stream: bool = True, ) -> Asset: """Downloads an asset. @@ -383,6 +397,8 @@ async def download_asset( messages: An optional queue to use for progress reporting clients: A async-safe cache of clients. If not provided, a new one will be created. + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Returns: Asset: The asset with an updated href @@ -422,6 +438,7 @@ async def download_asset( clean=config.clean, content_type=asset.media_type, messages=messages, + stream=stream, ) except Exception as error: if messages: diff --git a/src/stac_asset/client.py b/src/stac_asset/client.py index bef2f7f..3a7c275 100644 --- a/src/stac_asset/client.py +++ b/src/stac_asset/client.py @@ -41,6 +41,7 @@ async def open_url( url: URL, content_type: Optional[str] = None, messages: Optional[MessageQueue] = None, + stream: bool = True, ) -> AsyncIterator[bytes]: """Opens a url and yields an iterator over its bytes. @@ -51,6 +52,8 @@ async def open_url( content_type: The expected content type, to be checked by the client implementations messages: An optional queue to use for progress reporting + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Yields: AsyncIterator[bytes]: An iterator over chunks of the read file @@ -64,6 +67,7 @@ async def open_href( href: str, content_type: Optional[str] = None, messages: Optional[MessageQueue] = None, + stream: bool = True, ) -> AsyncIterator[bytes]: """Opens a href and yields an iterator over its bytes. @@ -71,12 +75,14 @@ async def open_href( href: The input href content_type: The expected content type messages: An optional queue to use for progress reporting + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Yields: AsyncIterator[bytes]: An iterator over chunks of the read file """ async for chunk in self.open_url( - URL(href), content_type=content_type, messages=messages + URL(href), content_type=content_type, messages=messages, stream=stream ): yield chunk @@ -87,6 +93,7 @@ async def download_href( clean: bool = True, content_type: Optional[str] = None, messages: Optional[MessageQueue] = None, + stream: bool = True, ) -> None: """Downloads a file to the local filesystem. @@ -96,11 +103,13 @@ async def download_href( clean: If an error occurs, delete the output file if it exists content_type: The expected content type messages: An optional queue to use for progress reporting + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory """ try: async with aiofiles.open(path, mode="wb") as f: async for chunk in self.open_href( - href, content_type=content_type, messages=messages + href, content_type=content_type, messages=messages, stream=stream ): await f.write(chunk) if messages: @@ -110,6 +119,7 @@ async def download_href( ) except QueueFull: pass + except Exception as err: path_as_path = Path(path) if clean and path_as_path.exists(): diff --git a/src/stac_asset/filesystem_client.py b/src/stac_asset/filesystem_client.py index d0dcd31..46eccc0 100644 --- a/src/stac_asset/filesystem_client.py +++ b/src/stac_asset/filesystem_client.py @@ -23,6 +23,7 @@ async def open_url( url: URL, content_type: Optional[str] = None, messages: Optional[MessageQueue] = None, + stream: bool = True, ) -> AsyncIterator[bytes]: """Iterates over data from a local url. @@ -31,6 +32,8 @@ async def open_url( content_type: The expected content type. Ignored by this client, because filesystems don't have content types. messages: An optional queue to use for progress reporting + stream: If enabled, it iterates over the bytes of the file; + otherwise, it reads the entire file into memory Yields: AsyncIterator[bytes]: An iterator over the file's bytes. @@ -47,8 +50,12 @@ async def open_url( if messages: await messages.put(OpenUrl(size=os.path.getsize(url.path), url=url)) async with aiofiles.open(url.path, "rb") as f: - async for chunk in f: - yield chunk + if stream: + async for chunk in f: + yield chunk + else: + content = await f.read() + yield content async def assert_href_exists(self, href: str) -> None: """Asserts that an href exists.""" diff --git a/src/stac_asset/http_client.py b/src/stac_asset/http_client.py index 9abe2b3..52abf13 100644 --- a/src/stac_asset/http_client.py +++ b/src/stac_asset/http_client.py @@ -140,6 +140,7 @@ async def open_url( url: URL, content_type: Optional[str] = None, messages: Optional[MessageQueue] = None, + stream: bool = True, ) -> AsyncIterator[bytes]: """Opens a url with this client's session and iterates over its bytes. @@ -147,6 +148,7 @@ async def open_url( url: The url to open content_type: The expected content type messages: An optional queue to use for progress reporting + stream: If enabled, it uses the aiohttp streaming API Yields: AsyncIterator[bytes]: An iterator over the file's bytes @@ -162,8 +164,12 @@ async def open_url( ) if messages: await messages.put(OpenUrl(url=url, size=response.content_length)) - async for chunk, _ in response.content.iter_chunks(): - yield chunk + if stream: + async for chunk, _ in response.content.iter_chunks(): + yield chunk + else: + content = await response.read() + yield content async def assert_href_exists(self, href: str) -> None: """Asserts that the href exists. diff --git a/src/stac_asset/planetary_computer_client.py b/src/stac_asset/planetary_computer_client.py index 8b32de1..42664ef 100644 --- a/src/stac_asset/planetary_computer_client.py +++ b/src/stac_asset/planetary_computer_client.py @@ -70,6 +70,7 @@ async def open_url( url: URL, content_type: Optional[str] = None, messages: Optional[Queue[Any]] = None, + stream: bool = True, ) -> AsyncIterator[bytes]: """Opens a url and iterates over its bytes. @@ -88,13 +89,14 @@ async def open_url( url: The url to open content_type: The expected content type messages: An optional queue to use for progress reporting + stream: If enabled, it uses the aiohttp streaming API Yields: AsyncIterator[bytes]: An iterator over the file's bytes """ url = await self._maybe_sign_url(url) async for chunk in super().open_url( - url, content_type=content_type, messages=messages + url, content_type=content_type, messages=messages, stream=stream ): yield chunk diff --git a/src/stac_asset/s3_client.py b/src/stac_asset/s3_client.py index 1488186..6d9ce68 100644 --- a/src/stac_asset/s3_client.py +++ b/src/stac_asset/s3_client.py @@ -87,6 +87,7 @@ async def open_url( url: URL, content_type: Optional[str] = None, messages: Optional[MessageQueue] = None, + stream: bool = True, ) -> AsyncIterator[bytes]: """Opens an s3 url and iterates over its bytes. @@ -94,6 +95,8 @@ async def open_url( url: The url to open content_type: The expected content type messages: An optional queue to use for progress reporting + stream: If enabled, it iterates over the bytes of the response; + otherwise, it reads the entire file into memory Yields: AsyncIterator[bytes]: An iterator over the file's bytes @@ -107,8 +110,12 @@ async def open_url( validate.content_type(response["ContentType"], content_type) if messages: await messages.put(OpenUrl(url=url, size=response["ContentLength"])) - async for chunk in response["Body"]: - yield chunk + if stream: + async for chunk in response["Body"]: + yield chunk + else: + content = await response["Body"].read() + yield content async def has_credentials(self) -> bool: """Returns true if the sessions has credentials."""