Skip to content

Commit

Permalink
fix: better semaphore usage for concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Jul 23, 2024
1 parent 9ed2b9b commit 47cec93
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

- `--http-timeout` option to the CLI ([#196](https://github.com/stac-utils/stac-asset/pull/196))
- More info to CLI error reporting ([#200](https://github.com/stac-utils/stac-asset/pull/200))
- `--max-concurrent-downloads` option to the CLI ([#204](https://github.com/stac-utils/stac-asset/pull/204))

### Fixed

- Expand the list of exceptions on which we should retry for HTTP ([#195](https://github.com/stac-utils/stac-asset/pull/195))
- `SkipAssetDownload` docstring ([#199](https://github.com/stac-utils/stac-asset/pull/199))
- Fast failing when we hit `max_concurrent_downloads` ([#204](https://github.com/stac-utils/stac-asset/pull/204))

## [0.4.1] - 2024-07-17

Expand Down
25 changes: 6 additions & 19 deletions src/stac_asset/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
from typing import (
AsyncIterator,
List,
Optional,
Set,
Type,
Union,
)
from typing import AsyncIterator, List, Optional, Set, Type, Union

import pystac.utils
from pystac import Asset, Collection, Item, ItemCollection, Link, STACError
Expand Down Expand Up @@ -75,11 +68,6 @@ async def download(


class Downloads:
clients: Clients
config: Config
downloads: List[Download]
semaphore: Semaphore

def __init__(
self,
config: Config,
Expand All @@ -88,7 +76,7 @@ def __init__(
) -> None:
config.validate()
self.config = config
self.downloads = list()
self.downloads: List[Download] = list()
self.clients = Clients(config, clients)
self.semaphore = Semaphore(max_concurrent_downloads)

Expand Down Expand Up @@ -149,9 +137,7 @@ async def add(
async def download(self, messages: Optional[MessageQueue]) -> None:
tasks: Set[Task[Union[Download, WrappedError]]] = set()
for download in self.downloads:
# wait to acquire the semaphore before starting a new download task
await self.semaphore.acquire()
task = asyncio.create_task(self._download_with_release(download, messages))
task = asyncio.create_task(self.download_with_lock(download, messages))
tasks.add(task)
task.add_done_callback(tasks.discard)

Expand Down Expand Up @@ -182,9 +168,10 @@ async def download(self, messages: Optional[MessageQueue]) -> None:
if exceptions:
raise DownloadError(exceptions)

async def _download_with_release(
async def download_with_lock(
self, download: Download, messages: Optional[MessageQueue]
) -> Download | WrappedError:
) -> Union[Download, WrappedError]:
await self.semaphore.acquire()
try:
return await download.download(messages=messages)
finally:
Expand Down

0 comments on commit 47cec93

Please sign in to comment.