Skip to content

Commit

Permalink
Merge pull request #206 from optimas-org/feature/support_ax_0.4
Browse files Browse the repository at this point in the history
Add support for Ax `0.4.0`
  • Loading branch information
AngelFP authored May 10, 2024
2 parents 305190f + aa76960 commit 36dfdb9
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 179 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unix-openmpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, '3.10', 3.11]
python-version: ['3.10', 3.11]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, '3.10', 3.11]
python-version: ['3.10', 3.11]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion doc/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dependencies:
- pip
- pip:
- -e ..
- ax-platform >= 0.3.5
- ax-platform >= 0.4.0
- autodoc_pydantic >= 2.0.1
- ipykernel
- matplotlib
Expand Down
2 changes: 1 addition & 1 deletion optimas/generators/ax/import_error_dummy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def __init__(self, *args, **kwargs) -> None:
raise RuntimeError(
"You need to install ax-platform, in order "
"to use Ax-based generators in optimas.\n"
"e.g. with `pip install ax-platform >= 0.3.5`"
"e.g. with `pip install ax-platform >= 0.4.0`"
)
21 changes: 6 additions & 15 deletions optimas/generators/ax/service/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import os

import torch
from packaging import version
from ax.version import version as ax_version
from ax.core.observation import ObservationFeatures
from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import (
InstantiationBase,
ObjectiveProperties,
FixedFeatures,
)
from ax.modelbridge.registry import Models
from ax.modelbridge.generation_strategy import (
Expand All @@ -25,8 +24,6 @@
TrialStatus,
)
from optimas.generators.ax.base import AxGenerator
from optimas.generators.base import Generator
from .custom_ax import CustomAxClient as AxClient


class AxServiceGenerator(AxGenerator):
Expand Down Expand Up @@ -130,15 +127,9 @@ def __init__(
def _ask(self, trials: List[Trial]) -> List[Trial]:
"""Fill in the parameter values of the requested trials."""
for trial in trials:
try:
parameters, trial_id = self._ax_client.get_next_trial(
fixed_features=self._fixed_features
)
# Occurs when not using a CustomAxClient (i.e., when the AxClient
# is provided by the user using an AxClientGenerator). In that
# case, there is also no need to support FixedFeatures.
except TypeError:
parameters, trial_id = self._ax_client.get_next_trial()
parameters, trial_id = self._ax_client.get_next_trial(
fixed_features=self._fixed_features
)
trial.parameter_values = [
parameters.get(var.name) for var in self._varying_parameters
]
Expand Down Expand Up @@ -237,7 +228,7 @@ def _create_ax_parameters(self) -> List:
if var.is_fixed:
fixed_parameters[var.name] = var.default_value
# Store fixed parameters as fixed features.
self._fixed_features = ObservationFeatures(fixed_parameters)
self._fixed_features = FixedFeatures(fixed_parameters)
return parameters

def _create_ax_objectives(self) -> Dict[str, ObjectiveProperties]:
Expand Down
152 changes: 0 additions & 152 deletions optimas/generators/ax/service/custom_ax.py

This file was deleted.

19 changes: 15 additions & 4 deletions optimas/generators/ax/service/multi_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from typing import List, Optional, Dict

from botorch.acquisition.knowledge_gradient import (
qMultiFidelityKnowledgeGradient,
)
from ax.utils.common.constants import Keys
from ax.modelbridge.generation_strategy import GenerationStep
from ax.modelbridge.registry import Models

Expand Down Expand Up @@ -106,8 +110,8 @@ def _create_generation_steps(
self, bo_model_kwargs: Dict
) -> List[GenerationStep]:
"""Create generation steps for multifidelity optimization."""
# Add cost intercept to model kwargs.
bo_model_kwargs["cost_intercept"] = self.fidel_cost_intercept
# Add acquisition function to model kwargs.
bo_model_kwargs["botorch_acqf_class"] = qMultiFidelityKnowledgeGradient

# Make generation strategy.
steps = []
Expand All @@ -120,10 +124,17 @@ def _create_generation_steps(
# Continue indefinitely with GPKG.
steps.append(
GenerationStep(
model=Models.GPKG,
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs=bo_model_kwargs,
)
model_gen_kwargs={
"model_gen_options": {
Keys.ACQF_KWARGS: {
Keys.COST_INTERCEPT: self.fidel_cost_intercept
}
}
},
),
)

return steps
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ authors = [
{name = 'Optimas Developers', email = 'angel.ferran.pousa@desy.de'},
]
readme = 'README.md'
requires-python = '>=3.9'
requires-python = '>=3.10'
keywords = ['optimization', 'scale', 'bayesian']
license = {text = 'BSD-3-Clause-LBNL'}
classifiers = [
'Development Status :: 3 - Alpha',
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
]
Expand All @@ -34,11 +33,11 @@ dynamic = ['version']
test = [
'flake8',
'pytest',
'ax-platform >= 0.3.5',
'ax-platform >= 0.4.0',
'matplotlib',
]
all = [
'ax-platform >= 0.3.5',
'ax-platform >= 0.4.0',
]

[project.urls]
Expand Down

0 comments on commit 36dfdb9

Please sign in to comment.