Skip to content

Commit

Permalink
Add Hdf5History.from_history()
Browse files Browse the repository at this point in the history
Adds `Hdf5History.from_history()` which allows saving other histories as HDF5 later on.

Closes ICB-DCM#1196
  • Loading branch information
dweindl committed Nov 22, 2023
1 parent ffe5c6e commit 78cfbf0
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 6 deletions.
55 changes: 49 additions & 6 deletions pypesto/history/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import contextlib
import time
from functools import wraps
from pathlib import Path
from typing import Dict, Sequence, Tuple, Union

import h5py
Expand Down Expand Up @@ -53,6 +55,7 @@ def with_h5_file(mode: str):
raise ValueError(f"Mode must be one of {modes}")

def decorator(fun):
@wraps(fun)
def wrapper(self, *args, **kwargs):
# file already opened
if self._f is not None and (
Expand Down Expand Up @@ -104,12 +107,12 @@ class Hdf5History(HistoryBase):
def __init__(
self,
id: str,
file: str,
file: Union[str, Path],
options: Union[HistoryOptions, Dict] = None,
):
super().__init__(options=options)
self.id: str = id
self.file: str = file
self.file: str = str(file)

# filled during file access
self._f: Union[h5py.File, None] = None
Expand Down Expand Up @@ -139,10 +142,7 @@ def finalize(self, message: str = None, exitflag: str = None) -> None:
super().finalize()

# add message and exitflag to trace
f = self._f
if f'{HISTORY}/{self.id}/{MESSAGES}/' not in f:
f.create_group(f'{HISTORY}/{self.id}/{MESSAGES}/')
grp = f[f'{HISTORY}/{self.id}/{MESSAGES}/']
grp = self._f.require_group(f'{HISTORY}/{self.id}/{MESSAGES}/')
if message is not None:
grp.attrs[MESSAGE] = message
if exitflag is not None:
Expand Down Expand Up @@ -472,3 +472,46 @@ def _editable(self) -> bool:
except OSError:
# if something goes wrong, we assume the file is not editable
return False

@staticmethod
def from_history(
other: HistoryBase, file: Union[str, Path], id_: str
) -> "Hdf5History":
"""Write some History to HDF5."""
history = Hdf5History(file=file, id=id_)

try:
with h5py.File(history.file, mode="a") as f:
history._f = f
trace_group = history._require_group()
trace_group.attrs[N_FVAL] = other.n_fval
trace_group.attrs[N_GRAD] = other.n_grad
trace_group.attrs[N_HESS] = other.n_hess
trace_group.attrs[N_RES] = other.n_res
trace_group.attrs[N_SRES] = other.n_sres
trace_group.attrs[START_TIME] = other.start_time
trace_group.attrs[N_ITERATIONS] = (
len(other.get_time_trace())
if other.implements_trace()
else 0
)

if other.implements_trace():
for trace_key in (X, FVAL, GRAD, HESS, RES, SRES, TIME):
getter = getattr(other, f"get_{trace_key}_trace")
trace = getter()
for i, value in enumerate(trace):
trace_group.require_group(str(i))[
trace_key
] = value

group = f.require_group(f'{HISTORY}/{history.id}/{MESSAGES}/')
if other.message is not None:
group.attrs[MESSAGE] = other.message
if other.exitflag is not None:
group.attrs[EXITFLAG] = other.exitflag

finally:
history._f = None

return history
41 changes: 41 additions & 0 deletions test/base/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pytest
import scipy.optimize as so
from numpy.testing import assert_array_almost_equal

import pypesto
import pypesto.optimize as optimize
Expand Down Expand Up @@ -716,3 +717,43 @@ def test_trim_history():
fval_trimmed_man.append(fval_i)
fval_current = fval_i
assert fval_trace_trimmed == fval_trimmed_man


def test_hd5_history_from_other(history: pypesto.HistoryBase):
"""Check that we can copy different histories to HDF5 and that the re-loaded history matches the original one."""
hdf5_file = tempfile.mkstemp(suffix='.h5')[1]
pypesto.Hdf5History.from_history(history, hdf5_file, id_="0")
copied = pypesto.Hdf5History(file=hdf5_file, id="0")

assert copied.n_fval == history.n_fval
assert copied.n_grad == history.n_grad
assert copied.n_hess == history.n_hess
assert copied.n_res == history.n_res
assert copied.n_sres == history.n_sres
assert copied.exitflag == history.exitflag
assert copied.message == history.message
assert copied.start_time == history.start_time

if history.implements_trace():
assert_array_almost_equal(copied.get_x_trace(), history.get_x_trace())
assert_array_almost_equal(
copied.get_fval_trace(), history.get_fval_trace()
)
assert_array_almost_equal(
copied.get_grad_trace(), history.get_grad_trace()
)
assert_array_almost_equal(
copied.get_time_trace(), history.get_time_trace()
)
assert_array_almost_equal(
copied.get_res_trace(), history.get_res_trace()
)
assert_array_almost_equal(
copied.get_sres_trace(), history.get_sres_trace()
)
assert_array_almost_equal(
copied.get_chi2_trace(), history.get_chi2_trace()
)
assert_array_almost_equal(
copied.get_schi2_trace(), history.get_schi2_trace()
)

0 comments on commit 78cfbf0

Please sign in to comment.