Skip to content

Commit

Permalink
just a few minor changes (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
hammannr authored Apr 26, 2024
1 parent 90cd3e0 commit ba877d0
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions alea/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def cost(args):

@_needs_data
def fit(
self, verbose=False, index_fitting=True, max_index_fitting_iter=10, **kwargs
self, verbose=False, disable_index_fitting=False, max_index_fitting_iter=10, **kwargs
) -> Tuple[dict, float]:
"""Fit the model to the data by maximizing the likelihood. Return a dict containing best-fit
values of each parameter, and the value of the likelihood evaluated there. While the
Expand All @@ -317,6 +317,9 @@ def fit(
Args:
verbose (bool): if True, print the Minuit object
disable_index_fitting (bool): if True, disable the index fitting
even if the model has index parameters.
max_index_fitting_iter (int): maximum number of iterations for index fitting
Returns:
dict, float: best-fit values of each parameter,
Expand All @@ -340,18 +343,16 @@ def fit(
for par in fixed_params:
m.fixed[par] = True

# Get the index variables, which could have problem if simply using migrad
index_variables = [
p
for p in self.parameters.parameters.values()
if p.ptype == "index" and p.name not in fixed_params
# Get the index parameters, which could have problem if simply using migrad
index_parameters = [
p for p in self.parameters if p.ptype == "index" and p.name not in fixed_params
]

if (not index_fitting) or (len(index_variables) == 0):
if disable_index_fitting or (len(index_parameters) == 0):
# Call migrad to do the actual minimization
m = self._migrad_fit(m)
else:
m = self._migrad_index_mixing_fit(m, index_variables, max_index_fitting_iter, verbose)
m = self._migrad_index_mixing_fit(m, index_parameters, max_index_fitting_iter, verbose)

self.minuit_object = m
if verbose:
Expand All @@ -363,36 +364,36 @@ def _migrad_fit(self, m):
m.migrad()
return m

def _migrad_index_mixing_fit(self, m, index_variables, max_index_fitting_iter, verbose):
index_anchors = [var.blueice_anchors for var in index_variables]
index_names = [var.name for var in index_variables]
def _migrad_index_mixing_fit(self, m, index_parameters, max_index_fitting_iter, verbose):
index_anchors = [p.blueice_anchors for p in index_parameters]
index_names = [p.name for p in index_parameters]
index_grid = [
{index_names[i]: anchor[i] for i in range(len(anchor))}
for anchor in product(*index_anchors)
]

# We fix the index variables in migrad
# We fix the index parameters in migrad
for par in index_names:
m.fixed[par] = True

# We firstly do optimization on other parameters with index variables
# We firstly do optimization on other parameters with index parameters
# fixed to their initial guesses. Then we grid search over the index
# variables given the optimized parameters. We repeat the optimization
# parameters given the optimized parameters. We repeat the optimization
# and grid search until the optimization converges
for itr in range(max_index_fitting_iter):
m.migrad()

# Find the best-fit index variables
# Find the best-fit index parameters
lls = np.zeros(len(index_grid))
for i in range(len(lls)):
params = m.values.to_dict()
params.update(index_grid[i])
lls[i] = self.ll(**params)
best_fit_params = m.values.to_dict()
best_fit_params.update(index_grid[i])
lls[i] = self.ll(**best_fit_params)
for var in index_names:
m.values[var] = index_grid[np.argmax(lls)][var]

# Calculating Hessian will update the validity of
# the fitting given the new index variables
# the fitting given the new index parameters
m.hesse()
if m.valid:
break
Expand Down

0 comments on commit ba877d0

Please sign in to comment.