Skip to content

Commit

Permalink
Support fixed features in Service API (facebook#2372)
Browse files Browse the repository at this point in the history
Summary:

Add the possibility of specifying some `FixedFeatures` as `fixed_features` in `AxClient.get_next_trial` and `AxClient.get_next_trials` which is currently only possible with the developer API.

Reviewed By: saitcakmak

Differential Revision: D56068035
  • Loading branch information
Cesar-Cardoso authored and facebook-github-bot committed Apr 17, 2024
1 parent 36ff37a commit 7ba2431
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
47 changes: 29 additions & 18 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,10 @@ def set_search_space(
wrap_error_message_in=CHOLESKY_ERROR_ANNOTATION,
)
def get_next_trial(
self, ttl_seconds: Optional[int] = None, force: bool = False
self,
ttl_seconds: Optional[int] = None,
force: bool = False,
fixed_features: Optional[FixedFeatures] = None,
) -> Tuple[TParameterization, int]:
"""
Generate trial with the next set of parameters to try in the iteration process.
Expand All @@ -508,6 +511,9 @@ def get_next_trial(
failed properly.
force: If set to True, this function will bypass the global stopping
strategy's decision and generate a new trial anyway.
fixed_features: A FixedFeatures object containing any
features that should be fixed at specified values during
generation.
Returns:
Tuple of trial parameterization, trial index
Expand All @@ -530,7 +536,10 @@ def get_next_trial(

try:
trial = self.experiment.new_trial(
generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
generator_run=self._gen_new_generator_run(
fixed_features=fixed_features
),
ttl_seconds=ttl_seconds,
)
except MaxParallelismReachedException as e:
if self._early_stopping_strategy is not None:
Expand Down Expand Up @@ -580,7 +589,10 @@ def get_current_trial_generation_limit(self) -> Tuple[int, bool]:
return self.generation_strategy.current_generator_run_limit()

def get_next_trials(
self, max_trials: int, ttl_seconds: Optional[int] = None
self,
max_trials: int,
ttl_seconds: Optional[int] = None,
fixed_features: Optional[FixedFeatures] = None,
) -> Tuple[Dict[int, TParameterization], bool]:
"""Generate as many trials as currently possible.
Expand All @@ -597,6 +609,9 @@ def get_next_trials(
ttl_seconds: If specified, will consider the trial failed after this
many seconds. Used to detect dead trials that were not marked
failed properly.
fixed_features: A FixedFeatures object containing any
features that should be fixed at specified values during
generation.
Returns: two-item tuple of:
- mapping from trial indices to parameterizations in those trials,
Expand All @@ -616,7 +631,9 @@ def get_next_trials(
trials_dict = {}
for _ in range(max_trials):
try:
params, trial_index = self.get_next_trial(ttl_seconds=ttl_seconds)
params, trial_index = self.get_next_trial(
ttl_seconds=ttl_seconds, fixed_features=fixed_features
)
trials_dict[trial_index] = params
except OptimizationComplete as err:
logger.info(
Expand Down Expand Up @@ -1744,20 +1761,16 @@ def _save_generation_strategy_to_db_if_possible(
suppress_all_errors=suppress_all_errors,
)

def _get_last_completed_trial_index(self) -> int:
# infer last completed trial as the trial_index to use
# TODO: use Experiment.completed_trials once D46484953 lands.
completed_indices = [
t.index for t in self.experiment.trials_by_status[TrialStatus.COMPLETED]
]
completed_indices.append(0) # handle case of no completed trials
return max(completed_indices)

def _gen_new_generator_run(self, n: int = 1) -> GeneratorRun:
def _gen_new_generator_run(
self, n: int = 1, fixed_features: Optional[FixedFeatures] = None
) -> GeneratorRun:
"""Generate new generator run for this experiment.
Args:
n: Number of arms to generate.
fixed_features: A FixedFeatures object containing any
features that should be fixed at specified values during
generation.
"""
# If random seed is not set for this optimization, context manager does
# nothing; otherwise, it sets the random seed for torch, but only for the
Expand All @@ -1767,10 +1780,8 @@ def _gen_new_generator_run(self, n: int = 1) -> GeneratorRun:
# stochasticity.

fixed_feats = InstantiationBase.make_fixed_observation_features(
fixed_features=FixedFeatures(
parameters={}, trial_index=self._get_last_completed_trial_index()
)
)
fixed_features=fixed_features
) if fixed_features else None
with manual_seed(seed=self._random_seed):
return not_none(self.generation_strategy).gen(
experiment=self.experiment,
Expand Down
19 changes: 14 additions & 5 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
observed_pareto,
predicted_pareto,
)
from ax.service.utils.instantiation import FixedFeatures
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.encoder import Encoder
Expand Down Expand Up @@ -2847,11 +2848,19 @@ def test_gen_fixed_features(self) -> None:
with mock.patch.object(
GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen
) as mock_gen:
params, idx = ax_client.get_next_trial()
call_kwargs = mock_gen.call_args_list[0][1]
ff = call_kwargs["fixed_features"]
self.assertEqual(ff.parameters, {})
self.assertEqual(ff.trial_index, 0)
with self.subTest("fixed_features is None"):
params, idx = ax_client.get_next_trial()
call_kwargs = mock_gen.call_args_list[0][1]
ff = call_kwargs["fixed_features"]
self.assertEqual(ff.parameters, {})
self.assertEqual(ff.trial_index, 0)
with self.subTest("fixed_features is set"):
fixed_features = FixedFeatures(parameters={"x": 0.0, "y": 5.0})
params, idx = ax_client.get_next_trial(fixed_features=fixed_features)
call_kwargs = mock_gen.call_args_list[1][1]
ff = call_kwargs["fixed_features"]
self.assertEqual(ff.parameters, fixed_features.parameters)
self.assertEqual(ff.trial_index, 0)

def test_get_optimization_trace_discard_infeasible_trials(self) -> None:
ax_client = AxClient()
Expand Down

0 comments on commit 7ba2431

Please sign in to comment.