Skip to content

Commit

Permalink
Feature hdf5hist pj (#564)
Browse files Browse the repository at this point in the history
* added history saving and reading

* updated read

* test store history

* test store history

* pep8 standard

* included hdf5_storage.ipynb

* updated notebook

* updated notebook

* brought other history classes up to date with develop

* added storge hdf5 notebook to example.rst

* added storgae_hdf5 to toctree

* updated docstring

* removed history comment

* added Hdf5History to tests

* pep8 standard

* remove TODOs and cleanup instance checks

Co-authored-by: MerktSimon <simon.merkt@uni-bonn.de>
Co-authored-by: FFroehlich <fabian@schaluck.com>
  • Loading branch information
3 people authored Feb 11, 2021
1 parent e9a5d5e commit 3e765d9
Show file tree
Hide file tree
Showing 8 changed files with 693 additions and 35 deletions.
3 changes: 3 additions & 0 deletions doc/example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following examples cover typical use cases and should help get a better idea
example/sampling_diagnostics.ipynb
example/synthetic_data.ipynb
example/prior_definition.ipynb
example/hdf5_storage.ipynb

Download the examples as notebooks
----------------------------------
Expand All @@ -30,6 +31,8 @@ Download the examples as notebooks
* :download:`Sampling diagnostics <example/sampling_diagnostics.ipynb>`
* :download:`Synthetic data <example/synthetic_data.ipynb>`
* :download:`Prior definition <example/prior_definition.ipynb>`
* :download:`hdf5 storage <example/hdf5_storage.ipynb>`


.. Note::
Some of the notebooks have extra dependencies.
360 changes: 360 additions & 0 deletions doc/example/hdf5_storage.ipynb

Large diffs are not rendered by default.

224 changes: 219 additions & 5 deletions pypesto/objective/history.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import pandas as pd
import numbers
import h5py
import copy
import time
import os
import abc
from typing import Any, Dict, List, Tuple, Sequence, Union
from typing import Any, Dict, List, Optional, Tuple, Sequence, Union

from .constants import (
MODE_FUN, MODE_RES, FVAL, GRAD, HESS, RES, SRES, CHI2, SCHI2, TIME,
Expand Down Expand Up @@ -769,6 +770,7 @@ def __init__(self,
super().__init__(options=options)
self.id = id
self.file = file
self._generate_hdf5_group()

def update(
self,
Expand All @@ -777,12 +779,224 @@ def update(
mode: str,
result: ResultDict
) -> None:
# TODO implement
raise NotImplementedError()
super().update(x, sensi_orders, mode, result)
self._update_trace(x, sensi_orders, mode, result)

def get_history_directory(self):
return self.file

def finalize(self):
# TODO implement
raise NotImplementedError()
super().finalize()

@staticmethod
def load(id: str,
file: str):
"""Loads the History object from memory."""
loaded_h5history = Hdf5History(id, file)
loaded_h5history._recover_options(file)
return loaded_h5history

def _recover_options(self, file: str):
"""
Recovers options when loading the hdf5 history from memory
by testing which entries were recorded.
"""
trace_record = self._check_for_not_nan_entries(X)
trace_record_grad = self._check_for_not_nan_entries(GRAD)
trace_record_hess = self._check_for_not_nan_entries(HESS)
trace_record_res = self._check_for_not_nan_entries(RES)
trace_record_sres = self._check_for_not_nan_entries(SRES)
trace_record_chi2 = self._check_for_not_nan_entries(CHI2)
trace_record_schi2 = self._check_for_not_nan_entries(SCHI2)
storage_file = file

restored_history_options = \
HistoryOptions(trace_record=trace_record,
trace_record_grad=trace_record_grad,
trace_record_hess=trace_record_hess,
trace_record_res=trace_record_res,
trace_record_sres=trace_record_sres,
trace_record_chi2=trace_record_chi2,
trace_record_schi2=trace_record_schi2,
trace_save_iter=self.trace_save_iter,
storage_file=storage_file)

self.options = restored_history_options

def _check_for_not_nan_entries(self, hdf5_group: str) -> bool:
"""Checks if there exist not-nan entries stored for a given group"""
group = self._get_hdf5_entries(hdf5_group)

for entry in group:
if not (entry is None or np.all(np.isnan(entry))):
return True

return False

# overwrite _update_counts
def _update_counts(self,
sensi_orders: Tuple[int, ...],
mode: str):
"""
Update the counters in the hdf5
"""
with h5py.File(self.file, 'a') as f:

if mode == MODE_FUN:
if 0 in sensi_orders:
f[f'history/{self.id}/trace/'].attrs[
'n_fval'] += 1
if 1 in sensi_orders:
f[f'history/{self.id}/trace/'].attrs[
'n_grad'] += 1
if 2 in sensi_orders:
f[f'history/{self.id}/trace/'].attrs[
'n_hess'] += 1
elif mode == MODE_RES:
if 0 in sensi_orders:
f[f'history/{self.id}/trace/'].attrs[
'n_res'] += 1
if 1 in sensi_orders:
f[f'history/{self.id}/trace/'].attrs[
'n_sres'] += 1

@property
def n_fval(self) -> int:
with h5py.File(self.file, 'r') as f:
return f[f'history/{self.id}/trace/'].attrs['n_fval']

@property
def n_grad(self) -> int:
with h5py.File(self.file, 'r') as f:
return f[f'history/{self.id}/trace/'].attrs['n_grad']

@property
def n_hess(self) -> int:
with h5py.File(self.file, 'r') as f:
return f[f'history/{self.id}/trace/'].attrs['n_hess']

@property
def n_res(self) -> int:
with h5py.File(self.file, 'r') as f:
return f[f'history/{self.id}/trace/'].attrs['n_res']

@property
def n_sres(self) -> int:
with h5py.File(self.file, 'r') as f:
return f[f'history/{self.id}/trace/'].attrs['n_sres']

@property
def trace_save_iter(self):
with h5py.File(self.file, 'r') as f:
return f[f'history/{self.id}/trace/']\
.attrs['trace_save_iter']

def _update_trace(self,
x: np.ndarray,
sensi_orders: Tuple[int],
mode: str,
result: ResultDict):
"""
Update and possibly store the trace.
"""

if not self.options.trace_record:
return

# extract function values
ret = extract_values(mode, result, self.options)

used_time = time.time() - self._start_time

values = {
TIME: used_time,
X: x,
FVAL: ret[FVAL],
GRAD: ret[GRAD],
RES: ret[RES],
SRES: ret[SRES],
CHI2: ret[CHI2],
HESS: ret[HESS],
}

with h5py.File(self.file, 'a') as f:

iteration = f[f'history/{self.id}/trace/'].attrs[
'n_iterations']

for key in values.keys():
if values[key] is not None:
f[f'history/{self.id}/trace/'
f'{str(iteration)}/{key}'] = values[key]

f[f'history/{self.id}/trace/'].attrs[
'n_iterations'] += 1

def _generate_hdf5_group(self, f: h5py.File = None):
"""
Generates the group in the hdf5 file, if it does not exist yet.
"""
try:
with h5py.File(self.file, 'a') as f:
if f'history/{self.id}/trace/' not in f:
grp = f.create_group(f'history/{self.id}/trace/')
grp.attrs['n_iterations'] = 0
grp.attrs['n_fval'] = 0
grp.attrs['n_grad'] = 0
grp.attrs['n_hess'] = 0
grp.attrs['n_res'] = 0
grp.attrs['n_sres'] = 0
grp.attrs['trace_save_iter'] = self.options.trace_save_iter
except OSError:
pass

def _get_hdf5_entries(self, entry_id: str) -> Sequence:
"""
returns the entries for the key entry_id.
"""
trace_result = []

with h5py.File(self.file, 'r') as f:

n_iterations = f[f'history/'
f'{self.id}/trace/'].attrs['n_iterations']

for iteration in range(n_iterations):
try:
entry = np.array(f[f'history/{self.id}/trace'
f'/{str(iteration)}/{entry_id}'])
trace_result.append(entry)
except KeyError:
trace_result.append(None)

return trace_result

def get_x_trace(self) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(X)

def get_fval_trace(self) -> Sequence[float]:
return self._get_hdf5_entries(FVAL)

def get_grad_trace(self) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(GRAD)

def get_hess_trace(self) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(HESS)

def get_res_trace(self) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(RES)

def get_sres_trace(self) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(SRES)

def get_chi2_trace(self) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(CHI2)

def get_schi2_trace(self, t: Optional[int] = None) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(SCHI2)

def get_time_trace(self, t: Optional[int] = None) -> Sequence[np.ndarray]:
return self._get_hdf5_entries(TIME)


class OptimizerHistory:
Expand Down
15 changes: 13 additions & 2 deletions pypesto/store/read_from_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ..profile.result import ProfilerResult
from ..sample.result import McmcPtResult
from ..problem import Problem
from ..objective import Objective, ObjectiveBase
from ..objective import Objective, ObjectiveBase, Hdf5History
import numpy as np
import logging

Expand Down Expand Up @@ -44,6 +44,7 @@ def read_hdf5_profile(f: h5py.File,


def read_hdf5_optimization(f: h5py.File,
file_name: str,
opt_id: str) -> 'OptimizerResult':
"""
Read HDF5 results per start.
Expand All @@ -52,13 +53,21 @@ def read_hdf5_optimization(f: h5py.File,
-------------
f:
The HDF5 result file
file_name:
The name of the HDF5 file, needed to create HDF5History
opt_id:
Specifies the start that is read from the HDF5 file
"""

result = OptimizerResult()

for optimization_key in result.keys():
if optimization_key == 'history':
if optimization_key in f:
result['history'] = Hdf5History(id=opt_id,
file=file_name)
result['history']._recover_options(file_name)
continue
if optimization_key in f[f'/optimization/results/{opt_id}']:
result[optimization_key] = \
f[f'/optimization/results/{opt_id}/{optimization_key}'][:]
Expand Down Expand Up @@ -154,7 +163,9 @@ def read(self) -> Result:
self.results.problem = problem_reader.read()

for opt_id in f['/optimization/results']:
result = read_hdf5_optimization(f, opt_id)
result = read_hdf5_optimization(f,
self.storage_filename,
opt_id)
self.results.optimize_result.append(result)
self.results.optimize_result.sort()
return self.results
Expand Down
3 changes: 2 additions & 1 deletion pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def write(self, result: Result, overwrite=False):
for start in result.optimize_result.list:
start_id = start['id']
start_grp = get_or_create_group(results_grp, start_id)
start['history'] = None # TOOD temporary fix
for key in start.keys():
if key == 'history':
continue
if isinstance(start[key], np.ndarray):
write_float_array(start_grp, key, start[key])
elif start[key] is not None:
Expand Down
Loading

0 comments on commit 3e765d9

Please sign in to comment.