Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve runhistory type checks #706

Merged
merged 3 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions smac/intensification/intensification.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,14 +604,15 @@ def _get_next_inc_run(self,
Max time for a given instance/seed pair

"""
# Line 5
next_instance = self.rs.choice(available_insts)
# Line 5 - and avoid https://github.com/numpy/numpy/issues/10791
_idx = self.rs.choice(len(available_insts))
next_instance = available_insts[_idx]

# Line 6
if self.deterministic:
next_seed = 0
else:
next_seed = self.rs.randint(low=0, high=MAXINT, size=1)[0]
next_seed = int(self.rs.randint(low=0, high=MAXINT, size=1)[0])

# Line 7
self.logger.debug(
Expand Down
2 changes: 1 addition & 1 deletion smac/intensification/simple_intensifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_next_run(self,
config=challenger,
instance=self.instances[-1],
instance_specific="0",
seed=0 if self.deterministic else self.rs.randint(low=0, high=MAXINT, size=1)[0],
seed=0 if self.deterministic else int(self.rs.randint(low=0, high=MAXINT, size=1)[0]),
cutoff=self.cutoff,
capped=False,
budget=0.0,
Expand Down
2 changes: 1 addition & 1 deletion smac/intensification/successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(self,
if self.deterministic:
seeds = [0]
else:
seeds = self.rs.randint(low=0, high=MAXINT, size=self.n_seeds)
seeds = [int(s) for s in self.rs.randint(low=0, high=MAXINT, size=self.n_seeds)]
if self.n_seeds == 1:
self.logger.warning('The target algorithm is specified to be non deterministic, '
'but number of seeds to evaluate are set to 1. '
Expand Down
33 changes: 32 additions & 1 deletion smac/runhistory/runhistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,24 @@ def add(
else:
config_id = typing.cast(int, config_id_tmp)

# Construct keys and values for the data dictionary
k = RunKey(config_id, instance_id, seed, budget)
v = RunValue(cost, time, status, starttime, endtime, additional_info)
# Construct keys and values for the data dictionary
for key, value in (
('config', config.get_dictionary()),
('config_id', config_id),
('instance_id', instance_id),
('seed', seed),
('budget', budget),
('cost', cost),
('time', time),
('status', status),
('starttime', starttime),
('endtime', endtime),
('additional_info', additional_info),
('origin', config.origin),
):
self._check_json_serializable(key, value, EnumEncoder, k, v)

# Each runkey is supposed to be used only once. Repeated tries to add
# the same runkey will be ignored silently if not capped.
Expand All @@ -268,6 +283,22 @@ def add(
# overwrite if censored with a larger cutoff
self._add(k, v, status, origin)

def _check_json_serializable(
self,
key: str,
obj: typing.Any,
encoder: typing.Type[json.JSONEncoder],
runkey: RunKey,
runvalue: RunValue
) -> None:
try:
json.dumps(obj, cls=encoder)
except Exception as e:
raise ValueError(
"Cannot add %s: %s of type %s to runhistory because it raises an error during JSON encoding, "
"please see the error above.\nRunKey: %s\nRunValue %s" % (key, str(obj), type(obj), runkey, runvalue)
) from e

def _add(self, k: RunKey, v: RunValue, status: StatusType,
origin: DataOrigin) -> None:
"""Actual function to add new entry to data structures
Expand Down
19 changes: 1 addition & 18 deletions smac/tae/execute_func.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import math
import time
import json
import traceback
import typing

Expand Down Expand Up @@ -215,28 +214,12 @@ def run(self, config: Configuration,
cost = result
except Exception as e:
self.logger.exception(e)
cost, result = self.cost_for_crash, self.cost_for_crash
status = StatusType.CRASHED
cost = self.cost_for_crash
additional_run_info = {}

runtime = time.time() - start_time

# check serializability of results
try:
json.dumps(cost)
except TypeError as e:
self.logger.exception(e)
raise TypeError("Target Algorithm returned 'cost' {} (type {}) but it is not serializable. "
"Please ensure all objects returned are JSON serializable.".format(result, type(result))) \
from e
try:
json.dumps(additional_run_info)
except TypeError as e:
self.logger.exception(e)
raise TypeError("Target Algorithm returned 'additional_run_info' ({}) with some non-serializable items. "
"Please ensure all objects returned are JSON serializable.".format(additional_run_info)) \
from e

if status == StatusType.SUCCESS and not isinstance(result, (int, float)):
status = StatusType.CRASHED
cost = self.cost_for_crash
Expand Down
25 changes: 25 additions & 0 deletions test/test_runhistory/test_runhistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from ConfigSpace import Configuration, ConfigurationSpace
from ConfigSpace.hyperparameters import UniformIntegerHyperparameter
import numpy as np
import pynisher

from smac.tae import StatusType
from smac.runhistory.runhistory import RunHistory
Expand Down Expand Up @@ -251,6 +253,29 @@ def test_json_origin(self):

os.remove(path)

def test_add_json_serializable(self):
"""Test if entries added to the runhistory are correctly checked for serializability."""
rh = RunHistory()
cs = get_config_space()
config = cs.sample_configuration()

rh.add(config, 0.0, 0.0, StatusType.SUCCESS, None, None, 0.0, 0.0, 0.0, None)
rh.add(config, 0.0, 0.0, StatusType.SUCCESS, None, None, 0.0, 0.0, 0.0, {})

with self.assertRaisesRegex(
ValueError,
r"Cannot add cost: 0\.0 of type <class 'numpy\.float32'> to runhistory because "
r"it raises an error during JSON encoding"
):
rh.add(config, np.float32(0.0), 0.0, StatusType.SUCCESS, None, None, 0.0, 0.0, 0.0, None)
with self.assertRaisesRegex(
ValueError,
r"Cannot add additional_info: \{'error': <class 'pynisher\.limit_function_call\.AnythingException'>\} "
r"of type <class 'dict'> to runhistory because it raises an error during JSON encoding",
):
rh.add(config, 0.0, 0.0, StatusType.SUCCESS, None, None, 0.0, 0.0, 0.0,
{'error': pynisher.AnythingException})


if __name__ == "__main__":
unittest.main()
17 changes: 0 additions & 17 deletions test/test_tae/test_exec_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,3 @@ def target(x):
return x**2
taf = ExecuteTAFuncDict(ta=target, stats=self.stats)
self.assertRaises(ValueError, taf.run, config=2, cutoff=65536)

def test_non_serializable(self):
# cost non serializable
def target(x):
return np.int32(x)
taf = ExecuteTAFuncDict(ta=target, stats=self.stats)
msg = "Please ensure all objects returned are JSON serializable."
with self.assertRaisesRegex(TypeError, msg):
taf.run(config=2)

# additional info non serializable
def target(x):
return x, {'x': np.int32(x)}
taf = ExecuteTAFuncDict(ta=target, stats=self.stats)
msg = "Please ensure all objects returned are JSON serializable."
with self.assertRaisesRegex(TypeError, msg):
taf.run(config=2)