From 8f9602e3fa64835586b774655342080d8f142917 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Tue, 1 Oct 2024 10:05:56 -0700 Subject: [PATCH] Shard evict fix (#795) Co-authored-by: Cory Stephenson --- streaming/base/dataset.py | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 774503682..722f4de44 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1080,40 +1080,25 @@ def _evict_shard(self, shard_id: int) -> None: raise RuntimeError(f'Negative cache usage: {self.cache_usage}.') def _evict_coldest_shard(self) -> None: - """Evict the coldeset (i.e., least recently accessed) shard. + """Evict the coldest (i.e., least recently accessed) local shard. Assumes you hold ``__cache_filelock``, preventing anyone else from modifying the cache. We expect that shard deletions are very fast. This method is called internally by ``prepare_shard`` to clear space for more downloads. """ - while True: - # Find the shard with the oldest last access time. - shard_id = int(self._shard_access_times.numpy().argmin()) - - # Check the shard's last access time. If it is NEVER, there are no downloaded shards to - # evict. If any shards are currently being downloaded, wait, else raise an error. - if self._shard_access_times[shard_id] == NEVER: - if (self._shard_states.numpy() == _ShardState.PREPARING).any(): - sleep(TICK) - continue - else: - raise ValueError( - f'Tried to evict a shard {shard_id}, but no shards are present to evict ' + - f'(cache usage {self.cache_usage} of {self.cache_limit})') - - # The shard has a valid timestamp. Now, verify that it is actually present. There is an - # edge case where it may not be present (see the note in get_item()). If not present, - # pick the next lowest shard. - if self._shard_states[shard_id] != _ShardState.LOCAL: - self._shard_access_times[shard_id] = NEVER - continue - - # Break on success. - break - + states = self._shard_states.numpy() + access_times = self._shard_access_times.numpy() + # Filter shard ids to include only local shards + local_shard_ids = np.where(states == _ShardState.LOCAL)[0] + if local_shard_ids.size == 0: + raise ValueError('Attempted shard eviction, but there are no shards present locally ' + + 'to evict. Your cache limit may be too low.') + local_shard_times = access_times[local_shard_ids] + # Find local shard with oldest last access time + coldest_shard_id = local_shard_ids[np.argmin(local_shard_times)] # Evict that shard. - self._evict_shard(shard_id) + self._evict_shard(coldest_shard_id) def evict_shard(self, shard_id: int) -> None: """Evict the given shard.