Skip to content

Commit

Permalink
Update gen_for_multiple_trials_with_multiple_models to call into gen_…
Browse files Browse the repository at this point in the history
…with_multiple_nodes (#2822)

Summary:

This should be a no-op on the logic, since we aren't actually using quickbo anywhere, but this diff makes gen_with_multiple_nodes a sub-method of gen_multiple_trials_with_multiple_models and makes the later the primary entry point into GS. We update the method in all locations it's defined.

We also ensure that pending points are handled correctly by gen_multiple_trials_with_multiple_models, and ensure backwards compatability with all live GS.

Differential Revision: D63657844
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 4, 2024
1 parent 331da3a commit 4382399
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 54 deletions.
70 changes: 62 additions & 8 deletions ax/core/generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

from abc import ABC, abstractmethod

from typing import Any

from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import AxError, UnsupportedError
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none
Expand Down Expand Up @@ -43,10 +46,10 @@ def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
data: Data | None = None,
# TODO[drfreund, danielcohennyc, mgarrard]: Update the format of the arguments
# below as we find the right one.
num_generator_runs: int = 1,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
n: int | None = None,
num_trials: int = 1,
arms_per_node: dict[str, int] | None = None,
) -> list[list[GeneratorRun]]:
"""Produce ``GeneratorRun``-s for multiple trials at once with the possibility
of joining ``GeneratorRun``-s from multiple models into one ``BatchTrial``.
Expand All @@ -60,14 +63,22 @@ def gen_for_multiple_trials_with_multiple_models(
data: Optional data to be passed to the underlying model's ``gen``, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the ``experiment``.
n: Integer representing how many trials should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from ``n``.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
n: Integer representing how many total arms should be in the generator
runs produced by this method. NOTE: Some underlying models may ignore
the `n` and produce a model-determined number of arms. In that
case this method will also output generator runs with number of
arms that can differ from `n`.
num_trials: Number of trials to generate generator runs for in this call.
If not provided, defaults to 1.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.
Returns:
A list of lists of ``GeneratorRun``-s. Each outer list item represents
Expand All @@ -79,6 +90,49 @@ def gen_for_multiple_trials_with_multiple_models(
# are currently running / being evaluated/
...

def _gen_multiple(
self,
experiment: Experiment,
num_generator_runs: int,
data: Data | None = None,
n: int = 1,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
**model_gen_kwargs: Any,
) -> list[GeneratorRun]:
"""Produce multiple generator runs at once, to be made into multiple
trials on the experiment.
NOTE: This is used to ensure that maximum parallelism and number
of trials per node are not violated when producing many generator
runs from this generation strategy in a row. Without this function,
if one generates multiple generator runs without first making any
of them into running trials, generation strategy cannot enforce that it only
produces as many generator runs as are allowed by the parallelism
limit and the limit on number of trials in current node.
Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
n: Integer representing how many arms should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from ``n``.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
model_gen_kwargs: Keyword arguments that are passed through to
``GenerationNode.gen``, which will pass them through to
``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``.
"""
...

@abstractmethod
def clone_reset(self) -> GenerationStrategyInterface:
"""Returns a clone of this generation strategy with all state reset."""
Expand Down
7 changes: 4 additions & 3 deletions ax/core/tests/test_generation_strategy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import AxError, UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment, SpecialGenerationStrategy
Expand All @@ -20,10 +21,10 @@ def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
data: Data | None = None,
# TODO[drfreund, danielcohennyc, mgarrard]: Update the format of the arguments
# below as we find the right one.
num_generator_runs: int = 1,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
n: int | None = None,
num_trials: int = 1,
arms_per_node: dict[str, int] | None = None,
) -> list[list[GeneratorRun]]:
raise NotImplementedError

Expand Down
101 changes: 71 additions & 30 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,9 @@ def gen_with_multiple_nodes(
experiment: Experiment,
data: Data | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
arms_per_node: dict[str, int] | None = None,
n: int | None = None,
fixed_features: ObservationFeatures | None = None,
arms_per_node: dict[str, int] | None = None,
) -> list[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
``BatchTrial``, and if producing a ``BatchTrial`` allows for multiple
Expand All @@ -411,15 +412,19 @@ def gen_with_multiple_nodes(
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node.
n: Integer representing how many arms should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the `n` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from `n`.
fixed_features: An optional set of ``ObservationFeatures`` that will be
passed down to the underlying models.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.
Returns:
A list of ``GeneratorRuns`` for a single trial.
Expand Down Expand Up @@ -476,6 +481,7 @@ def gen_with_multiple_nodes(
data=data,
n=arms_from_node,
pending_observations=pending_observations,
fixed_features=fixed_features,
)
)
# ensure that the points generated from each node are marked as pending
Expand All @@ -493,49 +499,82 @@ def gen_with_multiple_nodes(
def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
num_generator_runs: int,
data: Data | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
n: int | None = None,
num_trials: int = 1,
arms_per_node: dict[str, int] | None = None,
) -> list[list[GeneratorRun]]:
"""Produce GeneratorRuns for multiple trials at once with the possibility of
ensembling, or using multiple models per trial, getting multiple
GeneratorRuns per trial.
NOTE: This method is in development. Please do not use it yet.
using multiple models per trial, getting multiple GeneratorRuns per trial.
Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
experiment: ``Experiment``, for which the generation strategy is producing
a new generator run in the course of ``gen``, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
data: Optional data to be passed to the underlying model's ``gen``, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
n: Integer representing how many trials should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from ``n``.
generator run. By default, data is all data on the ``experiment``.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
n: Integer representing how many total arms should be in the generator
runs produced by this method. NOTE: Some underlying models may ignore
the `n` and produce a model-determined number of arms. In that
case this method will also output generator runs with number of
arms that can differ from `n`.
num_trials: Number of trials to generate generator runs for in this call.
If not provided, defaults to 1.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.
Returns:
A list of lists of lists generator runs. Each outer list represents
a trial being suggested and each inner list represents a generator
run for that trial.
"""
# TODO: use gen_with_multiple_nodes() and get `n` there
n = self._get_n(experiment=experiment, n=n)
grs = self._gen_multiple(
experiment=experiment,
num_generator_runs=num_generator_runs,
data=data,
n=n,
pending_observations=get_pending_observation_features_based_on_trial_status(
trial_grs = []
pending_observations = (
get_pending_observation_features_based_on_trial_status(
experiment=experiment
),
fixed_features=get_fixed_features_from_experiment(experiment=experiment),
)
or {}
if pending_observations is None
else deepcopy(pending_observations)
)
return [[gr] for gr in grs]
gr_limit = self._curr.generator_run_limit(raise_generation_errors=False)
if gr_limit == -1:
num_trials = max(num_trials, 1)
else:
num_trials = max(min(num_trials, gr_limit), 1)
for _i in range(num_trials):
trial_grs.append(
self.gen_with_multiple_nodes(
experiment=experiment,
data=data,
n=n,
pending_observations=pending_observations,
arms_per_node=arms_per_node,
fixed_features=get_fixed_features_from_experiment(
experiment=experiment
),
)
)

extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# pass in the most recently generated grs each time to avoid
# duplication
generator_runs=trial_grs[-1],
)
return trial_grs

def current_generator_run_limit(
self,
Expand Down Expand Up @@ -604,6 +643,8 @@ def _unset_non_persistent_state_fields(self) -> None:
self._model = None
for s in self._nodes:
s._model_spec_to_gen_from = None
if not self.is_node_based:
s._previous_node_name = None

@step_based_gs_only
def _validate_and_set_step_sequence(self, steps: list[GenerationStep]) -> None:
Expand Down
18 changes: 7 additions & 11 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def test_gen_for_multiple_uses_total_concurrent_arms_for_a_default(
gs = self.sobol_GS
gs.experiment = exp
exp._properties[Keys.EXPERIMENT_TOTAL_CONCURRENT_ARMS.value] = 3
grs = gs.gen_for_multiple_trials_with_multiple_models(exp, num_generator_runs=2)
grs = gs.gen_for_multiple_trials_with_multiple_models(exp, num_trials=2)
self.assertEqual(len(grs), 2)
for gr_list in grs:
self.assertEqual(len(gr_list), 1)
Expand All @@ -939,26 +939,21 @@ def test_gen_for_multiple_uses_total_concurrent_arms_for_a_default(
def test_gen_for_multiple_trials_with_multiple_models(self) -> None:
exp = get_experiment_with_multi_objective()
sobol_MBM_gs = self.sobol_MBM_step_GS
sobol_MBM_gs.experiment = exp
with mock_patch_method_original(
mock_path=f"{ModelSpec.__module__}.ModelSpec.gen",
original_method=ModelSpec.gen,
) as model_spec_gen_mock, mock_patch_method_original(
mock_path=f"{ModelSpec.__module__}.ModelSpec.fit",
original_method=ModelSpec.fit,
) as model_spec_fit_mock:
) as model_spec_gen_mock:
# Generate first four Sobol GRs (one more to gen after that if
# first four become trials.
grs = sobol_MBM_gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp, num_generator_runs=3
experiment=exp, num_trials=3
)
self.assertEqual(len(grs), 3)
for gr in grs:
self.assertEqual(len(gr), 1)
self.assertIsInstance(gr[0], GeneratorRun)

# We should only fit once; refitting for each `gen` would be
# wasteful as there is no new data.
model_spec_fit_mock.assert_called_once()
self.assertEqual(model_spec_gen_mock.call_count, 3)
pending_in_each_gen = enumerate(
args_and_kwargs.kwargs.get("pending_observations")
Expand Down Expand Up @@ -988,7 +983,7 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None:

grs = sobol_MBM_gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp,
num_generator_runs=3,
num_trials=3,
)
self.assertEqual(len(grs), 2)
for gr in grs:
Expand Down Expand Up @@ -1040,10 +1035,11 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features(
),
]
)
gs.experiment = exp
for _ in range(3):
grs = gs.gen_for_multiple_trials_with_multiple_models(
experiment=exp,
num_generator_runs=1,
num_trials=1,
n=2,
)
exp.new_batch_trial(generator_runs=grs[0]).mark_running(
Expand Down
3 changes: 2 additions & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1882,10 +1882,11 @@ def _gen_new_trials_from_generation_strategy(
``_gen_multiple`` method of the scheduler's ``generation_strategy``, taking
into account any ``pending`` observations.
"""
self.generation_strategy.experiment = self.experiment
# TODO: pass self.trial_type to GS.gen for multi-type experiments
return self.generation_strategy.gen_for_multiple_trials_with_multiple_models(
experiment=self.experiment,
num_generator_runs=num_trials,
num_trials=num_trials,
n=n,
)

Expand Down
4 changes: 3 additions & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,9 +2443,11 @@ def __init__(self) -> None:
def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
num_generator_runs: int,
data: Data | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
n: int | None = None,
num_trials: int = 1,
arms_per_node: dict[str, int] | None = None,
) -> list[list[GeneratorRun]]:
return []

Expand Down

0 comments on commit 4382399

Please sign in to comment.