Skip to content

Commit

Permalink
feat: progress reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Aug 8, 2023
1 parent 74ba33a commit f3a7494
Show file tree
Hide file tree
Showing 15 changed files with 371 additions and 39 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
- name: Update pip
run: pip install -U pip
- name: Install with development dependencies
run: pip install .[cli,dev]
- name: Check with pre-commit
Expand All @@ -48,6 +50,8 @@ jobs:
with:
python-version: "3.11"
cache: "pip"
- name: Update pip
run: pip install -U pip
- name: Install with development dependencies
run: pip install .[cli,dev]
- name: Install minimum versions of dependencies
Expand All @@ -66,6 +70,8 @@ jobs:
with:
python-version: "3.11"
cache: "pip"
- name: Update pip
run: pip install -U pip
- name: Install with development dependencies
run: pip install .[cli,dev]
- name: Test with coverage
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ repos:
- pystac
- pytest
- types-aiofiles
- types-python-dateutil
- types-tqdm
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.278"
hooks:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- `Client.from_config` and `Client.close` ([#46](https://github.com/stac-utils/stac-asset/pull/46))
- Retry configuration for S3 ([#47](https://github.com/stac-utils/stac-asset/pull/47))
- `Collection` download ([#50](https://github.com/stac-utils/stac-asset/pull/50))
- Progress reporting ([#55](https://github.com/stac-utils/stac-asset/pull/55))

### Changed

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ dependencies = [
"aiobotocore>=2.5.0",
"aiohttp>=3.8.4",
"pystac>=1.7.3",
"python-dateutil>=2.7.0",
"yarl>=1.9.2",
]

[project.optional-dependencies]
cli = ["click~=8.1.5", "click-logging~=1.0.1"]
cli = ["click~=8.1.5", "click-logging~=1.0.1", "tqdm~=4.65.1"]
dev = [
"black~=23.3",
"mypy~=1.3",
Expand All @@ -35,6 +36,8 @@ dev = [
"pytest-cov~=4.1",
"ruff==0.0.282",
"types-aiofiles~=23.1",
"types-python-dateutil~=2.8.19",
"types-tqdm~=4.65.0",
]
docs = ["pydata-sphinx-theme~=0.13", "sphinx~=7.0"]

Expand Down
151 changes: 138 additions & 13 deletions src/stac_asset/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,36 @@
import logging
import os
import sys
from typing import List, Optional, Union
from asyncio import Queue
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import click
import click_logging
import tqdm
from pystac import Item, ItemCollection

from . import Config, functions
from .config import DEFAULT_S3_MAX_ATTEMPTS, DEFAULT_S3_RETRY_MODE
from .messages import (
ErrorAssetDownload,
FinishAssetDownload,
OpenUrl,
StartAssetDownload,
WriteChunk,
)

logger = logging.getLogger(__name__)
click_logging.basic_config(logger)

# Needed until we drop Python 3.8
if TYPE_CHECKING:
AnyQueue = Queue[Any]
Tqdm = tqdm.tqdm[Any]
else:
AnyQueue = Queue
Tqdm = tqdm.tqdm


@click.group()
def cli() -> None:
Expand Down Expand Up @@ -111,6 +129,36 @@ def download(
$ stac-asset download -i asset-key-to-include item.json
"""
asyncio.run(
download_async(
href,
directory,
alternate_assets,
include,
exclude,
file_name,
quiet,
s3_requester_pays,
s3_retry_mode,
s3_max_attempts,
warn,
)
)


async def download_async(
href: Optional[str],
directory: Optional[str],
alternate_assets: List[str],
include: List[str],
exclude: List[str],
file_name: Optional[str],
quiet: bool,
s3_requester_pays: bool,
s3_retry_mode: str,
s3_max_attempts: int,
warn: bool,
) -> None:
config = Config(
alternate_assets=alternate_assets,
include=include,
Expand All @@ -125,39 +173,58 @@ def download(
if href is None or href == "-":
input_dict = json.load(sys.stdin)
else:
input_dict = json.loads(asyncio.run(read_file(href, config)))
input_dict = json.loads(await read_file(href, config))
if directory is None:
directory = os.getcwd()
directory_str = os.getcwd()
else:
directory_str = str(directory)

if quiet:
queue = None
else:
queue = Queue()

type_ = input_dict.get("type")
if type_ is None:
print("ERROR: missing 'type' field on input dictionary", file=sys.stderr)
if not quiet:
print("ERROR: missing 'type' field on input dictionary", file=sys.stderr)
sys.exit(1)
elif type_ == "Feature":
item = Item.from_dict(input_dict)
if href:
item.set_self_href(href)
item.make_asset_hrefs_absolute()
output: Union[Item, ItemCollection] = asyncio.run(
functions.download_item(

async def download() -> Union[Item, ItemCollection]:
return await functions.download_item(
item,
directory,
directory_str,
config=config,
queue=queue,
)
)

elif type_ == "FeatureCollection":
item_collection = ItemCollection.from_dict(input_dict)
output = asyncio.run(
functions.download_item_collection(

async def download() -> Union[Item, ItemCollection]:
return await functions.download_item_collection(
item_collection,
directory,
directory_str,
config=config,
queue=queue,
)
)

else:
print(f"ERROR: unsupported 'type' field: {type_}", file=sys.stderr)
if not quiet:
print(f"ERROR: unsupported 'type' field: {type_}", file=sys.stderr)
sys.exit(2)

task = asyncio.create_task(report_progress(queue))
output = await download()
if queue:
await queue.put(None)
await task

if not quiet:
json.dump(output.to_dict(transform_hrefs=False), sys.stdout)

Expand All @@ -170,3 +237,61 @@ async def read_file(href: str, config: Config) -> bytes:
async for chunk in client.open_href(href):
data += chunk
return data


async def report_progress(queue: Optional[AnyQueue]) -> None:
if queue is None:
return
downloads: Dict[str, Download] = dict()
while True:
message = await queue.get()
if isinstance(message, StartAssetDownload):
progress_bar = tqdm.tqdm(
position=len(downloads),
unit="B",
unit_scale=True,
unit_divisor=1024,
leave=False,
)
if message.item_id:
description = f"{message.item_id} [{message.key}]"
else:
description = message.key
progress_bar.set_description_str(description)
downloads[message.href] = Download(
key=message.key,
item_id=message.item_id,
href=message.href,
path=str(message.path),
progress_bar=progress_bar,
)
elif isinstance(message, OpenUrl):
download = downloads.get(str(message.url))
if download:
if message.size:
download.progress_bar.reset(total=message.size)
elif isinstance(message, FinishAssetDownload):
download = downloads.get(message.href)
if download:
download.progress_bar.close()
elif isinstance(message, ErrorAssetDownload):
download = downloads.get(message.href)
if download:
download.progress_bar.close()
elif isinstance(message, WriteChunk):
download = downloads.get(message.href)
if download:
download.progress_bar.update(message.size)
elif message is None:
for download in downloads.values():
download.progress_bar.close()
return


@dataclass
class Download:
key: str
item_id: Optional[str]
href: str
path: str
progress_bar: Tqdm
Loading

0 comments on commit f3a7494

Please sign in to comment.