Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More detailed defaults for problem.get_full_vector #1393

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions pypesto/problem/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ def unfix_parameters(
self.normalize()

def get_full_vector(
self, x: Union[np.ndarray, None], x_fixed_vals: Iterable[float] = None
self,
x: Union[np.ndarray, None],
x_fixed_vals: Iterable[float] = None,
x_is_grad: bool = False,
) -> Union[np.ndarray, None]:
"""
Map vector from dim to dim_full. Usually used for x, grad.
Expand All @@ -342,9 +345,9 @@ def get_full_vector(
x: array_like, shape=(dim,)
The vector in dimension dim.
x_fixed_vals: array_like, ndim=1, optional
The values to be used for the fixed indices. If None, then nans are
inserted. Usually, None will be used for grad and
problem.x_fixed_vals for x.
The values to be used for the fixed indices. If None and x_is_grad=False, problem.x_fixed_vals is used; for x_is_grad=True, nans are inserted.
x_is_grad: bool
If true, x is treated as gradients.
"""
if x is None:
return None
Expand All @@ -362,6 +365,9 @@ def get_full_vector(
x_full[..., self.x_free_indices] = x
if x_fixed_vals is not None:
x_full[..., self.x_fixed_indices] = x_fixed_vals
return x_full
if not x_is_grad:
x_full[..., self.x_fixed_indices] = self.x_fixed_vals
return x_full

def get_full_matrix(
Expand Down
6 changes: 3 additions & 3 deletions pypesto/result/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def update_to_full(self, problem: Problem) -> None:
problem which contains info about how to convert to full vectors
or matrices
"""
self.x = problem.get_full_vector(self.x, problem.x_fixed_vals)
self.grad = problem.get_full_vector(self.grad)
self.x = problem.get_full_vector(self.x)
self.grad = problem.get_full_vector(self.grad, x_is_grad=True)
self.hess = problem.get_full_matrix(self.hess)
self.x0 = problem.get_full_vector(self.x0, problem.x_fixed_vals)
self.x0 = problem.get_full_vector(self.x0)
self.free_indices = np.array(problem.x_free_indices)


Expand Down
4 changes: 2 additions & 2 deletions pypesto/visualize/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ def handle_inputs(
ub = result.problem.get_reduced_vector(ub, parameter_indices)
x_labels = [x_labels[int(i)] for i in parameter_indices]
else:
lb = result.problem.get_full_vector(lb)
ub = result.problem.get_full_vector(ub)
lb = result.problem.lb_full
ub = result.problem.ub_full

if inner_xs is not None and plot_inner_parameters:
lb = np.concatenate([lb, inner_lb])
Expand Down
4 changes: 1 addition & 3 deletions test/base/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def check_reconstruct_history(

def check_history_consistency(self, start: pypesto.OptimizerResult):
def xfull(x_trace):
return self.problem.get_full_vector(
x_trace, self.problem.x_fixed_vals
)
return self.problem.get_full_vector(x_trace)

if isinstance(start.history, (CsvHistory, Hdf5History)):
# get index of optimal parameter
Expand Down
3 changes: 1 addition & 2 deletions test/base/test_x_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def test_optimize():
# fixed values written into parameter vector
assert optimizer_result.x[1] == 1

lb_full = problem.get_full_vector(problem.lb)
assert len(lb_full) == 5
assert len(problem.lb_full) == 5


def create_problem():
Expand Down
Loading