Skip to content

Commit

Permalink
Fix indexing error for 0-dimensional HDF5 datasets
Browse files Browse the repository at this point in the history
* Refactor ProfileResultHDF5Writer to be more readable
* Raise more informative error if writing to HDF5 fails
* Don't index 0-dimensional datasets (Fixes ICB-DCM#1205)
  • Loading branch information
dweindl committed Nov 21, 2023
1 parent bcdbd55 commit 5f4bbed
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pypesto/result/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class ProfileResult:
"""

def __init__(self):
self.list = []
self.list: list[list[ProfilerResult]] = []

def append_empty_profile_list(self) -> int:
"""Append an empty profile list to the list of profile lists.
Expand Down
10 changes: 7 additions & 3 deletions pypesto/store/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def write_string_array(f: h5py.Group, path: str, strings: Collection) -> None:
"""
dt = h5py.special_dtype(vlen=str)
dset = f.create_dataset(path, (len(strings),), dtype=dt)
dset[:] = [s.encode('utf8') for s in strings]
if len(strings):
dset[:] = [s.encode('utf8') for s in strings]


def write_float_array(
Expand All @@ -69,7 +70,9 @@ def write_float_array(
dset = f.create_dataset(path, (np.shape(values)), dtype=dtype)
else:
dset = f[path]
dset[:] = values

if len(values):
dset[:] = values


def write_int_array(
Expand All @@ -90,4 +93,5 @@ def write_int_array(
datatype
"""
dset = f.create_dataset(path, (len(values),), dtype=dtype)
dset[:] = values
if len(values):
dset[:] = values
39 changes: 27 additions & 12 deletions pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import h5py
import numpy as np

from ..result import Result, SampleResult
from ..result import ProfilerResult, Result, SampleResult
from .hdf5 import write_array, write_float_array

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -234,20 +234,35 @@ def write(self, result: Result, overwrite: bool = False):
profile_grp = profiling_grp.require_group(str(profile_id))
for parameter_id, parameter_profile in enumerate(profile):
result_grp = profile_grp.require_group(str(parameter_id))
self._write_profiler_result(parameter_profile, result_grp)

if parameter_profile is None:
result_grp.attrs['IsNone'] = True
continue
result_grp.attrs['IsNone'] = False
for key in parameter_profile.keys():
if isinstance(parameter_profile[key], np.ndarray):
write_float_array(
result_grp, key, parameter_profile[key]
)
elif parameter_profile[key] is not None:
result_grp.attrs[key] = parameter_profile[key]
f.flush()

@staticmethod
def _write_profiler_result(
parameter_profile: Union[ProfilerResult, None], result_grp: h5py.Group
) -> None:
"""Write a single ProfilerResult to hdf5.
Writes a single profile for a single parameter to the provided HDF5 group.
"""
if parameter_profile is None:
result_grp.attrs['IsNone'] = True
return

result_grp.attrs['IsNone'] = False

for key, value in parameter_profile.items():
try:
if isinstance(value, np.ndarray):
write_float_array(result_grp, key, value)
elif value is not None:
result_grp.attrs[key] = value
except Exception as e:
raise ValueError(
f"Error writing {key} ({value}) to {result_grp}."
) from e


def write_result(
result: Result,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ all =
%(amici)s
%(petab)s
%(all_optimizers)s
%(mpi)s
%(mpi)s
%(pymc)s
%(aesara)s
%(jax)s
Expand Down

0 comments on commit 5f4bbed

Please sign in to comment.