Skip to content

Commit

Permalink
fix knot initialization (#1313)
Browse files Browse the repository at this point in the history
  • Loading branch information
Doresic committed Feb 29, 2024
1 parent 0c9ef74 commit 312f43a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 25 deletions.
5 changes: 1 addition & 4 deletions pypesto/hierarchical/semiquantitative/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,11 +1086,8 @@ def save_inner_parameters_to_inner_problem(
group
)

lower_trian = np.tril(np.ones((len(s), len(s))))
xi = np.dot(lower_trian, s)

for idx in range(len(inner_spline_parameters)):
inner_spline_parameters[idx].value = xi[idx]
inner_spline_parameters[idx].value = s[idx]

sigma = group_dict[INNER_NOISE_PARS]

Expand Down
36 changes: 15 additions & 21 deletions pypesto/visualize/spline_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,12 @@ def plot_splines_from_inner_result(
group_idx = list(inner_problem.groups.keys()).index(group)

# For each group get the inner parameters and simulation
xs = inner_problem.get_xs_for_group(group)

s = result[SCIPY_X]

inner_parameters = np.array([x.value for x in xs])
# Utility matrix for the spline knot calculation
lower_trian = np.tril(np.ones((len(s), len(s))))
spline_knots = np.dot(lower_trian, s)

measurements = inner_problem.groups[group][DATAPOINTS]
simulation = inner_problem.groups[group][CURRENT_SIMULATION]

Expand All @@ -201,30 +202,30 @@ def plot_splines_from_inner_result(
) = SemiquantInnerSolver._rescale_spline_bases(
self=None,
sim_all=simulation,
N=len(inner_parameters),
N=len(spline_knots),
K=len(simulation),
)
mapped_simulations = get_spline_mapped_simulations(
s, simulation, len(inner_parameters), delta_c, spline_bases, n
s, simulation, len(spline_knots), delta_c, spline_bases, n
)

axs[group_idx].plot(
simulation, measurements, 'bs', label='Measurements'
)
axs[group_idx].plot(
spline_bases, inner_parameters, 'g.', label='Spline knots'
spline_bases, spline_knots, 'g.', label='Spline knots'
)
axs[group_idx].plot(
spline_bases,
inner_parameters,
spline_knots,
linestyle='-',
color='g',
label='Spline function',
)
if inner_solver.options[REGULARIZE_SPLINE]:
alpha_opt, beta_opt = _calculate_optimal_regularization(
s=s,
N=len(inner_parameters),
N=len(spline_knots),
c=spline_bases,
)
axs[group_idx].plot(
Expand Down Expand Up @@ -392,10 +393,7 @@ def _add_spline_mapped_simulations_to_model_fit(
][0]

# Get the inner parameters and simulation.
xs = inner_problem.get_xs_for_group(group)
s = inner_result[SCIPY_X]

inner_parameters = np.array([x.value for x in xs])
simulation = inner_problem.groups[group][CURRENT_SIMULATION]

# For the simulation, get the spline bases
Expand All @@ -406,12 +404,12 @@ def _add_spline_mapped_simulations_to_model_fit(
) = SemiquantInnerSolver._rescale_spline_bases(
self=None,
sim_all=simulation,
N=len(inner_parameters),
N=len(s),
K=len(simulation),
)
# and the spline-mapped simulations.
mapped_simulations = get_spline_mapped_simulations(
s, simulation, len(inner_parameters), delta_c, spline_bases, n
s, simulation, len(s), delta_c, spline_bases, n
)

# Plot the spline-mapped simulations to the ax with same color
Expand Down Expand Up @@ -536,29 +534,25 @@ def _obtain_regularization_for_start(
# for each result and group, plot the inner solution
for result, group in zip(inner_results, inner_problem.groups):
# For each group get the inner parameters and simulation
xs = inner_problem.get_xs_for_group(group)

s = result[SCIPY_X]

inner_parameters = np.array([x.value for x in xs])
simulation = inner_problem.groups[group][CURRENT_SIMULATION]

# For the simulation, get the spline bases
(
delta_c,
_,
spline_bases,
n,
_,
) = SemiquantInnerSolver._rescale_spline_bases(
self=None,
sim_all=simulation,
N=len(inner_parameters),
N=len(s),
K=len(simulation),
)

if inner_solver.options[REGULARIZE_SPLINE]:
reg_term = _calculate_regularization_for_group(
s=s,
N=len(inner_parameters),
N=len(s),
c=spline_bases,
regularization_factor=inner_solver.options[
'regularization_factor'
Expand Down

0 comments on commit 312f43a

Please sign in to comment.