diff --git a/src/estimagic/optimization/tranquilo/filter_points.py b/src/estimagic/optimization/tranquilo/filter_points.py index 544e40b6f..a2cddd571 100644 --- a/src/estimagic/optimization/tranquilo/filter_points.py +++ b/src/estimagic/optimization/tranquilo/filter_points.py @@ -3,11 +3,12 @@ from scipy.linalg import qr_multiply from estimagic.optimization.tranquilo.clustering import cluster +from estimagic.optimization.tranquilo.get_component import get_component from estimagic.optimization.tranquilo.models import n_second_order_terms from estimagic.optimization.tranquilo.volume import get_radius_after_volume_scaling -def get_sample_filter(sample_filter="keep_all"): +def get_sample_filter(sample_filter="keep_all", user_options=None): """Get filter function with partialled options. The filter function is applied to points inside the current trustregion before @@ -31,31 +32,31 @@ def get_sample_filter(sample_filter="keep_all"): "drop_pounders": drop_collinear_pounders, } - if isinstance(sample_filter, str) and sample_filter in built_in_filters: - out = built_in_filters[sample_filter] - elif callable(sample_filter): - out = sample_filter - else: - raise ValueError() + out = get_component( + name_or_func=sample_filter, + component_name="sample_filter", + func_dict=built_in_filters, + user_options=user_options, + ) return out -def discard_all(xs, indices, state, target_size): # noqa: ARG001 +def discard_all(state): return state.x.reshape(1, -1), np.array([state.index]) -def keep_all(xs, indices, state, target_size): # noqa: ARG001 +def keep_all(xs, indices): return xs, indices -def keep_sphere(xs, indices, state, target_size): # noqa: ARG001 +def keep_sphere(xs, indices, state): dists = np.linalg.norm(xs - state.trustregion.center, axis=1) keep = dists <= state.trustregion.radius return xs[keep], indices[keep] -def drop_collinear_pounders(xs, indices, state, target_size): # noqa: ARG001 +def drop_collinear_pounders(xs, indices, state): """Drop collinear points using pounders filtering.""" if xs.shape[0] <= xs.shape[1] + 1: filtered_xs, filtered_indices = xs, indices diff --git a/src/estimagic/optimization/tranquilo/fit_models.py b/src/estimagic/optimization/tranquilo/fit_models.py index 060f6275f..401a989b0 100644 --- a/src/estimagic/optimization/tranquilo/fit_models.py +++ b/src/estimagic/optimization/tranquilo/fit_models.py @@ -1,11 +1,10 @@ -import inspect -import warnings from functools import partial import numpy as np from numba import njit from scipy.linalg import qr_multiply +from estimagic.optimization.tranquilo.get_component import get_component from estimagic.optimization.tranquilo.models import ( ModelInfo, VectorModel, @@ -38,8 +37,6 @@ def get_fitter(fitter, user_options=None, model_info=None): callable: The partialled fit method that only depends on x and y. """ - if user_options is None: - user_options = {} if model_info is None: model_info = ModelInfo() @@ -49,57 +46,30 @@ def get_fitter(fitter, user_options=None, model_info=None): "powell": fit_powell, } - if isinstance(fitter, str) and fitter in built_in_fitters: - _fitter = built_in_fitters[fitter] - _fitter_name = fitter - elif callable(fitter): - _fitter = fitter - _fitter_name = getattr(fitter, "__name__", "your fitter") - else: - raise ValueError( - f"Invalid fitter: {fitter}. Must be one of {list(built_in_fitters)} or a " - "callable." - ) - default_options = { "l2_penalty_linear": 0, "l2_penalty_square": 0.1, + "model_info": model_info, } - all_options = {**default_options, **user_options} - - args = set(inspect.signature(_fitter).parameters) - - if not {"x", "y", "model_info"}.issubset(args): - raise ValueError( - "fit method needs to take 'x', 'y' and 'model_info' as the first three " - "arguments." - ) - - not_options = {"x", "y", "model_info"} - if isinstance(_fitter, partial): - partialed_in = set(_fitter.args).union(set(_fitter.keywords)) - not_options = not_options | partialed_in + mandatory_arguments = ["x", "y", "model_info"] - valid_options = args - not_options - - reduced = {key: val for key, val in all_options.items() if key in valid_options} - - ignored = { - key: val for key, val in user_options.items() if key not in valid_options - } - - if ignored: - warnings.warn( - "The following options were ignored because they are not compatible with " - f"{_fitter_name}:\n\n {ignored}" - ) + _raw_fitter = get_component( + name_or_func=fitter, + component_name="fitter", + func_dict=built_in_fitters, + default_options=default_options, + user_options=user_options, + mandatory_signature=mandatory_arguments, + ) - out = partial( - _fitter_template, fitter=_fitter, model_info=model_info, options=reduced + fitter = partial( + _fitter_template, + fitter=_raw_fitter, + model_info=model_info, ) - return out + return fitter def _fitter_template( @@ -108,7 +78,6 @@ def _fitter_template( weights=None, fitter=None, model_info=None, - options=None, ): """Fit a model to data. @@ -122,7 +91,6 @@ def _fitter_template( ``x``, second ``y`` and third ``model_info``. model_info (ModelInfo): Information that describes the functional form of the model. - options (dict): Options for the fit method. Returns: VectorModel or ScalarModel: Results container. @@ -138,7 +106,7 @@ def _fitter_template( y = y * _root_weights x = x * _root_weights - coef = fitter(x, y, model_info, **options) + coef = fitter(x, y) # results processing intercepts, linear_terms, square_terms = np.split(coef, (1, n_params + 1), axis=1) diff --git a/src/estimagic/optimization/tranquilo/get_component.py b/src/estimagic/optimization/tranquilo/get_component.py new file mode 100644 index 000000000..d57922d42 --- /dev/null +++ b/src/estimagic/optimization/tranquilo/get_component.py @@ -0,0 +1,232 @@ +import functools +import inspect +import warnings +from functools import partial + +from estimagic.utilities import propose_alternatives + + +def get_component( + name_or_func, + component_name, + func_dict=None, + default_options=None, + user_options=None, + redundant_option_handling="ignore", + redundant_argument_handling="ignore", + mandatory_signature=None, +): + """Process a function that represents an interchangeable component of tranquilo. + + The function is either a built in function or a user provided function. In all + cases we run some checks that the signature of the function is correct and then + partial all static options into the function. + + Args: + name_or_func (str or callable): Name of a function or function. + component_name (str): Name of the component. Used in error messages. Examples + would be "subsolver" or "model". + func_dict (dict): Dict with function names as keys and functions as values. + default_options (dict): Default options for the function. Those will be + partialled into the function unless overridden by user_options. + user_options (dict): Options for the function. Those will be partialled into + the function. + redundant_option_handling (str): How to handle redundant options. Can be + "warn", "raise" or "ignore". Default "ignore". + redundant_argument_handling (str): How to handle redundant arguments passed + to the processed function at runtime. Can be "warn", "raise" or "ignore". + Default "ignore". + mandatory_signature (list): List or tuple of arguments that must be in the + signature of all functions in `func_dict`. These can be options or + arguments. Otherwise, a ValueError is raised. + + Returns: + callable: The processed function. + + """ + + _func, _name = _get_function_and_name( + name_or_func=name_or_func, + component_name=component_name, + func_dict=func_dict, + ) + + _all_arguments = list(inspect.signature(_func).parameters) + + _valid_options = _get_valid_options( + default_options=default_options, + user_options=user_options, + signature=_all_arguments, + name=_name, + component_name=component_name, + redundant_option_handling=redundant_option_handling, + ) + + _fail_if_mandatory_argument_is_missing( + mandatory_arguments=mandatory_signature, + signature=_all_arguments, + name=_name, + component_name=component_name, + ) + + _partialled = partial(_func, **_valid_options) + + if redundant_argument_handling == "raise": + out = _partialled + else: + out = _add_redundant_argument_handling( + func=_partialled, + signature=_all_arguments, + warn=redundant_argument_handling == "warn", + ) + + return out + + +def _get_function_and_name(name_or_func, component_name, func_dict): + """Get the function and its name. + + Args: + name_or_func (str or callable): Name of a function or function. + component_name (str): Name of the component. Used in error messages. Examples + would be "subsolver" or "model". + func_dict (dict): Dict with function names as keys and functions as values. + + Returns: + tuple: The function and its name. + + """ + func_dict = {} if func_dict is None else func_dict + if isinstance(name_or_func, str): + if name_or_func in func_dict: + _func = func_dict[name_or_func] + _name = name_or_func + else: + _proposal = propose_alternatives(name_or_func, list(func_dict)) + msg = ( + f"If {component_name} is a string, it must be one of the built in " + f"{component_name}s. Did you mean: {_proposal}?" + ) + raise ValueError(msg) + elif callable(name_or_func): + _func = name_or_func + _name = _func.__name__ + else: + raise TypeError("name_or_func must be a string or a callable.") + + return _func, _name + + +def _get_valid_options( + default_options, + user_options, + signature, + name, + component_name, + redundant_option_handling, +): + """Get the options that are valid for the function. + + Args: + default_options (dict): Default options for the function. + user_options (dict): Options for the function. + signature (list): List of arguments that are present in the signature. + name (str): Name of the function. + component_name (str): Name of the component. Used in error messages. Examples + would be "subsolver" or "model". + redundant_option_handling (str): How to handle redundant options. Can be + + Returns: + dict: Valid options. + + """ + _default_options = {} if default_options is None else default_options + _user_options = {} if user_options is None else user_options + + _options = {**_default_options, **_user_options} + + _valid_options = {k: v for k, v in _options.items() if k in signature} + _redundant_options = {k: v for k, v in _options.items() if k not in signature} + + if redundant_option_handling == "warn" and _redundant_options: + msg = ( + f"The following options are not supported by the {component_name} {name} " + f"and will be ignored: {list(_redundant_options)}." + ) + warnings.warn(msg) + + elif redundant_option_handling == "raise" and _redundant_options: + msg = ( + f"The following options are not supported by the {component_name} {name}: " + f"{list(_redundant_options)}." + ) + raise ValueError(msg) + + return _valid_options + + +def _fail_if_mandatory_argument_is_missing( + mandatory_arguments, signature, name, component_name +): + """Check if any mandatory arguments are missing in the signature of the function. + + Args: + mandatory_arguments (list): List of mandatory arguments. + signature (list): List of arguments that are present in the signature. + name (str): Name of the function. + component_name (str): Name of the component. Used in error messages. Examples + would be "subsolver" or "model". + + Returns: + None + + Raises: + ValueError: If any mandatory arguments are missing in the signature of the + function. + + """ + mandatory_arguments = [] if mandatory_arguments is None else mandatory_arguments + + _missing = [arg for arg in mandatory_arguments if arg not in signature] + + if _missing: + msg = ( + f"The following mandatory arguments are missing in the signature of the " + f"{component_name} {name}: {_missing}." + ) + raise ValueError(msg) + + return None + + +def _add_redundant_argument_handling(func, signature, warn): + """Allow func to be called with arguments that are not in the signature. + + Args: + func (callable): The function to be wrapped. + signature (list): List of arguments that are supported by func. + warn (bool): Whether to warn about redundant arguments. + + Returns: + callable: The wrapped function. + + """ + + @functools.wraps(func) + def _wrapper_add_redundant_argument_handling(*args, **kwargs): + _kwargs = {**dict(zip(signature[: len(args)], args)), **kwargs} + + _redundant = {k: v for k, v in _kwargs.items() if k not in signature} + _valid = {k: v for k, v in _kwargs.items() if k in signature} + + if warn and _redundant: + msg = ( + f"The following arguments are not supported by the function " + f"{func.__name__} and will be ignored: {_redundant}." + ) + warnings.warn(msg) + + out = func(**_valid) + return out + + return _wrapper_add_redundant_argument_handling diff --git a/src/estimagic/optimization/tranquilo/sample_points.py b/src/estimagic/optimization/tranquilo/sample_points.py index 659c3c9bd..481e3b80c 100644 --- a/src/estimagic/optimization/tranquilo/sample_points.py +++ b/src/estimagic/optimization/tranquilo/sample_points.py @@ -1,5 +1,3 @@ -import inspect -import warnings from functools import partial import numpy as np @@ -7,6 +5,7 @@ from scipy.special import gammainc, logsumexp import estimagic as em +from estimagic.optimization.tranquilo.get_component import get_component from estimagic.optimization.tranquilo.options import Bounds @@ -33,7 +32,6 @@ def get_sampler( existing_fvals, model_info and and returns a new sample. """ - user_options = {} if user_options is None else user_options built_in_samplers = { "box": _box_sampler, @@ -46,19 +44,11 @@ def get_sampler( "optimal_sphere": partial(_optimal_hull_sampler, order=2), } - if isinstance(sampler, str) and sampler in built_in_samplers: - _sampler = built_in_samplers[sampler] - _sampler_name = sampler - elif callable(sampler): - _sampler = sampler - _sampler_name = getattr(sampler, "__name__", "your sampler") - else: - raise ValueError( - f"Invalid sampler: {sampler}. Must be one of {list(built_in_samplers)} " - "or a callable." - ) - - if "hull_sampler" in _sampler_name and "order" not in user_options: + if ( + isinstance(sampler, str) + and "hull_sampler" in sampler + and "order" not in user_options + ): msg = ( "The hull_sampler and optimal_hull_sampler require the argument 'order' to " "be prespecfied in the user_options dictionary. Order is a positive " @@ -67,48 +57,27 @@ def get_sampler( ) raise ValueError(msg) - args = set(inspect.signature(_sampler).parameters) + default_options = { + "bounds": bounds, + "model_info": model_info, + "radius_factors": radius_factors, + } - mandatory_args = { + mandatory_args = [ "bounds", "trustregion", "n_points", "existing_xs", "rng", - } - - optional_kwargs = { - "model_info": model_info, - "radius_factors": radius_factors, - } - - optional_kwargs = {k: v for k, v in optional_kwargs.items() if k in args} - - problematic = mandatory_args - args - if problematic: - raise ValueError( - f"The following mandatory arguments are missing in {_sampler_name}: " - f"{problematic}" - ) - - valid_options = args - mandatory_args - - reduced = {key: val for key, val in user_options.items() if key in valid_options} - ignored = { - key: val for key, val in user_options.items() if key not in valid_options - } - - if ignored: - warnings.warn( - "The following options were ignored because they are not compatible " - f"with {_sampler_name}:\n\n {ignored}" - ) - - out = partial( - _sampler, - bounds=bounds, - **optional_kwargs, - **reduced, + ] + + out = get_component( + name_or_func=sampler, + component_name="sampler", + func_dict=built_in_samplers, + default_options=default_options, + user_options=user_options, + mandatory_signature=mandatory_args, ) return out diff --git a/src/estimagic/optimization/tranquilo/tranquilo.py b/src/estimagic/optimization/tranquilo/tranquilo.py index e8eb523d4..467dbfcd4 100644 --- a/src/estimagic/optimization/tranquilo/tranquilo.py +++ b/src/estimagic/optimization/tranquilo/tranquilo.py @@ -43,6 +43,7 @@ def _tranquilo( random_seed=925408, sampler=None, sample_filter=None, + filter_options=None, fitter=None, subsolver=None, sample_size=None, @@ -206,7 +207,7 @@ def _tranquilo( radius_factors=radius_factors, user_options=sampler_options, ) - filter_points = get_sample_filter(sample_filter) + filter_points = get_sample_filter(sample_filter, user_options=filter_options) aggregate_vector_model = get_aggregator( aggregator=aggregator, diff --git a/tests/optimization/tranquilo/test_filter_points.py b/tests/optimization/tranquilo/test_filter_points.py index 457638c4c..eac96e7f2 100644 --- a/tests/optimization/tranquilo/test_filter_points.py +++ b/tests/optimization/tranquilo/test_filter_points.py @@ -216,9 +216,7 @@ def test_drop_collinear_pounders(test_case, request): test_case ) - filtered_xs, filtered_indices = drop_collinear_pounders( - old_xs, old_indices, state, target_size=None - ) + filtered_xs, filtered_indices = drop_collinear_pounders(old_xs, old_indices, state) assert_equal(filtered_indices, expected_indices) aaae(filtered_xs, expected_xs) diff --git a/tests/optimization/tranquilo/test_get_component.py b/tests/optimization/tranquilo/test_get_component.py new file mode 100644 index 000000000..6df4284e3 --- /dev/null +++ b/tests/optimization/tranquilo/test_get_component.py @@ -0,0 +1,150 @@ +import pytest +from estimagic.optimization.tranquilo.get_component import ( + _add_redundant_argument_handling, + _fail_if_mandatory_argument_is_missing, + _get_function_and_name, + _get_valid_options, + get_component, +) + + +@pytest.fixture() +def func_dict(): + out = { + "f": lambda x: x, + "g": lambda x, y: x + y, + } + return out + + +def test_get_component(func_dict): + got = get_component( + name_or_func="g", + component_name="component", + func_dict=func_dict, + default_options={"x": 1}, + user_options={"y": 2}, + redundant_option_handling="ignore", + redundant_argument_handling="ignore", + mandatory_signature=["x"], + ) + + assert got() == 3 + assert got(bla=15) == 3 + + +def test_get_function_and_name_valid_string(func_dict): + _func, _name = _get_function_and_name( + name_or_func="f", + component_name="component", + func_dict=func_dict, + ) + assert _func == func_dict["f"] + assert _name == "f" + + +def test_get_function_and_name_invalid_string(): + with pytest.raises(ValueError, match="If component is a string, it must be one of"): + _get_function_and_name( + name_or_func="h", + component_name="component", + func_dict={"f": lambda x: x, "g": lambda x, y: x + y}, + ) + + +def test_get_function_and_name_valid_function(): + def _f(x): + return x + + _func, _name = _get_function_and_name( + name_or_func=_f, + component_name="component", + func_dict=None, + ) + assert _func == _f + assert _name == "_f" + + +def test_get_function_and_string_wrong_type(): + with pytest.raises(TypeError, match="name_or_func must be a string or a callable."): + _get_function_and_name( + name_or_func=1, + component_name="component", + func_dict=None, + ) + + +def test_get_valid_options_ignore(): + got = _get_valid_options( + default_options={"a": 1, "b": 2}, + user_options={"a": 3, "c": 4}, + signature=["a", "c"], + name="bla", + component_name="component", + redundant_option_handling="ignore", + ) + expected = {"a": 3, "c": 4} + + assert got == expected + + +def test_get_valid_options_raise(): + with pytest.raises(ValueError, match="The following options are not supported"): + _get_valid_options( + default_options={"a": 1, "b": 2}, + user_options={"a": 3, "c": 4}, + signature=["a", "c"], + name="bla", + component_name="component", + redundant_option_handling="raise", + ) + + +def test_get_valid_options_warn(): + with pytest.warns(UserWarning, match="The following options are not supported"): + _get_valid_options( + default_options={"a": 1, "b": 2}, + user_options={"a": 3, "c": 4}, + signature=["a", "c"], + name="bla", + component_name="component", + redundant_option_handling="warn", + ) + + +def test_fail_if_mandatory_argument_is_missing(): + with pytest.raises( + ValueError, match="The following mandatory arguments are missing" + ): + _fail_if_mandatory_argument_is_missing( + mandatory_arguments=["a", "c"], + signature=["a", "b"], + name="bla", + component_name="component", + ) + + +def test_add_redundant_argument_handling_ignore(): + def f(a, b): + return a + b + + _f = _add_redundant_argument_handling( + func=f, + signature=["a", "b"], + warn=False, + ) + + assert _f(1, b=2, c=3) == 3 + + +def test_add_redundant_argument_handling_warn(): + def f(a, b): + return a + b + + _f = _add_redundant_argument_handling( + func=f, + signature=["a", "b"], + warn=True, + ) + with pytest.warns(UserWarning, match="The following arguments are not supported"): + _f(1, b=2, c=3)