Skip to content

Commit

Permalink
More detailed defaults for problem.get_full_vector (#1393)
Browse files Browse the repository at this point in the history
* Made get_full_vector more intuitive

* Corrected he usage of get_full vector. Removed it for lb and ub in favor of lb_full and ub_full.

* Apply suggestions from code review

Co-authored-by: Maren Philipps <55318391+m-philipps@users.noreply.github.com>

---------

Co-authored-by: Maren Philipps <55318391+m-philipps@users.noreply.github.com>
  • Loading branch information
PaulJonasJost and m-philipps committed May 14, 2024
1 parent 18be6e3 commit 9d5c13b
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 14 deletions.
14 changes: 10 additions & 4 deletions pypesto/problem/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ def unfix_parameters(
self.normalize()

def get_full_vector(
self, x: Union[np.ndarray, None], x_fixed_vals: Iterable[float] = None
self,
x: Union[np.ndarray, None],
x_fixed_vals: Iterable[float] = None,
x_is_grad: bool = False,
) -> Union[np.ndarray, None]:
"""
Map vector from dim to dim_full. Usually used for x, grad.
Expand All @@ -342,9 +345,9 @@ def get_full_vector(
x: array_like, shape=(dim,)
The vector in dimension dim.
x_fixed_vals: array_like, ndim=1, optional
The values to be used for the fixed indices. If None, then nans are
inserted. Usually, None will be used for grad and
problem.x_fixed_vals for x.
The values to be used for the fixed indices. If None and x_is_grad=False, problem.x_fixed_vals is used; for x_is_grad=True, nans are inserted.
x_is_grad: bool
If true, x is treated as gradients.
"""
if x is None:
return None
Expand All @@ -362,6 +365,9 @@ def get_full_vector(
x_full[..., self.x_free_indices] = x
if x_fixed_vals is not None:
x_full[..., self.x_fixed_indices] = x_fixed_vals
return x_full
if not x_is_grad:
x_full[..., self.x_fixed_indices] = self.x_fixed_vals
return x_full

def get_full_matrix(
Expand Down
6 changes: 3 additions & 3 deletions pypesto/result/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def update_to_full(self, problem: Problem) -> None:
problem which contains info about how to convert to full vectors
or matrices
"""
self.x = problem.get_full_vector(self.x, problem.x_fixed_vals)
self.grad = problem.get_full_vector(self.grad)
self.x = problem.get_full_vector(self.x)
self.grad = problem.get_full_vector(self.grad, x_is_grad=True)
self.hess = problem.get_full_matrix(self.hess)
self.x0 = problem.get_full_vector(self.x0, problem.x_fixed_vals)
self.x0 = problem.get_full_vector(self.x0)
self.free_indices = np.array(problem.x_free_indices)


Expand Down
4 changes: 2 additions & 2 deletions pypesto/visualize/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ def handle_inputs(
ub = result.problem.get_reduced_vector(ub, parameter_indices)
x_labels = [x_labels[int(i)] for i in parameter_indices]
else:
lb = result.problem.get_full_vector(lb)
ub = result.problem.get_full_vector(ub)
lb = result.problem.lb_full
ub = result.problem.ub_full

if inner_xs is not None and plot_inner_parameters:
lb = np.concatenate([lb, inner_lb])
Expand Down
4 changes: 1 addition & 3 deletions test/base/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def check_reconstruct_history(

def check_history_consistency(self, start: pypesto.OptimizerResult):
def xfull(x_trace):
return self.problem.get_full_vector(
x_trace, self.problem.x_fixed_vals
)
return self.problem.get_full_vector(x_trace)

if isinstance(start.history, (CsvHistory, Hdf5History)):
# get index of optimal parameter
Expand Down
3 changes: 1 addition & 2 deletions test/base/test_x_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def test_optimize():
# fixed values written into parameter vector
assert optimizer_result.x[1] == 1

lb_full = problem.get_full_vector(problem.lb)
assert len(lb_full) == 5
assert len(problem.lb_full) == 5


def create_problem():
Expand Down

0 comments on commit 9d5c13b

Please sign in to comment.