diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 774503682..ea159fade 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1087,31 +1087,15 @@ def _evict_coldest_shard(self) -> None: This method is called internally by ``prepare_shard`` to clear space for more downloads. """ - while True: - # Find the shard with the oldest last access time. - shard_id = int(self._shard_access_times.numpy().argmin()) - - # Check the shard's last access time. If it is NEVER, there are no downloaded shards to - # evict. If any shards are currently being downloaded, wait, else raise an error. - if self._shard_access_times[shard_id] == NEVER: - if (self._shard_states.numpy() == _ShardState.PREPARING).any(): - sleep(TICK) - continue - else: - raise ValueError( - f'Tried to evict a shard {shard_id}, but no shards are present to evict ' + - f'(cache usage {self.cache_usage} of {self.cache_limit})') - - # The shard has a valid timestamp. Now, verify that it is actually present. There is an - # edge case where it may not be present (see the note in get_item()). If not present, - # pick the next lowest shard. - if self._shard_states[shard_id] != _ShardState.LOCAL: - self._shard_access_times[shard_id] = NEVER - continue - - # Break on success. - break - + states = self._shard_states.numpy() + access_times = self._shard_access_times.numpy() + # Filter indices to include only local shards + indices = np.where(states == 3)[0] + if indices.size == 0: + raise ValueError('Could not evict because no local shards.') + local_times = access_times[indices] + # Find local shard with oldest last access time + shard_id = indices[np.argmin(local_times)] # Evict that shard. self._evict_shard(shard_id)