Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tranquilo cleanup #443

Merged
merged 8 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 18 additions & 29 deletions src/estimagic/optimization/tranquilo/aggregate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from estimagic.optimization.tranquilo.models import ScalarModel


def get_aggregator(aggregator, functype, model_info):
def get_aggregator(aggregator, functype, model_type):
"""Get a function that aggregates a VectorModel into a ScalarModel.

Args:
aggregator (str or callable): Name of an aggregator or aggregator function.
The function must take as argument:
- vector_model (VectorModel): A fitted vector model.
functype (str): One of "scalar", "least_squares" and "likelihood".
model_info (ModelInfo): Information that describes the functional form of
the model.
model_type (str): Type of the model that is fitted. The following are supported:
- "linear": Only linear effects and intercept.
- "quadratic": Fully quadratic model.

Returns:
callable: The partialled aggregator that only depends on vector_model.
Expand All @@ -37,11 +38,11 @@ def get_aggregator(aggregator, functype, model_info):
_using_built_in_aggregator = False
else:
raise ValueError(
"Invalid aggregator: {aggregator}. Must be one of "
"Invalid aggregator: {aggregator}. Must be one of "
f"{list(built_in_aggregators)} or a callable."
)

# determine if aggregator is compatible with functype and model_info
# determine if aggregator is compatible with functype and model_type
aggregator_compatible_with_functype = {
"scalar": ("identity", "sum"),
"least_squares": ("least_squares_linear",),
Expand All @@ -51,17 +52,9 @@ def get_aggregator(aggregator, functype, model_info):
),
}

aggregator_compatible_with_model_info = {
# keys are names of aggregators and values are functions of model_info that
# return False in case of incompatibility
"identity": lambda x: True, # noqa: ARG005
"sum": _is_second_order_model,
"information_equality_linear": lambda model_info: not _is_second_order_model(
model_info
),
"least_squares_linear": lambda model_info: not _is_second_order_model(
model_info
),
aggregator_compatible_with_model_type = {
"linear": {"information_equality_linear", "least_squares_linear"},
"quadratic": {"identity", "sum"},
}

if _using_built_in_aggregator:
Expand All @@ -71,12 +64,12 @@ def get_aggregator(aggregator, functype, model_info):
f"Aggregator {_aggregator_name} is not compatible with functype "
f"{functype}. It would not produce a quadratic main model."
)
if not aggregator_compatible_with_model_info[_aggregator_name](model_info):
if _aggregator_name not in aggregator_compatible_with_model_type[model_type]:
raise ValueError(
f"ModelInfo {model_info} is not compatible with aggregator "
f"{_aggregator_name}. Depending on the aggregator this may be because "
"it would not produce a quadratic main model or that the aggregator "
"requires a different residual model for theoretical reasons."
f"Aggregator {_aggregator_name} is not compatible with model_type "
f"{model_type}. This is because the combination would not produce a "
"quadratic main model or that the aggregator requires a different "
"residual model."
)

# create aggregator
Expand Down Expand Up @@ -114,7 +107,7 @@ def aggregator_identity(vector_model):
Assumptions
-----------
1. functype: scalar
2. ModelInfo: has squares or interactions
2. model_type: quadratic

"""
intercept = float(vector_model.intercepts)
Expand All @@ -137,7 +130,7 @@ def aggregator_sum(vector_model):
Assumptions
-----------
1. functype: likelihood
2. ModelInfo: has squares or interactions
2. model_type: quadratic

"""
vm_intercepts = vector_model.intercepts
Expand All @@ -157,7 +150,7 @@ def aggregator_least_squares_linear(vector_model):
Assumptions
-----------
1. functype: least_squares
2. ModelInfo: has intercept but no squares and no interaction
2. model_type: linear

References
----------
Expand Down Expand Up @@ -185,7 +178,7 @@ def aggregator_information_equality_linear(vector_model):
Assumptions
-----------
1. functype: likelihood
2. ModelInfo: has no squares and no interaction
2. model_type: linear

"""
vm_linear_terms = vector_model.linear_terms
Expand All @@ -198,7 +191,3 @@ def aggregator_information_equality_linear(vector_model):
square_terms = -fisher_information / 2

return intercept, linear_terms, square_terms


def _is_second_order_model(model_info):
return model_info.has_squares or model_info.has_interactions
27 changes: 27 additions & 0 deletions src/estimagic/optimization/tranquilo/bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass, replace

import numpy as np


@dataclass
timmens marked this conversation as resolved.
Show resolved Hide resolved
class Bounds:
"""Parameter bounds."""

lower: np.ndarray
upper: np.ndarray

def __post_init__(self):
self.has_any = _any_finite(self.lower, self.upper)

# make it behave like a NamedTuple
def _replace(self, **kwargs):
return replace(self, **kwargs)


def _any_finite(lb, ub):
out = False
if lb is not None and np.isfinite(lb).any():
out = True
if ub is not None and np.isfinite(ub).any():
out = True
return out
2 changes: 1 addition & 1 deletion src/estimagic/optimization/tranquilo/estimate_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from estimagic.optimization.tranquilo.get_component import get_component
from estimagic.optimization.tranquilo.new_history import History
from estimagic.optimization.tranquilo.options import Region
from estimagic.optimization.tranquilo.region import Region


def get_variance_estimator(fitter, user_options):
Expand Down
Loading