Skip to content

Commit

Permalink
Merge pull request #152 from optimas-org/bug/fix_data_type
Browse files Browse the repository at this point in the history
Fix bug with parameter type in Ax generators
  • Loading branch information
RemiLehe authored Dec 11, 2023
2 parents 914139c + d6bafef commit 058acea
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimas/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def get_sim_specs(
# May be a 1D array.
"in": [var.name for var in varying_parameters],
"out": (
[(obj.name, float) for obj in objectives]
[(obj.name, obj.dtype) for obj in objectives]
# f is the single float output that LibEnsemble minimizes.
+ [(par.name, par.dtype) for par in analyzed_parameters]
# input parameters
+ [(var.name, float) for var in varying_parameters]
+ [(var.name, var.dtype) for var in varying_parameters]
),
"user": {
"n_procs": self._n_procs,
Expand Down
1 change: 1 addition & 0 deletions optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _create_ax_parameters(self) -> List:
"bounds": [var.lower_bound, var.upper_bound],
"is_fidelity": var.is_fidelity,
"target_value": var.fidelity_target_value,
"value_type": var.dtype.__name__,
}
)
if var.is_fixed:
Expand Down
44 changes: 44 additions & 0 deletions tests/test_ax_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,46 @@ def test_ax_single_fidelity():
np.save("./tests_output/ax_sf_history", exploration._libe_history.H)


def test_ax_single_fidelity_int():
"""
Test that an exploration with a single-fidelity generator runs
correctly with an integer parameter.
"""

var1 = VaryingParameter("x0", -50.0, 5.0, dtype=int)
var2 = VaryingParameter("x1", -5.0, 15.0)
obj = Objective("f", minimize=False)

gen = AxSingleFidelityGenerator(
varying_parameters=[var1, var2], objectives=[obj]
)
ev = FunctionEvaluator(function=eval_func_sf)
exploration = Exploration(
generator=gen,
evaluator=ev,
max_evals=10,
sim_workers=2,
exploration_dir_path="./tests_output/test_ax_single_fidelity_int",
)

# Get reference to original AxClient.
ax_client = gen._ax_client
assert ax_client.experiment.search_space.parameters["x0"].python_type == int

# Run exploration.
exploration.run()

# Check that the generator has been updated.
assert gen.n_completed_trials == exploration.history.shape[0]

# Check that the original ax client has been updated.
n_ax_trials = ax_client.get_trials_data_frame().shape[0]
assert n_ax_trials == exploration.history.shape[0]

# Check correct variable type.
assert exploration.history["x0"].to_numpy().dtype == int


def test_ax_single_fidelity_moo():
"""
Test that an exploration with a multi-objective single-fidelity generator
Expand Down Expand Up @@ -246,6 +286,9 @@ def test_ax_single_fidelity_updated_params():
# Update range of x0 and run 10 evals.
var1.update_range(-20.0, 0.0)
gen.update_parameter(var1)
# Make sure we have an evaluation in the new range (it currently fails
# otherwise).
exploration.evaluate_trials([{"x0": -10.0, "x1": 10.0}])
exploration.run(n_evals=10)
assert all(exploration.history["x0"][-10:] >= -20)
assert all(exploration.history["x0"][-10:] <= 0.0)
Expand Down Expand Up @@ -567,6 +610,7 @@ def test_ax_service_init():

if __name__ == "__main__":
test_ax_single_fidelity()
test_ax_single_fidelity_int()
test_ax_single_fidelity_moo()
test_ax_single_fidelity_fb()
test_ax_single_fidelity_moo_fb()
Expand Down

0 comments on commit 058acea

Please sign in to comment.