Skip to content

Commit

Permalink
Tranquilo Refactoring (#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Mar 16, 2023
1 parent 51fd8b6 commit 5df5c9c
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 275 deletions.
7 changes: 4 additions & 3 deletions src/estimagic/optimization/tranquilo/aggregate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ def aggregator_identity(vector_model):
2. model_type: quadratic
"""
n_params = vector_model.linear_terms.size
intercept = float(vector_model.intercepts)
linear_terms = np.squeeze(vector_model.linear_terms)
linear_terms = vector_model.linear_terms.flatten()
if vector_model.square_terms is None:
square_terms = np.zeros((len(linear_terms), len(linear_terms)))
square_terms = np.zeros((n_params, n_params))
else:
square_terms = np.squeeze(vector_model.square_terms)
square_terms = vector_model.square_terms.reshape(n_params, n_params)
return intercept, linear_terms, square_terms


Expand Down
80 changes: 0 additions & 80 deletions src/estimagic/optimization/tranquilo/geometry.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,5 @@
from functools import partial

import numpy as np

from estimagic.optimization.tranquilo.region import Region
from estimagic.optimization.tranquilo.sample_points import get_sampler


def get_geometry_checker_pair(
checker, reference_sampler, n_params, n_simulations=200, bounds=None
):
"""Get a geometry checker.
Args:
checker (str or Dict[callable]): Name of a geometry checker method or a
dictionary with entries 'quality_calculator' and 'cutoff_simulator'.
- 'quality_calculator': A callable that takes as argument a sample and
returns a measure on the quality of the geometry of the sample.
- 'cutoff_simulator': A callable that takes as argument 'n_samples',
'n_params', 'reference_sampler' and 'rng'.
reference_sampler (str): Either "box" or "ball", corresponding to comparison
samples drawn inside a box or a ball, respectively.
n_params (int): Number of parameters.
n_simulations (int): Number of simulations for the mean calculation.
bounds (Bounds): The parameter bounds. See module bounds.py.
Returns:
callable: The sample quality calculator.
callable: The quality cutoff simulator.
"""
if reference_sampler not in {"box", "ball"}:
raise ValueError("reference_sampler need to be either 'box' or 'ball'.")

built_in_checker = {
"d_optimality": {
"quality_calculator": log_d_quality_calculator,
"cutoff_simulator": log_d_cutoff_simulator,
},
}

_checker = built_in_checker[checker]

quality_calculator = _checker["quality_calculator"]
cutoff_simulator = partial(
_checker["cutoff_simulator"],
reference_sampler=reference_sampler,
bounds=bounds,
n_params=n_params,
n_simulations=n_simulations,
)
return quality_calculator, cutoff_simulator


def log_d_cutoff_simulator(
n_samples, rng, reference_sampler, bounds, n_params, n_simulations
):
"""Simulate the mean logarithm of the d-optimality criterion.
Args:
n_samples (int): Size of the sample.
rng (np.random.Generator): The random number generator.
reference_sampler (str): Either "box" or "ball", corresponding to comparison
samples drawn inside a box or a ball, respectively.
bounds (Bounds): The parameter bounds. See module bounds.py.
n_params (int): Dimensionality of the sample.
n_simulations (int): Number of simulations for the mean calculation.
Returns:
float: The simulated mean logarithm of the d-optimality criterion.
"""
_sampler = get_sampler(reference_sampler)
trustregion = Region(center=np.zeros(n_params), radius=1.0, bounds=bounds)
sampler = partial(_sampler, trustregion=trustregion)
raw = []
for _ in range(n_simulations):
x = sampler(n_points=n_samples, rng=rng)
raw.append(log_d_quality_calculator(x, trustregion))
out = np.nanmean(raw)
return out


def log_d_quality_calculator(sample, trustregion):
"""Logarithm of the d-optimality criterion.
Expand Down
5 changes: 4 additions & 1 deletion src/estimagic/optimization/tranquilo/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def map_from_unit(self, x: np.ndarray) -> np.ndarray:
if self.shape == "sphere":
out = _map_from_unit_sphere(x, center=self.center, radius=self.radius)
else:
out = _map_from_unit_cube(x, cube_bounds=self.cube_bounds)
cube_bounds = self.cube_bounds
out = _map_from_unit_cube(x, cube_bounds=cube_bounds)
# Bounds may not be satisfied exactly due to numerical inaccuracies.
out = np.clip(out, cube_bounds.lower, cube_bounds.upper)
return out

# make it behave like a NamedTuple
Expand Down
Loading

0 comments on commit 5df5c9c

Please sign in to comment.