From 6ec4fc912e77a7cdd5f80112d91bbd5680ebebd3 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 30 Nov 2023 15:35:41 +0100 Subject: [PATCH] Set/fix typehints in visualize/profiles.py / visualize/reference_points.py (#1223) And remove redundant / contradictory types from docstrings --- pypesto/visualize/profiles.py | 98 ++++++++++++--------------- pypesto/visualize/reference_points.py | 20 +++--- 2 files changed, 57 insertions(+), 61 deletions(-) diff --git a/pypesto/visualize/profiles.py b/pypesto/visualize/profiles.py index de12743a5..f2077ceef 100644 --- a/pypesto/visualize/profiles.py +++ b/pypesto/visualize/profiles.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple, Union +from typing import Optional, Sequence, Union from warnings import warn import matplotlib.pyplot as plt @@ -15,7 +15,7 @@ def profiles( results: Union[Result, Sequence[Result]], ax=None, profile_indices: Sequence[int] = None, - size: Sequence[float] = (18.5, 6.5), + size: tuple[float, float] = (18.5, 6.5), reference: Union[ReferencePoint, Sequence[ReferencePoint]] = None, colors=None, legends: Sequence[str] = None, @@ -23,7 +23,7 @@ def profiles( profile_list_ids: Union[int, Sequence[int]] = 0, ratio_min: float = 0.0, show_bounds: bool = False, -): +) -> plt.Axes: """ Plot classical 1D profile plot. @@ -31,26 +31,26 @@ def profiles( Parameters ---------- - results: list or pypesto.Result + results: List of or single `pypesto.Result` after profiling. - ax: list of matplotlib.Axes, optional + ax: List of axes objects to use. - profile_indices: list of integer values + profile_indices: List of integer values specifying which profiles should be plotted. - size: tuple, optional + size: Figure size (width, height) in inches. Is only applied when no ax object is specified. - reference: list, optional + reference: List of reference points for optimization results, containing at least a function value fval. - colors: list, or RGBA, optional + colors: List of colors, or single color. - legends: list or str, optional + legends: Labels for line plots, one label per result object. - x_labels: list of str + x_labels: Labels for parameter value axes (e.g. parameter names). - profile_list_ids: int or list of ints, optional - Index or list of indices of the profile lists to be used for profiling. + profile_list_ids: + Index or list of indices of the profile lists to visualize. ratio_min: Minimum ratio below which to cut off. show_bounds: @@ -58,7 +58,7 @@ def profiles( Returns ------- - ax: matplotlib.Axes + ax: The plot axes. """ # parse input @@ -122,16 +122,16 @@ def profiles( def profiles_lowlevel( - fvals, - ax=None, - size: Tuple[float, float] = (18.5, 6.5), + fvals: Union[float, Sequence[float]], + ax: Optional[Sequence[plt.Axes]] = None, + size: tuple[float, float] = (18.5, 6.5), color=None, legend_text: str = None, x_labels=None, show_bounds: bool = False, - lb_full=None, - ub_full=None, -): + lb_full: Sequence[float] = None, + ub_full: Sequence[float] = None, +) -> list[plt.Axes]: """ Lowlevel routine for profile plotting. @@ -139,21 +139,16 @@ def profiles_lowlevel( Parameters ---------- - fvals: numeric list or array + fvals: Values to plot. - ax: list of matplotlib.Axes, optional + ax: List of axes object to use. - size: tuple, optional - Figure size (width, height) in inches. Is only applied when no ax - object is specified. - size: tuple, optional + size: Figure size (width, height) in inches. Is only applied when no ax object is specified. color: RGBA, optional Color for profiles in plot. - legend_text: str - Label for line plots. - legend_text: List[str] + legend_text: Label for line plots. show_bounds: Whether to show, and extend the plot to, the lower and upper bounds. @@ -164,8 +159,7 @@ def profiles_lowlevel( Returns ------- - ax: matplotlib.Axes - The plot axes. + The plot axes. """ # axes if ax is None: @@ -179,7 +173,7 @@ def profiles_lowlevel( create_new_ax = False # count number of necessary axes - if isinstance(fvals, list): + if isinstance(fvals, Sequence): n_fvals = len(fvals) else: n_fvals = 1 @@ -269,30 +263,30 @@ def profiles_lowlevel( def profile_lowlevel( - fvals, - ax=None, - size: Tuple[float, float] = (18.5, 6.5), + fvals: Sequence[float], + ax: Optional[plt.Axes] = None, + size: tuple[float, float] = (18.5, 6.5), color=None, legend_text: str = None, show_bounds: bool = False, lb: float = None, ub: float = None, -): +) -> plt.Axes: """ Lowlevel routine for plotting one profile, working with a numpy array only. Parameters ---------- - fvals: numeric list or array + fvals: Values to plot. - ax: matplotlib.Axes, optional + ax: Axes object to use. - size: tuple, optional + size: Figure size (width, height) in inches. Is only applied when no ax object is specified. color: RGBA, optional Color for profiles in plot. - legend_text: str + legend_text: Label for line plots. show_bounds: Whether to show, and extend the plot to, the lower and upper bounds. @@ -303,8 +297,7 @@ def profile_lowlevel( Returns ------- - ax: matplotlib.Axes - The plot axes. + The plot axes. """ # parse input fvals = np.asarray(fvals) @@ -372,28 +365,27 @@ def handle_inputs( profile_indices: Sequence[int], profile_list: int, ratio_min: float, -): +) -> list[np.array]: """ Retrieve the values of the profiles to be plotted. Parameters ---------- - result: pypesto.Result + result: Profile result obtained by 'profile.py'. - profile_indices: list of integer values - List of integer values specifying which profiles should be plotted. - profile_list: int, optional + profile_indices: + Sequence of integer values specifying which profiles should be plotted. + profile_list: Index of the profile list to be used for profiling. - ratio_min: int, optional + ratio_min: Exclude values where profile likelihood ratio is smaller than ratio_min. Returns ------- - fvals: numeric list - Including values that need to be plotted. + List of parameter values and ratios that need to be plotted. """ - # extract ratio values values from result + # extract ratio values from result fvals = [] for i_par in range(0, len(result.profile_result.list[profile_list])): if ( @@ -437,8 +429,7 @@ def process_result_list_profiles( List of or single `pypesto.Result` after profiling. profile_list_ids: int or list of ints, optional Index or list of indices of the profile lists to be used for profiling. - colors: list of RGBA colors - colors for + colors: list of RGBA colors for plotting. legends: list of str Legends for plotting @@ -506,6 +497,7 @@ def process_profile_indices( else: for ind in profile_indices: if ind not in plottable_indices: + profile_indices = list(profile_indices) profile_indices.remove(ind) warn( 'Requested to plot profile for parameter index %i, ' diff --git a/pypesto/visualize/reference_points.py b/pypesto/visualize/reference_points.py index 13d8a03d9..c7aae3e5c 100644 --- a/pypesto/visualize/reference_points.py +++ b/pypesto/visualize/reference_points.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Optional, Sequence, Union import numpy as np @@ -14,7 +14,7 @@ class ReferencePoint(dict): Attributes ---------- - x: ndarray + x: Reference parameters. fval: float Function value, fun(x), for reference parameters. @@ -28,7 +28,12 @@ class ReferencePoint(dict): """ def __init__( - self, reference=None, x=None, fval=None, color=None, legend=None + self, + reference: Union[None, dict, tuple, "ReferencePoint"] = None, + x: Optional[Sequence] = None, + fval: Optional[float] = None, + color=None, + legend: Optional[str] = None, ): super().__init__() @@ -104,19 +109,18 @@ def __getattr__(self, key): __delattr__ = dict.__delitem__ -def assign_colors(ref): +def assign_colors(ref: Sequence[ReferencePoint]) -> Sequence[ReferencePoint]: """ Assign colors to reference points, depending on user settings. Parameters ---------- - ref: list of ReferencePoint + ref: Reference points, which need to get their color property filled Returns ------- - ref: list of ReferencePoint - Reference points, which got their color property filled + Reference points, which got their color property filled """ # loop over reference points auto_color_count = 0 @@ -141,7 +145,7 @@ def assign_colors(ref): def create_references( references=None, x=None, fval=None, color=None, legend=None -) -> List[ReferencePoint]: +) -> list[ReferencePoint]: """ Create a list of reference point objects from user inputs.