Skip to content

Commit

Permalink
Search local shards directly to find shard to evict
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Sep 27, 2024
1 parent 79c2dfc commit 6b521d2
Showing 1 changed file with 9 additions and 25 deletions.
34 changes: 9 additions & 25 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,31 +1087,15 @@ def _evict_coldest_shard(self) -> None:
This method is called internally by ``prepare_shard`` to clear space for more downloads.
"""
while True:
# Find the shard with the oldest last access time.
shard_id = int(self._shard_access_times.numpy().argmin())

# Check the shard's last access time. If it is NEVER, there are no downloaded shards to
# evict. If any shards are currently being downloaded, wait, else raise an error.
if self._shard_access_times[shard_id] == NEVER:
if (self._shard_states.numpy() == _ShardState.PREPARING).any():
sleep(TICK)
continue
else:
raise ValueError(
f'Tried to evict a shard {shard_id}, but no shards are present to evict ' +
f'(cache usage {self.cache_usage} of {self.cache_limit})')

# The shard has a valid timestamp. Now, verify that it is actually present. There is an
# edge case where it may not be present (see the note in get_item()). If not present,
# pick the next lowest shard.
if self._shard_states[shard_id] != _ShardState.LOCAL:
self._shard_access_times[shard_id] = NEVER
continue

# Break on success.
break

states = self._shard_states.numpy()
access_times = self._shard_access_times.numpy()
# Filter indices to include only local shards
indices = np.where(states == 3)[0]
if indices.size == 0:
raise ValueError('Could not evict because no local shards.')
local_times = access_times[indices]
# Find local shard with oldest last access time
shard_id = indices[np.argmin(local_times)]
# Evict that shard.
self._evict_shard(shard_id)

Expand Down

0 comments on commit 6b521d2

Please sign in to comment.