Skip to content

Commit

Permalink
Move arms_per_node validation to seperate method
Browse files Browse the repository at this point in the history
Summary: I realized this validation is cluttering its current method, so this just pulls it out into it's own validation helper

Reviewed By: Balandat

Differential Revision: D63881602
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 4, 2024
1 parent 4382399 commit fa524cf
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,21 +430,8 @@ def gen_with_multiple_nodes(
A list of ``GeneratorRuns`` for a single trial.
"""
# TODO: @mgarrard merge into gen method, just starting here to derisk
# Validate `arms_per_node` if specified, otherwise construct the default
# behavior with keys being node names and values being 1 to represent
# generating a single GR from each node.
n = self._get_n(experiment=experiment, n=n)
if arms_per_node is not None and not set(self.nodes_dict).issubset(
arms_per_node
):
raise UserInputError(
f"""
Each node defined in the GenerationStrategy must have an associated
number of arms to generate from that node defined in `arms_per_node`.
{arms_per_node} does not include all of {self.nodes_dict.keys()}. It
may be helpful to double check the spelling.
"""
)
self._validate_arms_per_node(arms_per_node=arms_per_node)
grs = []
continue_gen_for_trial = True
# TODO: @mgarrard update this when gen methods are merged
Expand Down Expand Up @@ -796,6 +783,25 @@ def _step_repr(self, step_str_rep: str) -> str:
step_str_rep += "])"
return step_str_rep

def _validate_arms_per_node(self, arms_per_node: dict[str, int] | None) -> None:
"""Validate that the arms_per_node argument is valid if it is provided.
Args:
arms_per_node: A map from node name to the number of arms to
generate from that node.
"""
if arms_per_node is not None and not set(self.nodes_dict).issubset(
arms_per_node
):
raise UserInputError(
f"""
Each node defined in the GenerationStrategy must have an associated
number of arms to generate from that node defined in `arms_per_node`.
{arms_per_node} does not include all of {self.nodes_dict.keys()}. It
may be helpful to double check the spelling.
"""
)

def _make_default_name(self) -> str:
"""Make a default name for this generation strategy; used when no name is passed
to the constructor. For node-based generation strategies, the name is
Expand Down

0 comments on commit fa524cf

Please sign in to comment.