Skip to content

Commit

Permalink
refactor: add DownloadStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Aug 14, 2023
1 parent 17c9b52 commit 0e1efdd
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Retry configuration for S3 ([#47](https://github.com/stac-utils/stac-asset/pull/47))
- `Collection` download ([#50](https://github.com/stac-utils/stac-asset/pull/50))
- Progress reporting ([#55](https://github.com/stac-utils/stac-asset/pull/55))
- `DownloadStrategy` ([#64](https://github.com/stac-utils/stac-asset/pull/64))

### Changed

Expand Down
3 changes: 2 additions & 1 deletion src/stac_asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .http_client import HttpClient
from .planetary_computer_client import PlanetaryComputerClient
from .s3_client import S3Client
from .strategy import FileNameStrategy
from .strategy import DownloadStrategy, FileNameStrategy

__all__ = [
"DownloadWarning",
Expand All @@ -39,6 +39,7 @@
"Config",
"ContentTypeError",
"DownloadError",
"DownloadStrategy",
"EarthdataClient",
"FileNameStrategy",
"FilesystemClient",
Expand Down
5 changes: 3 additions & 2 deletions src/stac_asset/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tqdm
from pystac import Item, ItemCollection

from . import Config, _download, functions
from . import Config, DownloadStrategy, _download, functions
from .config import DEFAULT_S3_MAX_ATTEMPTS, DEFAULT_S3_RETRY_MODE
from .messages import (
ErrorAssetDownload,
Expand Down Expand Up @@ -167,7 +167,8 @@ async def download_async(
s3_requester_pays=s3_requester_pays,
s3_retry_mode=s3_retry_mode,
s3_max_attempts=s3_max_attempts,
warn=warn,
# TODO allow configuring of download strategy
download_strategy=DownloadStrategy.DELETE,
)

if href is None or href == "-":
Expand Down
33 changes: 24 additions & 9 deletions src/stac_asset/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union

from pystac import Asset, Collection, Item
from pystac import Asset, Collection, Item, STACObject
from yarl import URL

from .client import Client
Expand All @@ -19,7 +19,7 @@
from .http_client import HttpClient
from .planetary_computer_client import PlanetaryComputerClient
from .s3_client import S3Client
from .strategy import FileNameStrategy
from .strategy import DownloadStrategy, FileNameStrategy

# Needed until we drop Python 3.8
if TYPE_CHECKING:
Expand All @@ -40,13 +40,20 @@ async def download(
self, make_directory: bool, clean: bool, queue: Optional[AnyQueue]
) -> Union[Download, WrappedError]:
try:
await self.client.download_asset(
self.asset = await self.client.download_asset(
self.key, self.asset, self.path, make_directory, clean, queue
)
except Exception as error:
return WrappedError(self, error)
else:
return self
if "alternate" not in self.asset.extra_fields:
if not has_alternate_assets_extension(self.owner):
self.owner.stac_extensions.append(
"https://stac-extensions.github.io/alternate-assets/v1.1.0/schema.json"
)
self.asset.extra_fields["alternate"] = {}
self.asset.extra_fields["alternate"]["from"] = {"href": self.asset.href}
self.asset.href = str(self.path)
return self


class Downloads:
Expand Down Expand Up @@ -112,11 +119,12 @@ async def download(self, queue: Optional[AnyQueue]) -> None:
exceptions = set()
for result in results:
if isinstance(result, WrappedError):
del result.download.owner.assets[result.download.key]
if self.config.warn:
warnings.warn(str(result.error), DownloadWarning)
else:
if self.config.download_strategy == DownloadStrategy.ERROR:
exceptions.add(result.error)
else:
if self.config.download_strategy == DownloadStrategy.DELETE:
del result.download.owner.assets[result.download.key]
warnings.warn(str(result.error), DownloadWarning)
if exceptions:
raise DownloadError(list(exceptions))

Expand Down Expand Up @@ -187,6 +195,13 @@ def guess_client_class_from_href(href: str) -> Type[Client]:
raise ValueError(f"could not guess client class for href: {href}")


def has_alternate_assets_extension(stac_object: STACObject) -> bool:
return any(
extension.startswith("https://stac-extensions.github.io/alternate-assets")
for extension in stac_object.stac_extensions
)


class WrappedError:
download: Download
error: Exception
Expand Down
4 changes: 0 additions & 4 deletions src/stac_asset/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@ async def download_asset(

if queue:
await queue.put(FinishAssetDownload(key=key, href=href, path=path))
if "alternate" not in asset.extra_fields:
asset.extra_fields["alternate"] = {}
asset.extra_fields["alternate"]["from"] = {"href": asset.href}
asset.href = str(path)
return asset

async def close(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/stac_asset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional

from .errors import CannotIncludeAndExclude
from .strategy import FileNameStrategy
from .strategy import DownloadStrategy, FileNameStrategy

DEFAULT_S3_REGION_NAME = "us-west-2"
DEFAULT_S3_RETRY_MODE = "adaptive"
Expand All @@ -22,6 +22,9 @@ class Config:
asset_file_name_strategy: FileNameStrategy = FileNameStrategy.FILE_NAME
"""The file name strategy to use when downloading assets."""

download_strategy: DownloadStrategy = DownloadStrategy.ERROR
"""The strategy to use when errors occur during download."""

exclude: List[str] = field(default_factory=list)
"""Assets to exclude from the download.
Expand All @@ -46,9 +49,6 @@ class Config:
If False, and the output directory does not exist, an error will be raised.
"""

warn: bool = False
"""When downloading, warn instead of erroring."""

clean: bool = True
"""If true, clean up the downloaded file if it errors."""

Expand Down
15 changes: 14 additions & 1 deletion src/stac_asset/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class FileNameStrategy(Enum):
"""Strategy to use when downloading assets."""
"""Strategy to use for naming files."""

FILE_NAME = auto()
"""Save the asset with the file name in its href.
Expand All @@ -13,3 +13,16 @@ class FileNameStrategy(Enum):

KEY = auto()
"""Save the asset with its key as its file name."""


class DownloadStrategy(Enum):
"""Strategy to use when encountering errors during download."""

ERROR = auto()
"""Throw an error if an asset cannot be downloaded."""

KEEP = auto()
"""Warn, but keep the asset on the item."""

DELETE = auto()
"""Warn, but delete the asset from the item."""
74 changes: 53 additions & 21 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CannotIncludeAndExclude,
Config,
DownloadError,
DownloadStrategy,
DownloadWarning,
FileNameStrategy,
)
Expand All @@ -21,22 +22,68 @@


async def test_download_item(tmp_path: Path, item: Item) -> None:
item = await stac_asset.download_item(item, tmp_path, Config(file_name="item.json"))
assert Path(tmp_path / "item.json").exists(), item.get_self_href()
item = await stac_asset.download_item(item, tmp_path)
assert os.path.exists(tmp_path / "20201211_223832_CS2.jpg")
asset = item.assets["data"]
assert asset.href == "./20201211_223832_CS2.jpg"
assert asset.href == str(tmp_path / "20201211_223832_CS2.jpg")


async def test_download_item_with_file_name(tmp_path: Path, item: Item) -> None:
await stac_asset.download_item(item, tmp_path, Config(file_name="item.json"))
item = Item.from_file(str(tmp_path / "item.json"))
assert item.assets["data"].href == "./20201211_223832_CS2.jpg"


async def test_download_missing_asset_error(tmp_path: Path, item: Item) -> None:
item.assets["does-not-exist"] = Asset("not-a-file.md5")
with pytest.raises(DownloadError):
await stac_asset.download_item(
item, tmp_path, Config(download_strategy=DownloadStrategy.ERROR)
)


async def test_download_missing_asset_keep(
tmp_path: Path, item: Item, data_path: Path
) -> None:
item.assets["does-not-exist"] = Asset("not-a-file.md5")
with pytest.warns(DownloadWarning):
item = await stac_asset.download_item(
item, tmp_path, Config(download_strategy=DownloadStrategy.KEEP)
)
assert item.assets["does-not-exist"].href == str(data_path / "not-a-file.md5")


async def test_download_missing_asset_delete(tmp_path: Path, item: Item) -> None:
item.assets["does-not-exist"] = Asset("not-a-file.md5")
with pytest.warns(DownloadWarning):
item = await stac_asset.download_item(
item, tmp_path, Config(download_strategy=DownloadStrategy.DELETE)
)
assert "does-not-exist" not in item.assets


async def test_download_item_collection(
tmp_path: Path, item_collection: ItemCollection
) -> None:
item_collection = await stac_asset.download_item_collection(
item_collection, tmp_path, Config(file_name="item-collection.json")
item_collection, tmp_path
)
assert os.path.exists(tmp_path / "item-collection.json")
assert os.path.exists(tmp_path / "test-item" / "20201211_223832_CS2.jpg")
asset = item_collection.items[0].assets["data"]
assert asset.href == "./test-item/20201211_223832_CS2.jpg"
assert asset.href == str(tmp_path / "test-item/20201211_223832_CS2.jpg")


async def test_download_item_collection_with_file_name(
tmp_path: Path, item_collection: ItemCollection
) -> None:
await stac_asset.download_item_collection(
item_collection, tmp_path, Config(file_name="item-collection.json")
)
item_collection = ItemCollection.from_file(str(tmp_path / "item-collection.json"))
assert (
item_collection.items[0].assets["data"].href
== "./test-item/20201211_223832_CS2.jpg"
)


async def test_download_collection(tmp_path: Path, collection: Collection) -> None:
Expand All @@ -49,21 +96,6 @@ async def test_download_collection(tmp_path: Path, collection: Collection) -> No
assert asset.href == "./20201211_223832_CS2.jpg"


async def test_item_download_404(tmp_path: Path, item: Item) -> None:
item.assets["missing-asset"] = Asset(href=str(Path(__file__).parent / "not-a-file"))
with pytest.raises(DownloadError):
await stac_asset.download_item(item, tmp_path)
assert not (tmp_path / "not-a-file").exists()


async def test_item_download_404_warn(tmp_path: Path, item: Item) -> None:
item.assets["missing-asset"] = Asset(href=str(Path(__file__).parent / "not-a-file"))
with pytest.warns(DownloadWarning):
item = await stac_asset.download_item(item, tmp_path, Config(warn=True))
assert not (tmp_path / "not-a-file").exists()
assert "missing-asset" not in item.assets


async def test_item_download_no_directory(tmp_path: Path, item: Item) -> None:
with pytest.raises(DownloadError):
await stac_asset.download_item(
Expand Down

0 comments on commit 0e1efdd

Please sign in to comment.