Skip to content

Commit

Permalink
Add non-streaming download option
Browse files Browse the repository at this point in the history
  • Loading branch information
drnextgis authored and gadomski committed Jul 29, 2024
1 parent 3981a58 commit 74298df
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

- Support for non-streaming downloads and the corresponding `--no-stream` flag to the CLI ([#208](https://github.com/stac-utils/stac-asset/pull/208))

## [0.4.2] - 2024-07-28

### Added
Expand Down
12 changes: 12 additions & 0 deletions src/stac_asset/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def cli() -> None:
help="The maximum number of downloads that can be active at one time",
default=_functions.DEFAULT_MAX_CONCURRENT_DOWNLOADS,
)
@click.option(
"--no-stream",
help="Disable chunked reading of assets",
default=False,
is_flag=True,
show_default=True,
)
# TODO add option to disable content type checking
def download(
href: Optional[str],
Expand All @@ -165,6 +172,7 @@ def download(
fail_fast: bool,
overwrite: bool,
max_concurrent_downloads: int,
no_stream: bool,
) -> None:
"""Download STAC assets from an item or item collection.
Expand Down Expand Up @@ -204,6 +212,7 @@ def download(
fail_fast=fail_fast,
overwrite=overwrite,
max_concurrent_downloads=max_concurrent_downloads,
no_stream=no_stream,
)
)

Expand All @@ -226,6 +235,7 @@ async def download_async(
fail_fast: bool,
overwrite: bool,
max_concurrent_downloads: int,
no_stream: bool,
) -> None:
config = Config(
alternate_assets=alternate_assets,
Expand Down Expand Up @@ -272,6 +282,7 @@ async def download() -> Union[Item, ItemCollection]:
config=config,
messages=messages,
max_concurrent_downloads=max_concurrent_downloads,
stream=not no_stream,
)

elif type_ == "FeatureCollection":
Expand All @@ -286,6 +297,7 @@ async def download() -> Union[Item, ItemCollection]:
config=config,
messages=messages,
max_concurrent_downloads=max_concurrent_downloads,
stream=not no_stream,
)

else:
Expand Down
35 changes: 26 additions & 9 deletions src/stac_asset/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class Download:
config: Config

async def download(
self,
messages: Optional[MessageQueue],
self, messages: Optional[MessageQueue], stream: bool = True
) -> Union[Download, WrappedError]:
if not os.path.exists(self.path) or self.config.overwrite:
try:
Expand All @@ -53,6 +52,7 @@ async def download(
config=self.config,
messages=messages,
clients=self.clients,
stream=stream,
)
except Exception as error:
if self.config.fail_fast:
Expand Down Expand Up @@ -134,10 +134,14 @@ async def add(
else:
stac_object.assets = assets

async def download(self, messages: Optional[MessageQueue]) -> None:
async def download(
self, messages: Optional[MessageQueue], stream: bool = True
) -> None:
tasks: Set[Task[Union[Download, WrappedError]]] = set()
for download in self.downloads:
task = asyncio.create_task(self.download_with_lock(download, messages))
task = asyncio.create_task(
self.download_with_lock(download, messages, stream)
)
tasks.add(task)
task.add_done_callback(tasks.discard)

Expand Down Expand Up @@ -169,11 +173,11 @@ async def download(self, messages: Optional[MessageQueue]) -> None:
raise DownloadError(exceptions)

async def download_with_lock(
self, download: Download, messages: Optional[MessageQueue]
self, download: Download, messages: Optional[MessageQueue], stream: bool = True
) -> Union[Download, WrappedError]:
await self.semaphore.acquire()
try:
return await download.download(messages=messages)
return await download.download(messages=messages, stream=stream)
finally:
self.semaphore.release()

Expand Down Expand Up @@ -208,6 +212,7 @@ async def download_item(
clients: Optional[List[Client]] = None,
keep_non_downloaded: bool = False,
max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS,
stream: bool = True,
) -> Item:
"""Downloads an item to the local filesystem.
Expand All @@ -225,6 +230,8 @@ async def download_item(
downloaded.
max_concurrent_downloads: The maximum number of downloads that can be
active at one time.
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
Returns:
Item: The `~pystac.Item`, with the updated asset hrefs and self href.
Expand All @@ -241,7 +248,7 @@ async def download_item(
max_concurrent_downloads=max_concurrent_downloads,
) as downloads:
await downloads.add(item, Path(directory), file_name, keep_non_downloaded)
await downloads.download(messages)
await downloads.download(messages, stream)

self_href = item.get_self_href()
if self_href:
Expand All @@ -263,6 +270,7 @@ async def download_collection(
clients: Optional[List[Client]] = None,
keep_non_downloaded: bool = False,
max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS,
stream: bool = True,
) -> Collection:
"""Downloads a collection to the local filesystem.
Expand All @@ -281,6 +289,8 @@ async def download_collection(
downloaded.
max_concurrent_downloads: The maximum number of downloads that can be
active at one time.
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
Returns:
Collection: The collection, with updated asset hrefs
Expand All @@ -294,7 +304,7 @@ async def download_collection(
max_concurrent_downloads=max_concurrent_downloads,
) as downloads:
await downloads.add(collection, Path(directory), file_name, keep_non_downloaded)
await downloads.download(messages)
await downloads.download(messages, stream)

self_href = collection.get_self_href()
if self_href:
Expand All @@ -316,6 +326,7 @@ async def download_item_collection(
clients: Optional[List[Client]] = None,
keep_non_downloaded: bool = False,
max_concurrent_downloads: int = DEFAULT_MAX_CONCURRENT_DOWNLOADS,
stream: bool = True,
) -> ItemCollection:
"""Downloads an item collection to the local filesystem.
Expand All @@ -333,6 +344,8 @@ async def download_item_collection(
downloaded.
max_concurrent_downloads: The maximum number of downloads that can be
active at one time.
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
Returns:
ItemCollection: The item collection, with updated asset hrefs
Expand All @@ -352,7 +365,7 @@ async def download_item_collection(
item.set_self_href(None)
root = Path(directory) / layout_template.substitute(item)
await downloads.add(item, root, None, keep_non_downloaded)
await downloads.download(messages)
await downloads.download(messages, stream)
if file_name:
dest_href = Path(directory) / file_name
for item in item_collection.items:
Expand All @@ -372,6 +385,7 @@ async def download_asset(
config: Config,
messages: Optional[MessageQueue] = None,
clients: Optional[Clients] = None,
stream: bool = True,
) -> Asset:
"""Downloads an asset.
Expand All @@ -383,6 +397,8 @@ async def download_asset(
messages: An optional queue to use for progress reporting
clients: A async-safe cache of clients. If not provided, a new one
will be created.
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
Returns:
Asset: The asset with an updated href
Expand Down Expand Up @@ -422,6 +438,7 @@ async def download_asset(
clean=config.clean,
content_type=asset.media_type,
messages=messages,
stream=stream,
)
except Exception as error:
if messages:
Expand Down
14 changes: 12 additions & 2 deletions src/stac_asset/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def open_url(
url: URL,
content_type: Optional[str] = None,
messages: Optional[MessageQueue] = None,
stream: bool = True,
) -> AsyncIterator[bytes]:
"""Opens a url and yields an iterator over its bytes.
Expand All @@ -51,6 +52,8 @@ async def open_url(
content_type: The expected content type, to be checked by the client
implementations
messages: An optional queue to use for progress reporting
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
Yields:
AsyncIterator[bytes]: An iterator over chunks of the read file
Expand All @@ -64,19 +67,22 @@ async def open_href(
href: str,
content_type: Optional[str] = None,
messages: Optional[MessageQueue] = None,
stream: bool = True,
) -> AsyncIterator[bytes]:
"""Opens a href and yields an iterator over its bytes.
Args:
href: The input href
content_type: The expected content type
messages: An optional queue to use for progress reporting
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
Yields:
AsyncIterator[bytes]: An iterator over chunks of the read file
"""
async for chunk in self.open_url(
URL(href), content_type=content_type, messages=messages
URL(href), content_type=content_type, messages=messages, stream=stream
):
yield chunk

Expand All @@ -87,6 +93,7 @@ async def download_href(
clean: bool = True,
content_type: Optional[str] = None,
messages: Optional[MessageQueue] = None,
stream: bool = True,
) -> None:
"""Downloads a file to the local filesystem.
Expand All @@ -96,11 +103,13 @@ async def download_href(
clean: If an error occurs, delete the output file if it exists
content_type: The expected content type
messages: An optional queue to use for progress reporting
stream: If enabled, it iterates over the bytes of the response;
otherwise, it reads the entire file into memory
"""
try:
async with aiofiles.open(path, mode="wb") as f:
async for chunk in self.open_href(
href, content_type=content_type, messages=messages
href, content_type=content_type, messages=messages, stream=stream
):
await f.write(chunk)
if messages:
Expand All @@ -110,6 +119,7 @@ async def download_href(
)
except QueueFull:
pass

except Exception as err:
path_as_path = Path(path)
if clean and path_as_path.exists():
Expand Down
11 changes: 9 additions & 2 deletions src/stac_asset/filesystem_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def open_url(
url: URL,
content_type: Optional[str] = None,
messages: Optional[MessageQueue] = None,
stream: bool = True,
) -> AsyncIterator[bytes]:
"""Iterates over data from a local url.
Expand All @@ -31,6 +32,8 @@ async def open_url(
content_type: The expected content type. Ignored by this client,
because filesystems don't have content types.
messages: An optional queue to use for progress reporting
stream: If enabled, it iterates over the bytes of the file;
otherwise, it reads the entire file into memory
Yields:
AsyncIterator[bytes]: An iterator over the file's bytes.
Expand All @@ -47,8 +50,12 @@ async def open_url(
if messages:
await messages.put(OpenUrl(size=os.path.getsize(url.path), url=url))
async with aiofiles.open(url.path, "rb") as f:
async for chunk in f:
yield chunk
if stream:
async for chunk in f:
yield chunk
else:
content = await f.read()
yield content

async def assert_href_exists(self, href: str) -> None:
"""Asserts that an href exists."""
Expand Down
10 changes: 8 additions & 2 deletions src/stac_asset/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,15 @@ async def open_url(
url: URL,
content_type: Optional[str] = None,
messages: Optional[MessageQueue] = None,
stream: bool = True,
) -> AsyncIterator[bytes]:
"""Opens a url with this client's session and iterates over its bytes.
Args:
url: The url to open
content_type: The expected content type
messages: An optional queue to use for progress reporting
stream: If enabled, it uses the aiohttp streaming API
Yields:
AsyncIterator[bytes]: An iterator over the file's bytes
Expand All @@ -162,8 +164,12 @@ async def open_url(
)
if messages:
await messages.put(OpenUrl(url=url, size=response.content_length))
async for chunk, _ in response.content.iter_chunks():
yield chunk
if stream:
async for chunk, _ in response.content.iter_chunks():
yield chunk
else:
content = await response.read()
yield content

async def assert_href_exists(self, href: str) -> None:
"""Asserts that the href exists.
Expand Down
4 changes: 3 additions & 1 deletion src/stac_asset/planetary_computer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def open_url(
url: URL,
content_type: Optional[str] = None,
messages: Optional[Queue[Any]] = None,
stream: bool = True,
) -> AsyncIterator[bytes]:
"""Opens a url and iterates over its bytes.
Expand All @@ -88,13 +89,14 @@ async def open_url(
url: The url to open
content_type: The expected content type
messages: An optional queue to use for progress reporting
stream: If enabled, it uses the aiohttp streaming API
Yields:
AsyncIterator[bytes]: An iterator over the file's bytes
"""
url = await self._maybe_sign_url(url)
async for chunk in super().open_url(
url, content_type=content_type, messages=messages
url, content_type=content_type, messages=messages, stream=stream
):
yield chunk

Expand Down
Loading

0 comments on commit 74298df

Please sign in to comment.