Skip to content

Commit

Permalink
Shard evict fix (#795)
Browse files Browse the repository at this point in the history
Co-authored-by: Cory Stephenson <cory.stephenson@databricks.com>
  • Loading branch information
snarayan21 and corystephenson-db authored Oct 1, 2024
1 parent 93d3f50 commit 8f9602e
Showing 1 changed file with 12 additions and 27 deletions.
39 changes: 12 additions & 27 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,40 +1080,25 @@ def _evict_shard(self, shard_id: int) -> None:
raise RuntimeError(f'Negative cache usage: {self.cache_usage}.')

def _evict_coldest_shard(self) -> None:
"""Evict the coldeset (i.e., least recently accessed) shard.
"""Evict the coldest (i.e., least recently accessed) local shard.
Assumes you hold ``__cache_filelock``, preventing anyone else from modifying the cache. We
expect that shard deletions are very fast.
This method is called internally by ``prepare_shard`` to clear space for more downloads.
"""
while True:
# Find the shard with the oldest last access time.
shard_id = int(self._shard_access_times.numpy().argmin())

# Check the shard's last access time. If it is NEVER, there are no downloaded shards to
# evict. If any shards are currently being downloaded, wait, else raise an error.
if self._shard_access_times[shard_id] == NEVER:
if (self._shard_states.numpy() == _ShardState.PREPARING).any():
sleep(TICK)
continue
else:
raise ValueError(
f'Tried to evict a shard {shard_id}, but no shards are present to evict ' +
f'(cache usage {self.cache_usage} of {self.cache_limit})')

# The shard has a valid timestamp. Now, verify that it is actually present. There is an
# edge case where it may not be present (see the note in get_item()). If not present,
# pick the next lowest shard.
if self._shard_states[shard_id] != _ShardState.LOCAL:
self._shard_access_times[shard_id] = NEVER
continue

# Break on success.
break

states = self._shard_states.numpy()
access_times = self._shard_access_times.numpy()
# Filter shard ids to include only local shards
local_shard_ids = np.where(states == _ShardState.LOCAL)[0]
if local_shard_ids.size == 0:
raise ValueError('Attempted shard eviction, but there are no shards present locally ' +
'to evict. Your cache limit may be too low.')
local_shard_times = access_times[local_shard_ids]
# Find local shard with oldest last access time
coldest_shard_id = local_shard_ids[np.argmin(local_shard_times)]
# Evict that shard.
self._evict_shard(shard_id)
self._evict_shard(coldest_shard_id)

def evict_shard(self, shard_id: int) -> None:
"""Evict the given shard.
Expand Down

0 comments on commit 8f9602e

Please sign in to comment.