From 0e1efddf7f3fb03edb4561266c53edf3788c152c Mon Sep 17 00:00:00 2001 From: Pete Gadomski Date: Mon, 14 Aug 2023 07:09:24 -0600 Subject: [PATCH] refactor: add DownloadStrategy --- CHANGELOG.md | 1 + src/stac_asset/__init__.py | 3 +- src/stac_asset/_cli.py | 5 ++- src/stac_asset/_download.py | 33 ++++++++++++----- src/stac_asset/client.py | 4 -- src/stac_asset/config.py | 8 ++-- src/stac_asset/strategy.py | 15 +++++++- tests/test_functions.py | 74 ++++++++++++++++++++++++++----------- 8 files changed, 101 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3aeeafe..f0206dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Retry configuration for S3 ([#47](https://github.com/stac-utils/stac-asset/pull/47)) - `Collection` download ([#50](https://github.com/stac-utils/stac-asset/pull/50)) - Progress reporting ([#55](https://github.com/stac-utils/stac-asset/pull/55)) +- `DownloadStrategy` ([#64](https://github.com/stac-utils/stac-asset/pull/64)) ### Changed diff --git a/src/stac_asset/__init__.py b/src/stac_asset/__init__.py index 795d54c..c3f7de0 100644 --- a/src/stac_asset/__init__.py +++ b/src/stac_asset/__init__.py @@ -29,7 +29,7 @@ from .http_client import HttpClient from .planetary_computer_client import PlanetaryComputerClient from .s3_client import S3Client -from .strategy import FileNameStrategy +from .strategy import DownloadStrategy, FileNameStrategy __all__ = [ "DownloadWarning", @@ -39,6 +39,7 @@ "Config", "ContentTypeError", "DownloadError", + "DownloadStrategy", "EarthdataClient", "FileNameStrategy", "FilesystemClient", diff --git a/src/stac_asset/_cli.py b/src/stac_asset/_cli.py index 558bead..8f5fdc7 100644 --- a/src/stac_asset/_cli.py +++ b/src/stac_asset/_cli.py @@ -12,7 +12,7 @@ import tqdm from pystac import Item, ItemCollection -from . import Config, _download, functions +from . import Config, DownloadStrategy, _download, functions from .config import DEFAULT_S3_MAX_ATTEMPTS, DEFAULT_S3_RETRY_MODE from .messages import ( ErrorAssetDownload, @@ -167,7 +167,8 @@ async def download_async( s3_requester_pays=s3_requester_pays, s3_retry_mode=s3_retry_mode, s3_max_attempts=s3_max_attempts, - warn=warn, + # TODO allow configuring of download strategy + download_strategy=DownloadStrategy.DELETE, ) if href is None or href == "-": diff --git a/src/stac_asset/_download.py b/src/stac_asset/_download.py index f0b4eb8..93f9979 100644 --- a/src/stac_asset/_download.py +++ b/src/stac_asset/_download.py @@ -9,7 +9,7 @@ from types import TracebackType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union -from pystac import Asset, Collection, Item +from pystac import Asset, Collection, Item, STACObject from yarl import URL from .client import Client @@ -19,7 +19,7 @@ from .http_client import HttpClient from .planetary_computer_client import PlanetaryComputerClient from .s3_client import S3Client -from .strategy import FileNameStrategy +from .strategy import DownloadStrategy, FileNameStrategy # Needed until we drop Python 3.8 if TYPE_CHECKING: @@ -40,13 +40,20 @@ async def download( self, make_directory: bool, clean: bool, queue: Optional[AnyQueue] ) -> Union[Download, WrappedError]: try: - await self.client.download_asset( + self.asset = await self.client.download_asset( self.key, self.asset, self.path, make_directory, clean, queue ) except Exception as error: return WrappedError(self, error) - else: - return self + if "alternate" not in self.asset.extra_fields: + if not has_alternate_assets_extension(self.owner): + self.owner.stac_extensions.append( + "https://stac-extensions.github.io/alternate-assets/v1.1.0/schema.json" + ) + self.asset.extra_fields["alternate"] = {} + self.asset.extra_fields["alternate"]["from"] = {"href": self.asset.href} + self.asset.href = str(self.path) + return self class Downloads: @@ -112,11 +119,12 @@ async def download(self, queue: Optional[AnyQueue]) -> None: exceptions = set() for result in results: if isinstance(result, WrappedError): - del result.download.owner.assets[result.download.key] - if self.config.warn: - warnings.warn(str(result.error), DownloadWarning) - else: + if self.config.download_strategy == DownloadStrategy.ERROR: exceptions.add(result.error) + else: + if self.config.download_strategy == DownloadStrategy.DELETE: + del result.download.owner.assets[result.download.key] + warnings.warn(str(result.error), DownloadWarning) if exceptions: raise DownloadError(list(exceptions)) @@ -187,6 +195,13 @@ def guess_client_class_from_href(href: str) -> Type[Client]: raise ValueError(f"could not guess client class for href: {href}") +def has_alternate_assets_extension(stac_object: STACObject) -> bool: + return any( + extension.startswith("https://stac-extensions.github.io/alternate-assets") + for extension in stac_object.stac_extensions + ) + + class WrappedError: download: Download error: Exception diff --git a/src/stac_asset/client.py b/src/stac_asset/client.py index 17cc2ad..46f0b75 100644 --- a/src/stac_asset/client.py +++ b/src/stac_asset/client.py @@ -179,10 +179,6 @@ async def download_asset( if queue: await queue.put(FinishAssetDownload(key=key, href=href, path=path)) - if "alternate" not in asset.extra_fields: - asset.extra_fields["alternate"] = {} - asset.extra_fields["alternate"]["from"] = {"href": asset.href} - asset.href = str(path) return asset async def close(self) -> None: diff --git a/src/stac_asset/config.py b/src/stac_asset/config.py index 037a5f2..836d0c1 100644 --- a/src/stac_asset/config.py +++ b/src/stac_asset/config.py @@ -5,7 +5,7 @@ from typing import List, Optional from .errors import CannotIncludeAndExclude -from .strategy import FileNameStrategy +from .strategy import DownloadStrategy, FileNameStrategy DEFAULT_S3_REGION_NAME = "us-west-2" DEFAULT_S3_RETRY_MODE = "adaptive" @@ -22,6 +22,9 @@ class Config: asset_file_name_strategy: FileNameStrategy = FileNameStrategy.FILE_NAME """The file name strategy to use when downloading assets.""" + download_strategy: DownloadStrategy = DownloadStrategy.ERROR + """The strategy to use when errors occur during download.""" + exclude: List[str] = field(default_factory=list) """Assets to exclude from the download. @@ -46,9 +49,6 @@ class Config: If False, and the output directory does not exist, an error will be raised. """ - warn: bool = False - """When downloading, warn instead of erroring.""" - clean: bool = True """If true, clean up the downloaded file if it errors.""" diff --git a/src/stac_asset/strategy.py b/src/stac_asset/strategy.py index 43f4202..40cfc77 100644 --- a/src/stac_asset/strategy.py +++ b/src/stac_asset/strategy.py @@ -2,7 +2,7 @@ class FileNameStrategy(Enum): - """Strategy to use when downloading assets.""" + """Strategy to use for naming files.""" FILE_NAME = auto() """Save the asset with the file name in its href. @@ -13,3 +13,16 @@ class FileNameStrategy(Enum): KEY = auto() """Save the asset with its key as its file name.""" + + +class DownloadStrategy(Enum): + """Strategy to use when encountering errors during download.""" + + ERROR = auto() + """Throw an error if an asset cannot be downloaded.""" + + KEEP = auto() + """Warn, but keep the asset on the item.""" + + DELETE = auto() + """Warn, but delete the asset from the item.""" diff --git a/tests/test_functions.py b/tests/test_functions.py index 5c57b7b..f7a463e 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -11,6 +11,7 @@ CannotIncludeAndExclude, Config, DownloadError, + DownloadStrategy, DownloadWarning, FileNameStrategy, ) @@ -21,22 +22,68 @@ async def test_download_item(tmp_path: Path, item: Item) -> None: - item = await stac_asset.download_item(item, tmp_path, Config(file_name="item.json")) - assert Path(tmp_path / "item.json").exists(), item.get_self_href() + item = await stac_asset.download_item(item, tmp_path) + assert os.path.exists(tmp_path / "20201211_223832_CS2.jpg") asset = item.assets["data"] - assert asset.href == "./20201211_223832_CS2.jpg" + assert asset.href == str(tmp_path / "20201211_223832_CS2.jpg") + + +async def test_download_item_with_file_name(tmp_path: Path, item: Item) -> None: + await stac_asset.download_item(item, tmp_path, Config(file_name="item.json")) + item = Item.from_file(str(tmp_path / "item.json")) + assert item.assets["data"].href == "./20201211_223832_CS2.jpg" + + +async def test_download_missing_asset_error(tmp_path: Path, item: Item) -> None: + item.assets["does-not-exist"] = Asset("not-a-file.md5") + with pytest.raises(DownloadError): + await stac_asset.download_item( + item, tmp_path, Config(download_strategy=DownloadStrategy.ERROR) + ) + + +async def test_download_missing_asset_keep( + tmp_path: Path, item: Item, data_path: Path +) -> None: + item.assets["does-not-exist"] = Asset("not-a-file.md5") + with pytest.warns(DownloadWarning): + item = await stac_asset.download_item( + item, tmp_path, Config(download_strategy=DownloadStrategy.KEEP) + ) + assert item.assets["does-not-exist"].href == str(data_path / "not-a-file.md5") + + +async def test_download_missing_asset_delete(tmp_path: Path, item: Item) -> None: + item.assets["does-not-exist"] = Asset("not-a-file.md5") + with pytest.warns(DownloadWarning): + item = await stac_asset.download_item( + item, tmp_path, Config(download_strategy=DownloadStrategy.DELETE) + ) + assert "does-not-exist" not in item.assets async def test_download_item_collection( tmp_path: Path, item_collection: ItemCollection ) -> None: item_collection = await stac_asset.download_item_collection( - item_collection, tmp_path, Config(file_name="item-collection.json") + item_collection, tmp_path ) - assert os.path.exists(tmp_path / "item-collection.json") assert os.path.exists(tmp_path / "test-item" / "20201211_223832_CS2.jpg") asset = item_collection.items[0].assets["data"] - assert asset.href == "./test-item/20201211_223832_CS2.jpg" + assert asset.href == str(tmp_path / "test-item/20201211_223832_CS2.jpg") + + +async def test_download_item_collection_with_file_name( + tmp_path: Path, item_collection: ItemCollection +) -> None: + await stac_asset.download_item_collection( + item_collection, tmp_path, Config(file_name="item-collection.json") + ) + item_collection = ItemCollection.from_file(str(tmp_path / "item-collection.json")) + assert ( + item_collection.items[0].assets["data"].href + == "./test-item/20201211_223832_CS2.jpg" + ) async def test_download_collection(tmp_path: Path, collection: Collection) -> None: @@ -49,21 +96,6 @@ async def test_download_collection(tmp_path: Path, collection: Collection) -> No assert asset.href == "./20201211_223832_CS2.jpg" -async def test_item_download_404(tmp_path: Path, item: Item) -> None: - item.assets["missing-asset"] = Asset(href=str(Path(__file__).parent / "not-a-file")) - with pytest.raises(DownloadError): - await stac_asset.download_item(item, tmp_path) - assert not (tmp_path / "not-a-file").exists() - - -async def test_item_download_404_warn(tmp_path: Path, item: Item) -> None: - item.assets["missing-asset"] = Asset(href=str(Path(__file__).parent / "not-a-file")) - with pytest.warns(DownloadWarning): - item = await stac_asset.download_item(item, tmp_path, Config(warn=True)) - assert not (tmp_path / "not-a-file").exists() - assert "missing-asset" not in item.assets - - async def test_item_download_no_directory(tmp_path: Path, item: Item) -> None: with pytest.raises(DownloadError): await stac_asset.download_item(