Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hdf5History.from_history() #1211

Merged
merged 6 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions pypesto/history/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import contextlib
import time
from functools import wraps
from pathlib import Path
from typing import Dict, Sequence, Tuple, Union

import h5py
Expand Down Expand Up @@ -53,6 +55,7 @@ def with_h5_file(mode: str):
raise ValueError(f"Mode must be one of {modes}")

def decorator(fun):
@wraps(fun)
def wrapper(self, *args, **kwargs):
# file already opened
if self._f is not None and (
Expand All @@ -76,6 +79,7 @@ def wrapper(self, *args, **kwargs):
def check_editable(fun):
"""Check if the history is editable."""

@wraps(fun)
def wrapper(self, *args, **kwargs):
if not self.editable:
raise ValueError(
Expand Down Expand Up @@ -104,12 +108,12 @@ class Hdf5History(HistoryBase):
def __init__(
self,
id: str,
file: str,
file: Union[str, Path],
options: Union[HistoryOptions, Dict] = None,
):
super().__init__(options=options)
self.id: str = id
self.file: str = file
self.file: str = str(file)

# filled during file access
self._f: Union[h5py.File, None] = None
Expand Down Expand Up @@ -139,10 +143,7 @@ def finalize(self, message: str = None, exitflag: str = None) -> None:
super().finalize()

# add message and exitflag to trace
f = self._f
if f'{HISTORY}/{self.id}/{MESSAGES}/' not in f:
f.create_group(f'{HISTORY}/{self.id}/{MESSAGES}/')
grp = f[f'{HISTORY}/{self.id}/{MESSAGES}/']
grp = self._f.require_group(f'{HISTORY}/{self.id}/{MESSAGES}/')
if message is not None:
grp.attrs[MESSAGE] = message
if exitflag is not None:
Expand Down Expand Up @@ -454,11 +455,6 @@ def _editable(self) -> bool:
"""
Check whether the id is already existent in the file.

Parameters
----------
file:
HDF5 file name.

Returns
-------
True if the file is editable, False otherwise.
Expand All @@ -472,3 +468,68 @@ def _editable(self) -> bool:
except OSError:
# if something goes wrong, we assume the file is not editable
return False

@staticmethod
def from_history(
other: HistoryBase,
file: Union[str, Path],
id_: str,
overwrite: bool = False,
) -> "Hdf5History":
"""Write some History to HDF5.

Parameters
----------
other:
History to be copied to HDF5.
file:
HDF5 file to write to (append or create).
id_:
ID of the history.
overwrite:
Whether to overwrite an existing history with the same id.
"""
history = Hdf5History(file=file, id=id_)
history._f = h5py.File(history.file, mode="a")

try:
if f"{HISTORY}/{history.id}" in history._f:
if overwrite:
del history._f[f"{HISTORY}/{history.id}"]
else:
raise RuntimeError(
f"ID {history.id} already exists in file {file}."
)

trace_group = history._require_group()
trace_group.attrs[N_FVAL] = other.n_fval
trace_group.attrs[N_GRAD] = other.n_grad
trace_group.attrs[N_HESS] = other.n_hess
trace_group.attrs[N_RES] = other.n_res
trace_group.attrs[N_SRES] = other.n_sres
trace_group.attrs[START_TIME] = other.start_time
trace_group.attrs[N_ITERATIONS] = (
len(other.get_time_trace()) if other.implements_trace() else 0
)

group = trace_group.parent.require_group(MESSAGES)
if other.message is not None:
group.attrs[MESSAGE] = other.message
if other.exitflag is not None:
group.attrs[EXITFLAG] = other.exitflag

if not other.implements_trace():
return history

for trace_key in (X, FVAL, GRAD, HESS, RES, SRES, TIME):
getter = getattr(other, f"get_{trace_key}_trace")
trace = getter()
for iteration, value in enumerate(trace):
trace_group.require_group(str(iteration))[
trace_key
] = value
finally:
history._f.close()
history._f = None

return history
49 changes: 49 additions & 0 deletions test/base/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pytest
import scipy.optimize as so
from numpy.testing import assert_array_almost_equal

import pypesto
import pypesto.optimize as optimize
Expand Down Expand Up @@ -716,3 +717,51 @@ def test_trim_history():
fval_trimmed_man.append(fval_i)
fval_current = fval_i
assert fval_trace_trimmed == fval_trimmed_man


def test_hd5_history_from_other(history: pypesto.HistoryBase):
"""Check that we can copy different histories to HDF5 and that the re-loaded history matches the original one."""
hdf5_file = tempfile.mkstemp(suffix='.h5')[1]
pypesto.Hdf5History.from_history(history, hdf5_file, id_="0")

# write a second time to test `overwrite` argument
with pytest.raises(RuntimeError, match="already exists"):
pypesto.Hdf5History.from_history(
history, hdf5_file, id_="0", overwrite=False
)
copied = pypesto.Hdf5History.from_history(
history, hdf5_file, id_="0", overwrite=True
)

assert copied.n_fval == history.n_fval
assert copied.n_grad == history.n_grad
assert copied.n_hess == history.n_hess
assert copied.n_res == history.n_res
assert copied.n_sres == history.n_sres
assert copied.exitflag == history.exitflag
assert copied.message == history.message
assert copied.start_time == history.start_time

if history.implements_trace():
assert_array_almost_equal(copied.get_x_trace(), history.get_x_trace())
assert_array_almost_equal(
copied.get_fval_trace(), history.get_fval_trace()
)
assert_array_almost_equal(
copied.get_grad_trace(), history.get_grad_trace()
)
assert_array_almost_equal(
copied.get_time_trace(), history.get_time_trace()
)
assert_array_almost_equal(
copied.get_res_trace(), history.get_res_trace()
)
assert_array_almost_equal(
copied.get_sres_trace(), history.get_sres_trace()
)
assert_array_almost_equal(
copied.get_chi2_trace(), history.get_chi2_trace()
)
assert_array_almost_equal(
copied.get_schi2_trace(), history.get_schi2_trace()
)