diff --git a/ax/analysis/plotly/arm_effects/__init__.py b/ax/analysis/plotly/arm_effects/__init__.py new file mode 100644 index 00000000000..8e84578e63d --- /dev/null +++ b/ax/analysis/plotly/arm_effects/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.analysis.plotly.arm_effects.insample_effects import InSampleEffectsPlot +from ax.analysis.plotly.arm_effects.predicted_effects import PredictedEffectsPlot + +__all__ = ["PredictedEffectsPlot", "InSampleEffectsPlot"] diff --git a/ax/analysis/plotly/arm_effects/insample_effects.py b/ax/analysis/plotly/arm_effects/insample_effects.py new file mode 100644 index 00000000000..18350ebb734 --- /dev/null +++ b/ax/analysis/plotly/arm_effects/insample_effects.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import chain +from logging import Logger + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.arm_effects.utils import ( + get_predictions_by_arm, + prepare_arm_effects_plot, +) + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.utils import is_predictive +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.generator_run import GeneratorRun +from ax.core.outcome_constraint import OutcomeConstraint +from ax.exceptions.core import DataRequiredError, UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.modelbridge.registry import Models +from ax.modelbridge.transforms.derelativize import Derelativize +from ax.utils.common.logger import get_logger +from plotly import io as pio +from pyre_extensions import none_throws + +logger: Logger = get_logger(__name__) + + +class InSampleEffectsPlot(PlotlyAnalysis): + """ + Plotly Insample Effecs plot for a single metric on a single trial, with one point + per unique arm across all trials. The plot may either use modeled effects, or + raw / observed data. + + This plot is useful for understanding how arms compare to eachother for a given + metric. + + TODO: Allow trial index to be optional so we can plot all trials for non batch + experiments. + + The DataFrame computed will contain one row per arm and the following columns: + - source: In-sample or model key that geneerated the candidate + - arm_name: The name of the arm + - mean: The observed or predicted mean of the metric specified + - sem: The observed or predicted sem of the metric specified + - error_margin: The 95% CI of the metric specified for the arm + - size_column: The size of the circle in the plot, which represents + the probability that the arm is feasible (does not violate any + constraints). + - parameters: A string representation of the parameters for the arm + to be viewed in the tooltip. + - constraints_violated: A string representation of the probability + each constraint is violated for the arm, to be viewed in the tooltip. + """ + + def __init__( + self, metric_name: str, trial_index: int, use_modeled_effects: bool + ) -> None: + """ + Args: + metric_name: The name of the metric to plot. + trial_index: The of the trial to plot arms for. + use_modeled_effects: Whether to use modeled effects or show + raw effects. + """ + + self.metric_name = metric_name + self.trial_index = trial_index + self.use_modeled_effects = use_modeled_effects + + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("InSampleEffectsPlot requires an Experiment.") + + model = _get_model( + experiment=experiment, + generation_strategy=generation_strategy, + use_modeled_effects=self.use_modeled_effects, + trial_index=self.trial_index, + metric_name=self.metric_name, + ) + + outcome_constraints = ( + [] + if experiment.optimization_config is None + else Derelativize() + .transform_optimization_config( + # TODO[T203521207]: move cloning into transform_optimization_config + optimization_config=none_throws(experiment.optimization_config).clone(), + modelbridge=model, + ) + .outcome_constraints + ) + df = _prepare_data( + experiment=experiment, + model=model, + outcome_constraints=outcome_constraints, + metric_name=self.metric_name, + trial_index=self.trial_index, + use_modeled_effects=self.use_modeled_effects, + ) + fig = prepare_arm_effects_plot( + df=df, metric_name=self.metric_name, outcome_constraints=outcome_constraints + ) + + if ( + experiment.optimization_config is None + or self.metric_name not in experiment.optimization_config.metrics + ): + level = AnalysisCardLevel.LOW + elif self.metric_name in experiment.optimization_config.objective.metric_names: + level = AnalysisCardLevel.HIGH + else: + level = AnalysisCardLevel.MID + + plot_type = "Modeled" if self.use_modeled_effects else "Raw" + subtitle = ( + "View a trial and its arms' " + f"{'predicted' if self.use_modeled_effects else 'observed'} " + "metric values" + ) + return PlotlyAnalysisCard( + name=f"{plot_type}EffectsPlot", + title=( + f"{plot_type} Effects for {self.metric_name} " + f"on trial {self.trial_index}" + ), + subtitle=subtitle, + level=level, + df=df, + blob=pio.to_json(fig), + attributes={"trial_index": self.trial_index}, + ) + + +def _get_max_observed_trial_index(model: ModelBridge) -> int | None: + """Returns the max observed trial index to appease multitask models for prediction + by giving fixed features. This is not necessarily accurate and should eventually + come from the generation strategy. + """ + observed_trial_indices = [ + obs.features.trial_index + for obs in model.get_training_data() + if obs.features.trial_index is not None + ] + if len(observed_trial_indices) == 0: + return None + return max(observed_trial_indices) + + +def _get_model( + experiment: Experiment, + generation_strategy: GenerationStrategyInterface | None, + use_modeled_effects: bool, + trial_index: int, + metric_name: str, +) -> ModelBridge: + """Get a model for predictions. + + Args: + experiment: Used to get the data for the model. + generation_strategy: Used to get the model if we want to use modeled effects + and the current model is predictive. + use_modeled_effects: Whether to use modeled effects. + trial_index: The trial index to get data for in training the model. + metric_name: The name of the metric we're plotting, which we validate has + data on the trial. + + Returns: + If use_modeled_effects is False, returns a Thompson model, which just predicts + from the data. + If use_modeled_effects is True, returns the current model on the generation + strategy if it is predictive. Otherwise, returns an empirical Bayes model. + """ + trial_data = experiment.lookup_data(trial_indices=[trial_index]) + if trial_data.filter(metric_names=[metric_name]).df.empty: + raise DataRequiredError( + f"Cannot plot effects for '{metric_name}' on trial {trial_index} " + "because it has no data. Either the data is not available yet, " + "or we encountered an error fetching it." + ) + if use_modeled_effects: + model = None + if isinstance(generation_strategy, GenerationStrategy): + if generation_strategy.model is None: + generation_strategy._fit_current_model(data=experiment.lookup_data()) + + model = none_throws(generation_strategy.model) + + if model is None or not is_predictive(model=model): + logger.info("Using empirical Bayes for predictions.") + return Models.EMPIRICAL_BAYES_THOMPSON( + experiment=experiment, data=trial_data + ) + + return model + else: + # This model just predicts raw data + return Models.THOMPSON( + data=trial_data, + search_space=experiment.search_space, + experiment=experiment, + ) + + +def _prepare_data( + experiment: Experiment, + model: ModelBridge, + outcome_constraints: list[OutcomeConstraint], + metric_name: str, + trial_index: int, + use_modeled_effects: bool, +) -> pd.DataFrame: + """Prepare data for plotting. Data should include columns for: + - source: In-sample or model key that geneerated the candidate + - arm_name: Name of the arm + - mean: Predicted metric value + - error_margin: 1.96 * predicted sem for plotting 95% CI + - **PARAMETER_NAME: The value of each parameter for the arm. Will be used + for the tooltip. + There will be one row for each arm in the model's training data and one for + each arm in the generator runs of the candidate trial. If an arm is in both + the training data and the candidate trial, it will only appear once for the + candidate trial. + + Args: + experiment: Experiment to plot + model: ModelBridge being used for prediction + outcome_constraints: Derelatives outcome constraints used for + assessing feasibility + metric_name: Name of metric to plot + trial_index: Optional trial index to plot. If not specified, will + plot the most recent non-abandoned trial with all observations. + """ + try: + trial = experiment.trials[trial_index] + except KeyError: + raise UserInputError( + f"Cannot plot effects for {trial_index} because " + f"it's missing from {experiment}." + ) + + status_quo_prediction = ( + [] + if experiment.status_quo is None + else [ + get_predictions_by_arm( + model=model, + metric_name=metric_name, + outcome_constraints=outcome_constraints, + gr=GeneratorRun( + arms=[experiment.status_quo], + model_key="Status Quo", + ), + ) + ] + ) + trial_predictions = [ + get_predictions_by_arm( + model=model, + metric_name=metric_name, + outcome_constraints=outcome_constraints, + gr=gr, + ) + for gr in trial.generator_runs + ] + + df = pd.DataFrame.from_records( + list( + chain( + *[ + *trial_predictions, + *status_quo_prediction, + ] + ) + ) + ) + df.drop_duplicates(subset="arm_name", keep="last", inplace=True) + return df diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/arm_effects/predicted_effects.py similarity index 52% rename from ax/analysis/plotly/predicted_effects.py rename to ax/analysis/plotly/arm_effects/predicted_effects.py index 10ef4a3f7bb..ac07edc6c83 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/arm_effects/predicted_effects.py @@ -4,33 +4,54 @@ # LICENSE file in the root directory of this source tree. from itertools import chain -from typing import Any import pandas as pd from ax.analysis.analysis import AnalysisCardLevel -from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard -from ax.analysis.plotly.utils import ( - format_constraint_violated_probabilities, - get_constraint_violated_probabilities, +from ax.analysis.plotly.arm_effects.utils import ( + get_predictions_by_arm, + prepare_arm_effects_plot, ) + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.utils import is_predictive from ax.core import OutcomeConstraint from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface -from ax.core.generator_run import GeneratorRun -from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError from ax.modelbridge.base import ModelBridge from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.transforms.derelativize import Derelativize from ax.utils.common.typeutils import checked_cast -from plotly import express as px, graph_objects as go, io as pio +from plotly import io as pio from pyre_extensions import none_throws class PredictedEffectsPlot(PlotlyAnalysis): + """ + Plotly Predicted Effecs plot for a single metric, with one point per unique arm + across all trials. It plots all observed points, as well as predictions for the + most recently generated trial. + + This plot is useful for understanding how arms in a candidate trial can be expected + to perform. + + The DataFrame computed will contain one row per arm and the following columns: + - source: In-sample or model key that geneerated the candidate + - arm_name: The name of the arm + - mean: The observed or predicted mean of the metric specified + - sem: The observed or predicted sem of the metric specified + - error_margin: The 95% CI of the metric specified for the arm + - size_column: The size of the circle in the plot, which represents + the probability that the arm is feasible (does not violate any + constraints). + - parameters: A string representation of the parameters for the arm + to be viewed in the tooltip. + - constraints_violated: A string representation of the probability + each constraint is violated for the arm, to be viewed in the tooltip. + """ + def __init__(self, metric_name: str) -> None: """ Args: @@ -38,7 +59,6 @@ def __init__(self, metric_name: str) -> None: will be used. Note that the metric cannot be inferred for multi-objective or scalarized-objective experiments. """ - self.metric_name = metric_name def compute( @@ -74,12 +94,20 @@ def compute( generation_strategy._fit_current_model(data=experiment.lookup_data()) model = none_throws(generation_strategy.model) + if not is_predictive(model=model): + raise UserInputError( + "PredictedEffectsPlot requires a GenerationStrategy which is " + "in a state where the current model supports prediction. The current " + f"model is {model._model_key} and does not support prediction." + ) + outcome_constraints = ( [] if experiment.optimization_config is None else Derelativize() .transform_optimization_config( - optimization_config=none_throws(experiment.optimization_config), + # TODO[T203521207]: move cloning into transform_optimization_config + optimization_config=none_throws(experiment.optimization_config).clone(), modelbridge=model, ) .outcome_constraints @@ -90,7 +118,9 @@ def compute( candidate_trial=candidate_trial, outcome_constraints=outcome_constraints, ) - fig = _prepare_plot(df=df, metric_name=self.metric_name) + fig = prepare_arm_effects_plot( + df=df, metric_name=self.metric_name, outcome_constraints=outcome_constraints + ) if ( experiment.optimization_config is None @@ -113,102 +143,6 @@ def compute( ) -def _get_predictions( - model: ModelBridge, - metric_name: str, - outcome_constraints: list[OutcomeConstraint], - gr: GeneratorRun | None = None, -) -> list[dict[str, Any]]: - trial_index = ( - _get_max_observed_trial_index(model) - if model.status_quo is None - else model.status_quo.features.trial_index - ) - if gr is None: - observations = model.get_training_data() - features = [o.features for o in observations] - arm_names = [o.arm_name for o in observations] - for feature in features: - feature.trial_index = trial_index - else: - features = [ - ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) - for arm in gr.arms - ] - arm_names = [a.name for a in gr.arms] - try: - predictions = [ - predict_at_point( - model=model, - obsf=obsf, - metric_names={metric_name}.union( - {constraint.metric.name for constraint in outcome_constraints} - ), - ) - for obsf in features - ] - except NotImplementedError: - raise UserInputError( - "PredictedEffectsPlot requires a GenerationStrategy which is " - "in a state where the current model supports prediction. The current " - f"model is {model._model_key} and does not support prediction." - ) - constraints_violated_by_constraint = get_constraint_violated_probabilities( - predictions=predictions, - outcome_constraints=outcome_constraints, - ) - probabilities_not_feasible = constraints_violated_by_constraint.pop( - "any_constraint_violated" - ) - constraints_violated = [ - { - c: constraints_violated_by_constraint[c][i] - for c in constraints_violated_by_constraint - } - for i in range(len(features)) - ] - - for i in range(len(features)): - if ( - model.status_quo is not None - and features[i].parameters - == none_throws(model.status_quo).features.parameters - ): - probabilities_not_feasible[i] = 0 - constraints_violated[i] = {} - return [ - { - "source": "In-sample" if gr is None else gr._model_key, - "arm_name": arm_names[i], - "mean": predictions[i][0][metric_name], - "sem": predictions[i][1][metric_name], - "error_margin": 1.96 * predictions[i][1][metric_name], - "constraints_violated": format_constraint_violated_probabilities( - constraints_violated[i] - ), - "size_column": 100 - probabilities_not_feasible[i] * 100, - "parameters": "
" - + "
".join([f"{k}: {v}" for k, v in features[i].parameters.items()]), - } - for i in range(len(features)) - ] - - -def _get_max_observed_trial_index(model: ModelBridge) -> int | None: - """Returns the max observed trial index to appease multitask models for prediction - by giving fixed features. This is not necessarily accurate and should eventually - come from the generation strategy. - """ - observed_trial_indices = [ - obs.features.trial_index - for obs in model.get_training_data() - if obs.features.trial_index is not None - ] - if len(observed_trial_indices) == 0: - return None - return max(observed_trial_indices) - - def _prepare_data( model: ModelBridge, metric_name: str, @@ -236,7 +170,7 @@ def _prepare_data( list( chain( *[ - _get_predictions( + get_predictions_by_arm( model=model, metric_name=metric_name, outcome_constraints=outcome_constraints, @@ -245,7 +179,7 @@ def _prepare_data( [] if candidate_trial is None else [ - _get_predictions( + get_predictions_by_arm( model=model, metric_name=metric_name, outcome_constraints=outcome_constraints, @@ -260,41 +194,3 @@ def _prepare_data( ) df.drop_duplicates(subset="arm_name", keep="last", inplace=True) return df - - -def _get_parameter_columns(df: pd.DataFrame) -> dict[str, bool]: - """Get the names of the columns that represent parameters in df.""" - return { - col: (col not in ["source", "error_margin", "size_column"]) - for col in df.columns - } - - -def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure: - """Prepare a plotly figure for the predicted effects based on the data in df.""" - fig = px.scatter( - df, - x="arm_name", - y="mean", - error_y="error_margin", - color="source", - hover_data=_get_parameter_columns(df), - size="size_column", - size_max=10, - ) - if "status_quo" in df["arm_name"].values: - fig.add_hline( - y=df[df["arm_name"] == "status_quo"]["mean"].iloc[0], - line_width=1, - line_color="red", - ) - fig.update_layout( - xaxis={ - "tickangle": 45, - }, - ) - for trace in fig.data: - if trace.marker.symbol == "x": - trace.marker.size = 11 # Larger size for 'x' - - return fig diff --git a/ax/analysis/plotly/arm_effects/utils.py b/ax/analysis/plotly/arm_effects/utils.py new file mode 100644 index 00000000000..e0d68a5d494 --- /dev/null +++ b/ax/analysis/plotly/arm_effects/utils.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import pandas as pd +from ax.analysis.plotly.utils import ( + format_constraint_violated_probabilities, + get_constraint_violated_probabilities, +) +from ax.core.generator_run import GeneratorRun +from ax.core.observation import ObservationFeatures +from ax.core.outcome_constraint import OutcomeConstraint +from ax.core.types import TParameterization +from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.prediction_utils import predict_at_point +from plotly import express as px, graph_objects as go +from pyre_extensions import none_throws + + +def format_parameters_for_effects_by_arm_plot( + parameters: TParameterization, max_num_params: int = 5 +) -> str: + """Format the parameters for tooltips in the predicted or insample + effects plot.""" + parameter_items = list(parameters.items())[:max_num_params] + string = "
" + "
".join([f"{k}: {v}" for k, v in parameter_items]) + if len(parameter_items) < len(parameters): + string += "
..." + return string + + +def prepare_arm_effects_plot( + df: pd.DataFrame, metric_name: str, outcome_constraints: list[OutcomeConstraint] +) -> go.Figure: + """Prepare a plotly figure for the predicted effects based on the data in df. + + Args: + metric_name: The name of the metric to plot. + outcome_constraints: The outcome constraints for the experiment used to + determine if the metric is a constraint, and if so, what the bound is + so the bound can be rendered in the plot. + df: A dataframe of data to plot with the following columns: + - source: In-sample or model key that geneerated the candidate + - arm_name: The name of the arm + - mean: The observed or predicted mean of the metric specified + - sem: The observed or predicted sem of the metric specified + - error_margin: The 95% CI of the metric specified for the arm + - size_column: The size of the circle in the plot, which represents + the probability that the arm is feasible (does not violate any + constraints). + - parameters: A string representation of the parameters for the arm + to be viewed in the tooltip. + - constraints_violated: A string representation of the probability + each constraint is violated for the arm, to be viewed in the tooltip. + """ + fig = px.scatter( + df, + x="arm_name", + y="mean", + error_y="error_margin", + color="source", + # TODO: can we format this by callable or string template? + hover_data=_get_parameter_columns(df), + size="size_column", + size_max=10, + ) + _add_style_to_effects_by_arm_plot( + fig=fig, df=df, metric_name=metric_name, outcome_constraints=outcome_constraints + ) + return fig + + +def _get_parameter_columns(df: pd.DataFrame) -> dict[str, bool]: + """Get the names of the columns that represent parameters in df.""" + return { + col: (col not in ["source", "error_margin", "size_column"]) + for col in df.columns + } + + +def _add_style_to_effects_by_arm_plot( + fig: go.Figure, + df: pd.DataFrame, + metric_name: str, + outcome_constraints: list[OutcomeConstraint], +) -> None: + """Add style to a plotly figure for predicted or insample effects. + + - If we have a status quo, we add a solid red line at the status quo mean. + - If the metric is a constraint, we add a dashed red line at the constraint + bound. + - Make the x-axis (arm name) tick angle 45 degrees. + """ + if "status_quo" in df["arm_name"].values: + fig.add_hline( + y=df[df["arm_name"] == "status_quo"]["mean"].iloc[0], + line_width=1, + line_color="red", + ) + for constraint in outcome_constraints: + if constraint.metric.name == metric_name: + assert not constraint.relative + fig.add_hline( + y=constraint.bound, + line_width=1, + line_color="red", + line_dash="dash", + ) + fig.update_layout( + xaxis={ + "tickangle": 45, + }, + ) + + +def _get_trial_index_for_predictions(model: ModelBridge) -> int | None: + """Returns status quo features index if defined on the model. Otherwise, returns + the max observed trial index to appease multitask models for prediction + by giving fixed features. The max index is not necessarily accurate and should + eventually come from the generation strategy, but at least gives consistent + predictions accross trials. + """ + if model.status_quo is None: + observed_trial_indices = [ + obs.features.trial_index + for obs in model.get_training_data() + if obs.features.trial_index is not None + ] + if len(observed_trial_indices) == 0: + return None + return max(observed_trial_indices) + + return model.status_quo.features.trial_index + + +def get_predictions_by_arm( + model: ModelBridge, + metric_name: str, + outcome_constraints: list[OutcomeConstraint], + gr: GeneratorRun | None = None, +) -> list[dict[str, Any]]: + trial_index = _get_trial_index_for_predictions(model) + if gr is None: + observations = model.get_training_data() + features = [o.features for o in observations] + arm_names = [o.arm_name for o in observations] + for feature in features: + feature.trial_index = trial_index + else: + features = [ + ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) + for arm in gr.arms + ] + arm_names = [a.name for a in gr.arms] + try: + predictions = [ + predict_at_point( + model=model, + obsf=obsf, + metric_names={metric_name}.union( + {constraint.metric.name for constraint in outcome_constraints} + ), + ) + for obsf in features + ] + except NotImplementedError: + raise UserInputError( + "This plot requires a GenerationStrategy which is " + "in a state where the current model supports prediction. The current " + f"model is {model._model_key} and does not support prediction." + ) + constraints_violated_by_constraint = get_constraint_violated_probabilities( + predictions=predictions, + outcome_constraints=outcome_constraints, + ) + probabilities_not_feasible = constraints_violated_by_constraint.pop( + "any_constraint_violated" + ) + constraints_violated = [ + { + c: constraints_violated_by_constraint[c][i] + for c in constraints_violated_by_constraint + } + for i in range(len(features)) + ] + + for i in range(len(features)): + if ( + model.status_quo is not None + and features[i].parameters + == none_throws(model.status_quo).features.parameters + ): + probabilities_not_feasible[i] = 0 + constraints_violated[i] = {} + return [ + { + "source": "In-sample" if gr is None else gr._model_key, + "arm_name": arm_names[i], + "mean": predictions[i][0][metric_name], + "sem": predictions[i][1][metric_name], + "error_margin": 1.96 * predictions[i][1][metric_name], + "constraints_violated": format_constraint_violated_probabilities( + constraints_violated[i] + ), + "size_column": 100 - probabilities_not_feasible[i] * 100, + "parameters": format_parameters_for_effects_by_arm_plot( + parameters=features[i].parameters + ), + } + for i in range(len(features)) + ] diff --git a/ax/analysis/plotly/tests/test_insample_effects.py b/ax/analysis/plotly/tests/test_insample_effects.py new file mode 100644 index 00000000000..25fa6de3702 --- /dev/null +++ b/ax/analysis/plotly/tests/test_insample_effects.py @@ -0,0 +1,472 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import patch + +import torch + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.arm_effects.insample_effects import InSampleEffectsPlot +from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm +from ax.core.base_trial import TrialStatus +from ax.exceptions.core import DataRequiredError, UserInputError +from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy +from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.prediction_utils import predict_at_point +from ax.modelbridge.registry import Models +from ax.modelbridge.transition_criterion import MaxTrials +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import ( + get_branin_experiment, + get_branin_metric, + get_branin_outcome_constraint, +) +from ax.utils.testing.mock import fast_botorch_optimize +from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds +from pyre_extensions import none_throws + + +class TestInsampleEffectsPlot(TestCase): + + def setUp(self) -> None: + super().setUp() + self.generation_strategy = GenerationStrategy( + nodes=[ + GenerationNode( + node_name="Sobol", + model_specs=[ModelSpec(model_enum=Models.SOBOL)], + transition_criteria=[ + MaxTrials( + threshold=1, + transition_to="GPEI", + ) + ], + ), + GenerationNode( + node_name="GPEI", + model_specs=[ + ModelSpec( + model_enum=Models.BOTORCH_MODULAR, + ), + ], + transition_criteria=[ + MaxTrials( + threshold=1, + transition_to="MTGP", + only_in_statuses=[ + TrialStatus.RUNNING, + TrialStatus.COMPLETED, + TrialStatus.EARLY_STOPPED, + ], + ) + ], + ), + GenerationNode( + node_name="MTGP", + model_specs=[ + ModelSpec( + model_enum=Models.ST_MTGP, + ), + ], + ), + ], + ) + + def test_compute_for_requires_an_exp(self) -> None: + analysis = InSampleEffectsPlot( + metric_name="branin", trial_index=0, use_modeled_effects=True + ) + + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + + @fast_botorch_optimize + def test_compute_uses_gs_model_if_possible(self) -> None: + # GIVEN an experiment and GS with a Botorch model + experiment = get_branin_experiment(with_status_quo=True) + generation_strategy = self.generation_strategy + experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1.0 + ).mark_completed( + unsafe=True + ) + experiment.fetch_data() + generation_strategy.gen_with_multiple_nodes(experiment=experiment, n=10) + # Ensure the current model is Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") + # WHEN we compute the analysis + analysis = InSampleEffectsPlot( + metric_name="branin", trial_index=0, use_modeled_effects=True + ) + with patch( + f"{get_predictions_by_arm.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it uses the model from the GS + models_used_for_prediction = [ + call[1]["model"]._model_key for call in predict_at_point_spy.call_args_list + ] + self.assertTrue( + [all(m == "BoTorch" for m in models_used_for_prediction)], + models_used_for_prediction, + ) + # AND THEN it has predictions for all arms + trial = experiment.trials[0] + self.assertEqual( + len(card.df), + len(trial.arms), + ) + for arm in trial.arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + + def test_compute_modeled_can_use_ebts_for_gs_with_non_predictive_model( + self, + ) -> None: + # GIVEN an experiment and GS with a non Botorch model + experiment = get_branin_experiment() + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) + ).mark_completed(unsafe=True) + experiment.fetch_data() + # Ensure the current model is not Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "Sobol") + # WHEN we compute the analysis + analysis = InSampleEffectsPlot( + metric_name="branin", trial_index=0, use_modeled_effects=True + ) + with patch( + f"{get_predictions_by_arm.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it uses the empirical bayes model + models_used_for_prediction = [ + call[1]["model"]._model_key for call in predict_at_point_spy.call_args_list + ] + self.assertTrue( + [all(m == "EB" for m in models_used_for_prediction)], + models_used_for_prediction, + ) + # AND THEN it has predictions for all arms + trial = experiment.trials[0] + self.assertEqual( + len(card.df), + len(trial.arms), + ) + for arm in trial.arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + + # AND THEN the card is labeled correctly + self.assertEqual(card.name, "ModeledEffectsPlot") + self.assertEqual(card.title, "Modeled Effects for branin on trial 0") + self.assertEqual( + card.subtitle, "View a trial and its arms' predicted metric values" + ) + # high because it's on objective + self.assertEqual(card.level, AnalysisCardLevel.HIGH) + + def test_compute_modeled_can_use_ebts_for_no_gs(self) -> None: + # GIVEN an experiment with a trial with data + experiment = get_branin_experiment() + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) + ).mark_completed(unsafe=True) + experiment.fetch_data() + # WHEN we compute the analysis + analysis = InSampleEffectsPlot( + metric_name="branin", trial_index=0, use_modeled_effects=True + ) + with patch( + f"{get_predictions_by_arm.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute(experiment=experiment, generation_strategy=None) + # THEN it uses the empirical bayes model + models_used_for_prediction = [ + call[1]["model"]._model_key for call in predict_at_point_spy.call_args_list + ] + self.assertTrue( + [all(m == "EB" for m in models_used_for_prediction)], + models_used_for_prediction, + ) + # AND THEN it has predictions for all arms + trial = experiment.trials[0] + self.assertEqual( + len(card.df), + len(trial.arms), + ) + for arm in trial.arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + + # AND THEN the card is labeled correctly + self.assertEqual(card.name, "ModeledEffectsPlot") + self.assertEqual(card.title, "Modeled Effects for branin on trial 0") + self.assertEqual( + card.subtitle, "View a trial and its arms' predicted metric values" + ) + # high because it's on objective + self.assertEqual(card.level, AnalysisCardLevel.HIGH) + + def test_compute_unmodeled_uses_thompson(self) -> None: + # GIVEN an experiment with a trial with data + experiment = get_branin_experiment() + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) + ).mark_completed(unsafe=True) + experiment.fetch_data() + # WHEN we compute the analysis + analysis = InSampleEffectsPlot( + metric_name="branin", trial_index=0, use_modeled_effects=False + ) + with patch( + f"{get_predictions_by_arm.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it uses the thompson model + models_used_for_prediction = [ + call[1]["model"]._model_key for call in predict_at_point_spy.call_args_list + ] + self.assertTrue( + [all(m == "Thompson" for m in models_used_for_prediction)], + models_used_for_prediction, + ) + # AND THEN it has predictions for all arms + trial = experiment.trials[0] + data_df = experiment.lookup_data(trial_indices=[trial.index]).df + self.assertEqual( + len(card.df), + len(trial.arms), + ) + for arm in trial.arms: + self.assertIn(arm.name, card.df["arm_name"].unique()) + self.assertAlmostEqual( + card.df.loc[card.df["arm_name"] == arm.name, "mean"].item(), + data_df.loc[data_df["arm_name"] == arm.name, "mean"].item(), + ) + self.assertAlmostEqual( + card.df.loc[card.df["arm_name"] == arm.name, "sem"].item(), + data_df.loc[data_df["arm_name"] == arm.name, "sem"].item(), + ) + + # AND THEN the card is labeled correctly + self.assertEqual(card.name, "RawEffectsPlot") + self.assertEqual(card.title, "Raw Effects for branin on trial 0") + self.assertEqual( + card.subtitle, "View a trial and its arms' observed metric values" + ) + # high because it's on objective + self.assertEqual(card.level, AnalysisCardLevel.HIGH) + + def test_compute_requires_data_for_the_metric_on_the_trial_without_a_model( + self, + ) -> None: + # GIVEN an experiment with a trial with no data + experiment = get_branin_experiment() + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) + ).mark_completed(unsafe=True) + self.assertTrue(experiment.lookup_data().df.empty) + # WHEN we compute the analysis + analysis = InSampleEffectsPlot( + metric_name="branin", + trial_index=0, + use_modeled_effects=False, + ) + with self.assertRaisesRegex( + DataRequiredError, + "Cannot plot effects for 'branin' on trial 0 because it has no data.", + ): + analysis.compute(experiment=experiment, generation_strategy=None) + # THEN it raises an error + + @fast_botorch_optimize + def test_compute_requires_data_for_the_metric_on_the_trial_with_a_model( + self, + ) -> None: + # GIVEN an experiment and GS with a Botorch model + experiment = get_branin_experiment(with_status_quo=True) + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1.0 + ).mark_completed( + unsafe=True + ) + experiment.fetch_data() + # AND GIVEN the experiment has a trial with no data + empty_trial = experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ), + ) + # Ensure the current model is Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") + self.assertTrue( + experiment.lookup_data(trial_indices=[empty_trial.index]).df.empty + ) + # WHEN we compute the analysis + analysis = InSampleEffectsPlot( + metric_name="branin", + trial_index=empty_trial.index, + use_modeled_effects=True, + ) + with self.assertRaisesRegex( + DataRequiredError, + ( + f"Cannot plot effects for 'branin' on trial {empty_trial.index} " + "because it has no data." + ), + ): + analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it raises an error + + @fast_botorch_optimize + def test_constraints(self) -> None: + # GIVEN an experiment with metrics and batch trials + experiment = get_branin_experiment(with_status_quo=True) + none_throws(experiment.optimization_config).outcome_constraints = [ + get_branin_outcome_constraint(name="constraint_branin_1"), + get_branin_outcome_constraint(name="constraint_branin_2"), + ] + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + trial = experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ), + ) + trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) + trial.mark_completed(unsafe=True) + experiment.fetch_data() + trial = experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ), + ) + trial.set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) + # WHEN we compute the analysis and constraints are violated + analysis = InSampleEffectsPlot( + metric_name="branin", trial_index=0, use_modeled_effects=True + ) + with self.subTest("violated"): + with patch( + f"{compute_log_prob_feas_from_bounds.__module__}.log_ndtr", + side_effect=lambda t: torch.as_tensor([[0.25]] * t.size()[0]).log(), + ): + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it marks that constraints are violated for the non-SQ arms + non_sq_df = card.df[card.df["arm_name"] != "status_quo"] + sq_row = card.df[card.df["arm_name"] == "status_quo"] + self.assertTrue( + all(non_sq_df["constraints_violated"] != "No constraints violated"), + non_sq_df["constraints_violated"], + ) + self.assertTrue( + all( + non_sq_df["constraints_violated"] + == ( + "
constraint_branin_1: 75.0% chance violated" + "
constraint_branin_2: 75.0% chance violated" + ) + ), + str(non_sq_df["constraints_violated"][0]), + ) + # AND THEN it marks that constraints are not violated for the SQ + self.assertEqual(sq_row["size_column"].iloc[0], 100) + self.assertEqual( + sq_row["constraints_violated"].iloc[0], "No constraints violated" + ) + + # WHEN we compute the analysis and constraints are violated + with self.subTest("not violated"): + with patch( + f"{compute_log_prob_feas_from_bounds.__module__}.log_ndtr", + side_effect=lambda t: torch.as_tensor([[1]] * t.size()[0]).log(), + ): + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) + # THEN it marks that constraints are not violated + self.assertTrue( + all(card.df["constraints_violated"] == "No constraints violated"), + str(card.df["constraints_violated"]), + ) + + # AND THEN it has not modified the constraints + opt_config = none_throws(experiment.optimization_config) + self.assertTrue(opt_config.outcome_constraints[0].relative) + self.assertTrue(opt_config.outcome_constraints[1].relative) + + def test_level(self) -> None: + # GIVEN an experiment with metrics and batch trials + experiment = get_branin_experiment(with_status_quo=True) + none_throws(experiment.optimization_config).outcome_constraints = [ + get_branin_outcome_constraint(name="constraint_branin"), + ] + experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) + generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment + trial = experiment.new_batch_trial( + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ), + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) + trial.mark_completed(unsafe=True) + experiment.fetch_data() + + metric_to_level = { + "branin": AnalysisCardLevel.HIGH, + "constraint_branin": AnalysisCardLevel.MID, + "tracking_branin": AnalysisCardLevel.LOW, + } + + for metric, level in metric_to_level.items(): + with self.subTest("objective is high"): + # WHEN we compute the analysis for an objective + analysis = InSampleEffectsPlot( + # trial_index and use_modeled_effects don't affect the level + metric_name=metric, + trial_index=0, + use_modeled_effects=False, + ) + card = analysis.compute(experiment=experiment) + # THEN the card has the correct level + self.assertEqual(card.level, level) diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index 8374750cb3f..b63d0677a8a 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -8,7 +8,8 @@ import torch from ax.analysis.analysis import AnalysisCardLevel -from ax.analysis.plotly.predicted_effects import PredictedEffectsPlot +from ax.analysis.plotly.arm_effects.predicted_effects import PredictedEffectsPlot +from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm from ax.core.base_trial import TrialStatus from ax.core.observation import ObservationFeatures from ax.core.trial import Trial @@ -32,7 +33,7 @@ from pyre_extensions import none_throws -class TestParallelCoordinatesPlot(TestCase): +class TestPredictedEffectsPlot(TestCase): def setUp(self) -> None: super().setUp() self.generation_strategy = GenerationStrategy( @@ -230,7 +231,7 @@ def test_compute_multitask(self) -> None: # WHEN we compute the analysis analysis = PredictedEffectsPlot(metric_name="branin") with patch( - f"{PredictedEffectsPlot.__module__}.predict_at_point", + f"{get_predictions_by_arm.__module__}.predict_at_point", wraps=predict_at_point, ) as predict_at_point_spy: card = analysis.compute( @@ -404,3 +405,8 @@ def test_constraints(self) -> None: all(card.df["constraints_violated"] == "No constraints violated"), str(card.df["constraints_violated"]), ) + + # AND THEN it has not modified the constraints + opt_config = none_throws(experiment.optimization_config) + self.assertTrue(opt_config.outcome_constraints[0].relative) + self.assertTrue(opt_config.outcome_constraints[1].relative) diff --git a/ax/analysis/plotly/utils.py b/ax/analysis/plotly/utils.py index 85cdac7309c..5d37a1fdd78 100644 --- a/ax/analysis/plotly/utils.py +++ b/ax/analysis/plotly/utils.py @@ -7,6 +7,7 @@ import torch from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds # Because normal distributions have long tails, every arm has a non-zero @@ -118,3 +119,18 @@ def format_constraint_violated_probabilities( constraints_violated_str = "
" + constraints_violated_str return constraints_violated_str + + +def is_predictive(model: ModelBridge) -> bool: + """Check if a model is predictive. Basically, we're checking if + predict() is implemented. + + NOTE: This does not mean it's capable of out of sample prediction. + """ + try: + model.predict(observation_features=[]) + except NotImplementedError: + return False + except Exception: + return True + return True diff --git a/sphinx/source/analysis.rst b/sphinx/source/analysis.rst index f52e6681bbb..aae4a02a03d 100644 --- a/sphinx/source/analysis.rst +++ b/sphinx/source/analysis.rst @@ -39,6 +39,14 @@ Healthcheck Analysis :undoc-members: :show-inheritance: +InSample Effects Analysis +~~~~~~~~~~~~~~~ + +.. automodule:: ax.analysis.plotly.arm_effects.insample_effects + :members: + :undoc-members: + :show-inheritance: + Parallel Coordinates Analysis ~~~~~~~~~~~~~~~ @@ -50,7 +58,15 @@ Parallel Coordinates Analysis Predicted Effects Analysis ~~~~~~~~~~~~~~~ -.. automodule:: ax.analysis.plotly.predicted_effects +.. automodule:: ax.analysis.plotly.arm_effects.predicted_effects + :members: + :undoc-members: + :show-inheritance: + +Plotly Arm Effects Utils +~~~~~~~~~~~~~~~ + +.. automodule:: ax.analysis.plotly.arm_effects.utils :members: :undoc-members: :show-inheritance: