Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set/fix typehints in visualize/profiles.py / visualize/reference_poin… #1223

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 45 additions & 53 deletions pypesto/visualize/profiles.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from warnings import warn

import matplotlib.pyplot as plt
Expand All @@ -15,50 +15,50 @@ def profiles(
results: Union[Result, Sequence[Result]],
ax=None,
profile_indices: Sequence[int] = None,
size: Sequence[float] = (18.5, 6.5),
size: tuple[float, float] = (18.5, 6.5),
reference: Union[ReferencePoint, Sequence[ReferencePoint]] = None,
colors=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be colors: Sequence[tuple] or something similar?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"or something similar" 😅
I am not completely sure which types are supported.

legends: Sequence[str] = None,
x_labels: Sequence[str] = None,
profile_list_ids: Union[int, Sequence[int]] = 0,
ratio_min: float = 0.0,
show_bounds: bool = False,
):
) -> plt.Axes:
"""
Plot classical 1D profile plot.
Using the posterior, e.g. Gaussian like profile.
Parameters
----------
results: list or pypesto.Result
results:
List of or single `pypesto.Result` after profiling.
ax: list of matplotlib.Axes, optional
ax:
List of axes objects to use.
profile_indices: list of integer values
profile_indices:
List of integer values specifying which profiles should be plotted.
size: tuple, optional
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
reference: list, optional
reference:
List of reference points for optimization results, containing at
least a function value fval.
colors: list, or RGBA, optional
colors:
List of colors, or single color.
legends: list or str, optional
legends:
Labels for line plots, one label per result object.
x_labels: list of str
x_labels:
Labels for parameter value axes (e.g. parameter names).
profile_list_ids: int or list of ints, optional
Index or list of indices of the profile lists to be used for profiling.
profile_list_ids:
Index or list of indices of the profile lists to visualize.
ratio_min:
Minimum ratio below which to cut off.
show_bounds:
Whether to show, and extend the plot to, the lower and upper bounds.
Returns
-------
ax: matplotlib.Axes
ax:
The plot axes.
"""
# parse input
Expand Down Expand Up @@ -122,38 +122,33 @@ def profiles(


def profiles_lowlevel(
fvals,
ax=None,
size: Tuple[float, float] = (18.5, 6.5),
fvals: Union[float, Sequence[float]],
ax: Optional[Sequence[plt.Axes]] = None,
size: tuple[float, float] = (18.5, 6.5),
color=None,
legend_text: str = None,
x_labels=None,
show_bounds: bool = False,
lb_full=None,
ub_full=None,
):
lb_full: Sequence[float] = None,
ub_full: Sequence[float] = None,
) -> list[plt.Axes]:
"""
Lowlevel routine for profile plotting.
Working with a list of arrays only, opening different axes objects in case.
Parameters
----------
fvals: numeric list or array
fvals:
Values to plot.
ax: list of matplotlib.Axes, optional
ax:
List of axes object to use.
size: tuple, optional
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
size: tuple, optional
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
color: RGBA, optional
Color for profiles in plot.
legend_text: str
Label for line plots.
legend_text: List[str]
legend_text:
Label for line plots.
show_bounds:
Whether to show, and extend the plot to, the lower and upper bounds.
Expand All @@ -164,8 +159,7 @@ def profiles_lowlevel(
Returns
-------
ax: matplotlib.Axes
The plot axes.
The plot axes.
"""
# axes
if ax is None:
Expand All @@ -179,7 +173,7 @@ def profiles_lowlevel(
create_new_ax = False

# count number of necessary axes
if isinstance(fvals, list):
if isinstance(fvals, Sequence):
n_fvals = len(fvals)
else:
n_fvals = 1
Expand Down Expand Up @@ -269,30 +263,30 @@ def profiles_lowlevel(


def profile_lowlevel(
fvals,
ax=None,
size: Tuple[float, float] = (18.5, 6.5),
fvals: Sequence[float],
ax: Optional[plt.Axes] = None,
size: tuple[float, float] = (18.5, 6.5),
color=None,
legend_text: str = None,
show_bounds: bool = False,
lb: float = None,
ub: float = None,
):
) -> plt.Axes:
"""
Lowlevel routine for plotting one profile, working with a numpy array only.
Parameters
----------
fvals: numeric list or array
fvals:
Values to plot.
ax: matplotlib.Axes, optional
ax:
Axes object to use.
size: tuple, optional
size:
Figure size (width, height) in inches. Is only applied when no ax
object is specified.
color: RGBA, optional
Color for profiles in plot.
legend_text: str
legend_text:
Label for line plots.
show_bounds:
Whether to show, and extend the plot to, the lower and upper bounds.
Expand All @@ -303,8 +297,7 @@ def profile_lowlevel(
Returns
-------
ax: matplotlib.Axes
The plot axes.
The plot axes.
"""
# parse input
fvals = np.asarray(fvals)
Expand Down Expand Up @@ -372,28 +365,27 @@ def handle_inputs(
profile_indices: Sequence[int],
profile_list: int,
ratio_min: float,
):
) -> list[np.array]:
"""
Retrieve the values of the profiles to be plotted.
Parameters
----------
result: pypesto.Result
result:
Profile result obtained by 'profile.py'.
profile_indices: list of integer values
List of integer values specifying which profiles should be plotted.
profile_list: int, optional
profile_indices:
Sequence of integer values specifying which profiles should be plotted.
profile_list:
Index of the profile list to be used for profiling.
ratio_min: int, optional
ratio_min:
Exclude values where profile likelihood ratio is smaller than
ratio_min.
Returns
-------
fvals: numeric list
Including values that need to be plotted.
List of parameter values and ratios that need to be plotted.
"""
# extract ratio values values from result
# extract ratio values from result
fvals = []
for i_par in range(0, len(result.profile_result.list[profile_list])):
if (
Expand Down Expand Up @@ -437,8 +429,7 @@ def process_result_list_profiles(
List of or single `pypesto.Result` after profiling.
profile_list_ids: int or list of ints, optional
Index or list of indices of the profile lists to be used for profiling.
colors: list of RGBA colors
colors for
colors: list of RGBA colors for plotting.
legends: list of str
Legends for plotting
Expand Down Expand Up @@ -506,6 +497,7 @@ def process_profile_indices(
else:
for ind in profile_indices:
if ind not in plottable_indices:
profile_indices = list(profile_indices)
profile_indices.remove(ind)
warn(
'Requested to plot profile for parameter index %i, '
Expand Down
20 changes: 12 additions & 8 deletions pypesto/visualize/reference_points.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Optional, Sequence, Union

import numpy as np

Expand All @@ -14,7 +14,7 @@ class ReferencePoint(dict):
Attributes
----------
x: ndarray
x:
Reference parameters.
fval: float
Function value, fun(x), for reference parameters.
Expand All @@ -28,7 +28,12 @@ class ReferencePoint(dict):
"""

def __init__(
self, reference=None, x=None, fval=None, color=None, legend=None
self,
reference: Union[None, dict, tuple, "ReferencePoint"] = None,
x: Optional[Sequence] = None,
fval: Optional[float] = None,
color=None,
legend: Optional[str] = None,
):
super().__init__()

Expand Down Expand Up @@ -104,19 +109,18 @@ def __getattr__(self, key):
__delattr__ = dict.__delitem__


def assign_colors(ref):
def assign_colors(ref: Sequence[ReferencePoint]) -> Sequence[ReferencePoint]:
"""
Assign colors to reference points, depending on user settings.
Parameters
----------
ref: list of ReferencePoint
ref:
Reference points, which need to get their color property filled
Returns
-------
ref: list of ReferencePoint
Reference points, which got their color property filled
Reference points, which got their color property filled
"""
# loop over reference points
auto_color_count = 0
Expand All @@ -141,7 +145,7 @@ def assign_colors(ref):

def create_references(
references=None, x=None, fval=None, color=None, legend=None
) -> List[ReferencePoint]:
) -> list[ReferencePoint]:
"""
Create a list of reference point objects from user inputs.
Expand Down