Skip to content

Commit

Permalink
Set/fix typehints in visualize/profiles.py / visualize/reference_poin…
Browse files Browse the repository at this point in the history
…ts.py (#1223)

And remove redundant / contradictory types from docstrings
  • Loading branch information
dweindl committed Nov 30, 2023
1 parent 9fed090 commit 6ec4fc9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 61 deletions.
98 changes: 45 additions & 53 deletions pypesto/visualize/profiles.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from warnings import warn

import matplotlib.pyplot as plt
Expand All @@ -15,50 +15,50 @@ def profiles(
results: Union[Result, Sequence[Result]],
ax=None,
profile_indices: Sequence[int] = None,
size: Sequence[float] = (18.5, 6.5),
size: tuple[float, float] = (18.5, 6.5),
reference: Union[ReferencePoint, Sequence[ReferencePoint]] = None,
colors=None,
legends: Sequence[str] = None,
x_labels: Sequence[str] = None,
profile_list_ids: Union[int, Sequence[int]] = 0,
ratio_min: float = 0.0,
show_bounds: bool = False,
):
) -> plt.Axes:
"""
Plot classical 1D profile plot.
Using the posterior, e.g. Gaussian like profile.
Parameters
----------
results: list or pypesto.Result
results:
List of or single `pypesto.Result` after profiling.
ax: list of matplotlib.Axes, optional
ax:
List of axes objects to use.
profile_indices: list of integer values
profile_indices:
List of integer values specifying which profiles should be plotted.
size: tuple, optional
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
reference: list, optional
reference:
List of reference points for optimization results, containing at
least a function value fval.
colors: list, or RGBA, optional
colors:
List of colors, or single color.
legends: list or str, optional
legends:
Labels for line plots, one label per result object.
x_labels: list of str
x_labels:
Labels for parameter value axes (e.g. parameter names).
profile_list_ids: int or list of ints, optional
Index or list of indices of the profile lists to be used for profiling.
profile_list_ids:
Index or list of indices of the profile lists to visualize.
ratio_min:
Minimum ratio below which to cut off.
show_bounds:
Whether to show, and extend the plot to, the lower and upper bounds.
Returns
-------
ax: matplotlib.Axes
ax:
The plot axes.
"""
# parse input
Expand Down Expand Up @@ -122,38 +122,33 @@ def profiles(


def profiles_lowlevel(
fvals,
ax=None,
size: Tuple[float, float] = (18.5, 6.5),
fvals: Union[float, Sequence[float]],
ax: Optional[Sequence[plt.Axes]] = None,
size: tuple[float, float] = (18.5, 6.5),
color=None,
legend_text: str = None,
x_labels=None,
show_bounds: bool = False,
lb_full=None,
ub_full=None,
):
lb_full: Sequence[float] = None,
ub_full: Sequence[float] = None,
) -> list[plt.Axes]:
"""
Lowlevel routine for profile plotting.
Working with a list of arrays only, opening different axes objects in case.
Parameters
----------
fvals: numeric list or array
fvals:
Values to plot.
ax: list of matplotlib.Axes, optional
ax:
List of axes object to use.
size: tuple, optional
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
size: tuple, optional
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
color: RGBA, optional
Color for profiles in plot.
legend_text: str
Label for line plots.
legend_text: List[str]
legend_text:
Label for line plots.
show_bounds:
Whether to show, and extend the plot to, the lower and upper bounds.
Expand All @@ -164,8 +159,7 @@ def profiles_lowlevel(
Returns
-------
ax: matplotlib.Axes
The plot axes.
The plot axes.
"""
# axes
if ax is None:
Expand All @@ -179,7 +173,7 @@ def profiles_lowlevel(
create_new_ax = False

# count number of necessary axes
if isinstance(fvals, list):
if isinstance(fvals, Sequence):
n_fvals = len(fvals)
else:
n_fvals = 1
Expand Down Expand Up @@ -269,30 +263,30 @@ def profiles_lowlevel(


def profile_lowlevel(
fvals,
ax=None,
size: Tuple[float, float] = (18.5, 6.5),
fvals: Sequence[float],
ax: Optional[plt.Axes] = None,
size: tuple[float, float] = (18.5, 6.5),
color=None,
legend_text: str = None,
show_bounds: bool = False,
lb: float = None,
ub: float = None,
):
) -> plt.Axes:
"""
Lowlevel routine for plotting one profile, working with a numpy array only.
Parameters
----------
fvals: numeric list or array
fvals:
Values to plot.
ax: matplotlib.Axes, optional
ax:
Axes object to use.
size: tuple, optional
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
color: RGBA, optional
Color for profiles in plot.
legend_text: str
legend_text:
Label for line plots.
show_bounds:
Whether to show, and extend the plot to, the lower and upper bounds.
Expand All @@ -303,8 +297,7 @@ def profile_lowlevel(
Returns
-------
ax: matplotlib.Axes
The plot axes.
The plot axes.
"""
# parse input
fvals = np.asarray(fvals)
Expand Down Expand Up @@ -372,28 +365,27 @@ def handle_inputs(
profile_indices: Sequence[int],
profile_list: int,
ratio_min: float,
):
) -> list[np.array]:
"""
Retrieve the values of the profiles to be plotted.
Parameters
----------
result: pypesto.Result
result:
Profile result obtained by 'profile.py'.
profile_indices: list of integer values
List of integer values specifying which profiles should be plotted.
profile_list: int, optional
profile_indices:
Sequence of integer values specifying which profiles should be plotted.
profile_list:
Index of the profile list to be used for profiling.
ratio_min: int, optional
ratio_min:
Exclude values where profile likelihood ratio is smaller than
ratio_min.
Returns
-------
fvals: numeric list
Including values that need to be plotted.
List of parameter values and ratios that need to be plotted.
"""
# extract ratio values values from result
# extract ratio values from result
fvals = []
for i_par in range(0, len(result.profile_result.list[profile_list])):
if (
Expand Down Expand Up @@ -437,8 +429,7 @@ def process_result_list_profiles(
List of or single `pypesto.Result` after profiling.
profile_list_ids: int or list of ints, optional
Index or list of indices of the profile lists to be used for profiling.
colors: list of RGBA colors
colors for
colors: list of RGBA colors for plotting.
legends: list of str
Legends for plotting
Expand Down Expand Up @@ -506,6 +497,7 @@ def process_profile_indices(
else:
for ind in profile_indices:
if ind not in plottable_indices:
profile_indices = list(profile_indices)
profile_indices.remove(ind)
warn(
'Requested to plot profile for parameter index %i, '
Expand Down
20 changes: 12 additions & 8 deletions pypesto/visualize/reference_points.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Optional, Sequence, Union

import numpy as np

Expand All @@ -14,7 +14,7 @@ class ReferencePoint(dict):
Attributes
----------
x: ndarray
x:
Reference parameters.
fval: float
Function value, fun(x), for reference parameters.
Expand All @@ -28,7 +28,12 @@ class ReferencePoint(dict):
"""

def __init__(
self, reference=None, x=None, fval=None, color=None, legend=None
self,
reference: Union[None, dict, tuple, "ReferencePoint"] = None,
x: Optional[Sequence] = None,
fval: Optional[float] = None,
color=None,
legend: Optional[str] = None,
):
super().__init__()

Expand Down Expand Up @@ -104,19 +109,18 @@ def __getattr__(self, key):
__delattr__ = dict.__delitem__


def assign_colors(ref):
def assign_colors(ref: Sequence[ReferencePoint]) -> Sequence[ReferencePoint]:
"""
Assign colors to reference points, depending on user settings.
Parameters
----------
ref: list of ReferencePoint
ref:
Reference points, which need to get their color property filled
Returns
-------
ref: list of ReferencePoint
Reference points, which got their color property filled
Reference points, which got their color property filled
"""
# loop over reference points
auto_color_count = 0
Expand All @@ -141,7 +145,7 @@ def assign_colors(ref):

def create_references(
references=None, x=None, fval=None, color=None, legend=None
) -> List[ReferencePoint]:
) -> list[ReferencePoint]:
"""
Create a list of reference point objects from user inputs.
Expand Down

0 comments on commit 6ec4fc9

Please sign in to comment.