Skip to content

Commit

Permalink
Validate cached_statpoing when read from disk.
Browse files Browse the repository at this point in the history
This has the added benefit of validating all statepoints that are
added to the cache. I needed to add a validate argument to update_cache
because one of the unit tests relies on adding invalid statepoints
to the cache.
  • Loading branch information
joaander committed Feb 12, 2024
1 parent 7578993 commit 6b30785
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
42 changes: 28 additions & 14 deletions signac/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,13 +856,20 @@ def _register(self, id_, statepoint):
"""
self._sp_cache[id_] = statepoint

def _get_statepoint_from_workspace(self, job_id):
def _get_statepoint_from_workspace(self, job_id, validate=True):
"""Attempt to read the state point from the workspace.
Parameters
----------
job_id : str
Identifier of the job.
validate : bool
When True, validate that any statepoint read from disk matches the job_id.
Raises
------
:class:`signac.errors.JobsCorruptedError`
When one or more jobs are identified as corrupted.
"""
# Performance-critical path. We can rely on the project workspace, job
Expand All @@ -871,7 +878,11 @@ def _get_statepoint_from_workspace(self, job_id):
fn_statepoint = os.sep.join((self.workspace, job_id, Job.FN_STATE_POINT))
try:
with open(fn_statepoint, "rb") as statepoint_file:
return json.loads(statepoint_file.read().decode())
statepoint = json.loads(statepoint_file.read().decode())
if validate and calc_id(statepoint) != job_id:
raise JobsCorruptedError([job_id])

return statepoint
except (OSError, ValueError) as error:
if os.path.isdir(os.sep.join((self.workspace, job_id))):
logger.error(
Expand All @@ -882,7 +893,7 @@ def _get_statepoint_from_workspace(self, job_id):
raise JobsCorruptedError([job_id])
raise KeyError(job_id)

def _get_statepoint(self, job_id):
def _get_statepoint(self, job_id, validate=True):
"""Get the state point associated with a job id.
The state point is retrieved from the internal cache, from
Expand All @@ -892,6 +903,9 @@ def _get_statepoint(self, job_id):
----------
job_id : str
A job id to get the state point for.
validate : bool
When True, validate that any statepoint read from disk matches the job_id.
Returns
-------
Expand Down Expand Up @@ -926,7 +940,7 @@ def _get_statepoint(self, job_id):
"updating the cache by running `signac update-cache`."
)
self._sp_cache_warned = True
statepoint = self._get_statepoint_from_workspace(job_id)
statepoint = self._get_statepoint_from_workspace(job_id, validate)
# Update the project's state point cache from this cache miss
self._sp_cache[job_id] = statepoint
return statepoint
Expand Down Expand Up @@ -1258,11 +1272,7 @@ def check(self):
logger.info("Checking workspace for corruption...")
for job_id in self._find_job_ids():
try:
statepoint = self._get_statepoint(job_id)
if calc_id(statepoint) != job_id:
corrupted.append(job_id)
else:
self.open_job(statepoint).init()
statepoint = self._get_statepoint_from_workspace(job_id)
except JobsCorruptedError as error:
corrupted.extend(error.job_ids)
if corrupted:
Expand Down Expand Up @@ -1298,7 +1308,7 @@ def repair(self, job_ids=None):
for job_id in job_ids:
try:
# First, check if we can look up the state point.
statepoint = self._get_statepoint(job_id)
statepoint = self._get_statepoint(job_id, validate=False)
# Check if state point and id correspond.
correct_id = calc_id(statepoint)
if correct_id != job_id:
Expand Down Expand Up @@ -1379,7 +1389,7 @@ def _build_index(self, include_job_document=False):
raise
yield job_id, doc

def _update_in_memory_cache(self):
def _update_in_memory_cache(self, validate=False):
"""Update the in-memory state point cache to reflect the workspace."""
logger.debug("Updating in-memory cache...")
start = time.time()
Expand All @@ -1392,7 +1402,7 @@ def _update_in_memory_cache(self):
del self._sp_cache[id_]

def _add(id_):
self._sp_cache[id_] = self._get_statepoint_from_workspace(id_)
self._sp_cache[id_] = self._get_statepoint_from_workspace(id_, validate)

to_add_chunks = _split_and_print_progress(
iterable=list(to_add),
Expand All @@ -1419,7 +1429,7 @@ def _remove_persistent_cache_file(self):
if error.errno != errno.ENOENT:
raise error

def update_cache(self):
def update_cache(self, validate=True):
"""Update the persistent state point cache.
This function updates a persistent state point cache, which
Expand All @@ -1428,12 +1438,16 @@ def update_cache(self):
to be significantly faster after calling this function, especially
for large data spaces.
Parameters
----------
validate : bool
When True, validate that any statepoint read from disk matches the job_id.
"""
logger.info("Update cache...")
start = time.time()
cache = self._read_cache()
cached_ids = set(self._sp_cache)
self._update_in_memory_cache()
self._update_in_memory_cache(validate)
if cache is None or set(cache) != cached_ids:
fn_cache = self.fn(self.FN_CACHE)
fn_cache_tmp = fn_cache + "~"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,7 +2102,7 @@ class UpdateCacheAfterInitJob(signac.job.Job):

def init(self, *args, **kwargs):
job = super().init(*args, **kwargs)
self._project.update_cache()
self._project.update_cache(validate=False)
return job


Expand Down

0 comments on commit 6b30785

Please sign in to comment.