Skip to content

Commit

Permalink
Refactor getter functions (#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg authored Feb 1, 2023
1 parent 0f252a1 commit 66ac796
Show file tree
Hide file tree
Showing 7 changed files with 435 additions and 116 deletions.
23 changes: 12 additions & 11 deletions src/estimagic/optimization/tranquilo/filter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from scipy.linalg import qr_multiply

from estimagic.optimization.tranquilo.clustering import cluster
from estimagic.optimization.tranquilo.get_component import get_component
from estimagic.optimization.tranquilo.models import n_second_order_terms
from estimagic.optimization.tranquilo.volume import get_radius_after_volume_scaling


def get_sample_filter(sample_filter="keep_all"):
def get_sample_filter(sample_filter="keep_all", user_options=None):
"""Get filter function with partialled options.
The filter function is applied to points inside the current trustregion before
Expand All @@ -31,31 +32,31 @@ def get_sample_filter(sample_filter="keep_all"):
"drop_pounders": drop_collinear_pounders,
}

if isinstance(sample_filter, str) and sample_filter in built_in_filters:
out = built_in_filters[sample_filter]
elif callable(sample_filter):
out = sample_filter
else:
raise ValueError()
out = get_component(
name_or_func=sample_filter,
component_name="sample_filter",
func_dict=built_in_filters,
user_options=user_options,
)

return out


def discard_all(xs, indices, state, target_size): # noqa: ARG001
def discard_all(state):
return state.x.reshape(1, -1), np.array([state.index])


def keep_all(xs, indices, state, target_size): # noqa: ARG001
def keep_all(xs, indices):
return xs, indices


def keep_sphere(xs, indices, state, target_size): # noqa: ARG001
def keep_sphere(xs, indices, state):
dists = np.linalg.norm(xs - state.trustregion.center, axis=1)
keep = dists <= state.trustregion.radius
return xs[keep], indices[keep]


def drop_collinear_pounders(xs, indices, state, target_size): # noqa: ARG001
def drop_collinear_pounders(xs, indices, state):
"""Drop collinear points using pounders filtering."""
if xs.shape[0] <= xs.shape[1] + 1:
filtered_xs, filtered_indices = xs, indices
Expand Down
66 changes: 17 additions & 49 deletions src/estimagic/optimization/tranquilo/fit_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import inspect
import warnings
from functools import partial

import numpy as np
from numba import njit
from scipy.linalg import qr_multiply

from estimagic.optimization.tranquilo.get_component import get_component
from estimagic.optimization.tranquilo.models import (
ModelInfo,
VectorModel,
Expand Down Expand Up @@ -38,8 +37,6 @@ def get_fitter(fitter, user_options=None, model_info=None):
callable: The partialled fit method that only depends on x and y.
"""
if user_options is None:
user_options = {}
if model_info is None:
model_info = ModelInfo()

Expand All @@ -49,57 +46,30 @@ def get_fitter(fitter, user_options=None, model_info=None):
"powell": fit_powell,
}

if isinstance(fitter, str) and fitter in built_in_fitters:
_fitter = built_in_fitters[fitter]
_fitter_name = fitter
elif callable(fitter):
_fitter = fitter
_fitter_name = getattr(fitter, "__name__", "your fitter")
else:
raise ValueError(
f"Invalid fitter: {fitter}. Must be one of {list(built_in_fitters)} or a "
"callable."
)

default_options = {
"l2_penalty_linear": 0,
"l2_penalty_square": 0.1,
"model_info": model_info,
}

all_options = {**default_options, **user_options}

args = set(inspect.signature(_fitter).parameters)

if not {"x", "y", "model_info"}.issubset(args):
raise ValueError(
"fit method needs to take 'x', 'y' and 'model_info' as the first three "
"arguments."
)

not_options = {"x", "y", "model_info"}
if isinstance(_fitter, partial):
partialed_in = set(_fitter.args).union(set(_fitter.keywords))
not_options = not_options | partialed_in
mandatory_arguments = ["x", "y", "model_info"]

valid_options = args - not_options

reduced = {key: val for key, val in all_options.items() if key in valid_options}

ignored = {
key: val for key, val in user_options.items() if key not in valid_options
}

if ignored:
warnings.warn(
"The following options were ignored because they are not compatible with "
f"{_fitter_name}:\n\n {ignored}"
)
_raw_fitter = get_component(
name_or_func=fitter,
component_name="fitter",
func_dict=built_in_fitters,
default_options=default_options,
user_options=user_options,
mandatory_signature=mandatory_arguments,
)

out = partial(
_fitter_template, fitter=_fitter, model_info=model_info, options=reduced
fitter = partial(
_fitter_template,
fitter=_raw_fitter,
model_info=model_info,
)

return out
return fitter


def _fitter_template(
Expand All @@ -108,7 +78,6 @@ def _fitter_template(
weights=None,
fitter=None,
model_info=None,
options=None,
):
"""Fit a model to data.
Expand All @@ -122,7 +91,6 @@ def _fitter_template(
``x``, second ``y`` and third ``model_info``.
model_info (ModelInfo): Information that describes the functional form of
the model.
options (dict): Options for the fit method.
Returns:
VectorModel or ScalarModel: Results container.
Expand All @@ -138,7 +106,7 @@ def _fitter_template(
y = y * _root_weights
x = x * _root_weights

coef = fitter(x, y, model_info, **options)
coef = fitter(x, y)

# results processing
intercepts, linear_terms, square_terms = np.split(coef, (1, n_params + 1), axis=1)
Expand Down
Loading

0 comments on commit 66ac796

Please sign in to comment.