From 78452c6b2746439d28ea7dd17f889e3ccebb3a57 Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Thu, 28 Mar 2024 13:50:25 +0100 Subject: [PATCH 1/6] add checks for time vars --- annubes/task.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/annubes/task.py b/annubes/task.py index 23a0122..6676427 100644 --- a/annubes/task.py +++ b/annubes/task.py @@ -94,6 +94,9 @@ class Task(TaskSettingsMixin): max_sequential: int | None = None def __post_init__(self): + # Check time variables + self._check_time_vars() + self._task_settings = vars(self).copy() if not 0 <= self.catch_prob <= 1: @@ -107,13 +110,6 @@ def __post_init__(self): if not self.shuffle_trials: self._session = OrderedDict(self._session) - if not self.dt > 0: - msg = "`dt` must be greater than 0." - raise ValueError(msg) - if not self.tau > 0: - msg = "`tau` must be greater than 0." - raise ValueError(msg) - # Derived and other attributes self._modalities = set(dict.fromkeys(char for string in self._session for char in string)) self._n_inputs = len(self._modalities) + 1 # includes start cue @@ -265,6 +261,33 @@ def plot_trials(self, n_plots: int = 1) -> go.Figure: fig.update_layout(height=1300, width=900, title_text="Trials") return fig + def _check_int_positive(self, name: str, value: Any, strict: bool) -> None: # noqa: ANN401 + if not isinstance(value, int): + msg = f"`{name}` must be an integer." + raise TypeError(msg) + if strict: + if not value > 0: + msg = f"`{name}` must be greater than 0." + raise ValueError(msg) + elif not value >= 0: + msg = f"`{name}` must be greater than or equal to 0." + raise ValueError(msg) + + def _check_time_vars(self) -> None: + strictly_positive = { + "stim_time": (self.stim_time, True), + "dt": (self.dt, True), + "tau": (self.tau, True), + "fix_time": (self.fix_time, False), + } + for name, value in strictly_positive.items(): + self._check_int_positive(name, value[0], value[1]) + if isinstance(self.iti, tuple): + for iti in self.iti: + self._check_int_positive("iti", iti, False) + else: + self._check_int_positive("iti", self.iti, False) + def _build_trials_seq(self) -> NDArray[np.str_]: """Generate a sequence of modalities.""" # Extract keys and probabilities from the dictionary From 20088d0ecc688a030df5b62ad2e2736b7bb2bf22 Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Thu, 28 Mar 2024 16:29:23 +0100 Subject: [PATCH 2/6] add checks for all remaining input parameters, for all public methods --- annubes/task.py | 78 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/annubes/task.py b/annubes/task.py index 6676427..27fbb8f 100644 --- a/annubes/task.py +++ b/annubes/task.py @@ -94,15 +94,30 @@ class Task(TaskSettingsMixin): max_sequential: int | None = None def __post_init__(self): - # Check time variables + # Check input parameters + ## Check str + self._check_str("name", self.name) + ## Check dict + self._check_session("session", self.session) + ## Check float + for intensity in self.stim_intensities: + self._check_float_positive("stim_intensities", intensity) + self._check_float_positive("catch_prob", self.catch_prob, prob=True) + self._check_float_positive("fix_intensity", self.fix_intensity) + for intensity in self.output_behavior: + self._check_float_positive("output_behavior", intensity) + self._check_float_positive("noise_std", self.noise_std) + ## Check int self._check_time_vars() + if self.max_sequential is not None: + self._check_int_positive("max_sequential", self.max_sequential, strict=True) + self._check_int_positive("n_outputs", self.n_outputs, strict=True) + ## Check bool + self._check_bool("shuffle_trials", self.shuffle_trials) + self._check_bool("scaling", self.scaling) self._task_settings = vars(self).copy() - if not 0 <= self.catch_prob <= 1: - msg = "`catch_prob` must be between 0 and 1." - raise ValueError(msg) - sum_session_vals = sum(self.session.values()) self._session = {} for i in self.session: @@ -134,6 +149,18 @@ def generate_trials( `generate_trials()` method's call ("ntrials", "random_state"), and the generated data ("modality_seq", "time", "phases", "inputs", "outputs"). """ + # Check input parameters + if isinstance(ntrials, tuple): + if len(ntrials) != 2: # noqa: PLR2004 + msg = "`ntrials` must be an integer or a tuple of two integers." + raise ValueError(msg) + self._check_int_positive("ntrials", ntrials[0], strict=True) + self._check_int_positive("ntrials", ntrials[1], strict=True) + else: + self._check_int_positive("ntrials", ntrials, strict=True) + if random_seed is not None: + self._check_int_positive("random_seed", random_seed, strict=False) + # Set random state if random_seed is None: rng = np.random.default_rng(random_seed) @@ -174,6 +201,9 @@ def plot_trials(self, n_plots: int = 1) -> go.Figure: Returns: go.Figure: Plotly figure of trial results. """ + # Check input parameters + self._check_int_positive("n_plots", n_plots, strict=True) + if (p := n_plots) > (t := self._ntrials): msg = f"Number of plots requested ({p}) exceeds number of trials ({t}). Will plot all trials." warnings.warn(msg, stacklevel=2) @@ -261,6 +291,33 @@ def plot_trials(self, n_plots: int = 1) -> go.Figure: fig.update_layout(height=1300, width=900, title_text="Trials") return fig + def _check_str(self, name: str, value: Any) -> None: # noqa: ANN401 + if not isinstance(value, str): + msg = f"`{name}` must be a string" + raise TypeError(msg) + + def _check_session(self, name: str, value: Any) -> None: # noqa: ANN401 + if not isinstance(value, dict): + msg = f"`{name}` must be a dictionary." + raise TypeError(msg) + if not all(isinstance(k, str) for k in value): + msg = f"Keys of `{name}` must be strings." + raise TypeError(msg) + if not all(isinstance(v, (float | int)) for v in value.values()): + msg = f"Values of `{name}` must be floats or integers." + raise TypeError(msg) + + def _check_float_positive(self, name: str, value: Any, prob: bool = False) -> None: # noqa: ANN401 + if not isinstance(value, float | int): + msg = f"`{name}` must be a float or integer." + raise TypeError(msg) + if not value >= 0: + msg = f"`{name}` must be greater than or equal to 0." + raise ValueError(msg) + if prob and not 0 <= value <= 1: + msg = f"`{name}` must be between 0 and 1." + raise ValueError(msg) + def _check_int_positive(self, name: str, value: Any, strict: bool) -> None: # noqa: ANN401 if not isinstance(value, int): msg = f"`{name}` must be an integer." @@ -273,6 +330,11 @@ def _check_int_positive(self, name: str, value: Any, strict: bool) -> None: # n msg = f"`{name}` must be greater than or equal to 0." raise ValueError(msg) + def _check_bool(self, name: str, value: Any) -> None: # noqa: ANN401 + if not isinstance(value, bool): + msg = f"`{name}` must be a boolean." + raise TypeError(msg) + def _check_time_vars(self) -> None: strictly_positive = { "stim_time": (self.stim_time, True), @@ -281,12 +343,12 @@ def _check_time_vars(self) -> None: "fix_time": (self.fix_time, False), } for name, value in strictly_positive.items(): - self._check_int_positive(name, value[0], value[1]) + self._check_int_positive(name, value[0], strict=value[1]) if isinstance(self.iti, tuple): for iti in self.iti: - self._check_int_positive("iti", iti, False) + self._check_int_positive("iti", iti, strict=False) else: - self._check_int_positive("iti", self.iti, False) + self._check_int_positive("iti", self.iti, strict=False) def _build_trials_seq(self) -> NDArray[np.str_]: """Generate a sequence of modalities.""" From fb8436a06195cd33e158e6f16bfb2c58cf6ef7b7 Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Thu, 28 Mar 2024 16:35:00 +0100 Subject: [PATCH 3/6] fix tests --- annubes/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/annubes/task.py b/annubes/task.py index 27fbb8f..13d2f19 100644 --- a/annubes/task.py +++ b/annubes/task.py @@ -319,7 +319,7 @@ def _check_float_positive(self, name: str, value: Any, prob: bool = False) -> No raise ValueError(msg) def _check_int_positive(self, name: str, value: Any, strict: bool) -> None: # noqa: ANN401 - if not isinstance(value, int): + if not isinstance(value, int | np.int_): msg = f"`{name}` must be an integer." raise TypeError(msg) if strict: From 3c0c02e17a7ac2c68eb8928693a3d4ade9e3bbda Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Fri, 29 Mar 2024 14:46:18 +0100 Subject: [PATCH 4/6] add tests for checks --- tests/test_task.py | 157 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/tests/test_task.py b/tests/test_task.py index 93dd636..c402808 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from typing import Any import numpy as np import plotly.graph_objects as go @@ -17,6 +18,139 @@ def task(): ) +@pytest.mark.parametrize( + "name", + [ + NAME, + pytest.param(5, marks=pytest.mark.xfail(raises=TypeError)), + ], +) +def test_post_init_check_str(name: Any): # noqa: ANN401 + Task(name=name) + + +@pytest.mark.parametrize( + "session", + [ + {"v": 0.5, "a": 0.5}, + pytest.param(5, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param({1: 0.5, "a": 0.5}, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param({"v": "a", "a": 0.5}, marks=pytest.mark.xfail(raises=TypeError)), + ], +) +def test_post_init_check_session(session: Any): # noqa: ANN401 + Task(name=NAME, session=session) + + +@pytest.mark.parametrize( + ("stim_intensities", "catch_prob", "fix_intensity", "output_behavior", "noise_std"), + [ + ([0.8, 0.9, 1], 0.5, 0, [0, 1], 0.01), + pytest.param([0.8, "a"], 0.5, 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param([0.8, -1], 0.5, 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param([0.8, 0.9, 1], "a", 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param([0.8, 0.9, 1], 5, 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param([0.8, 0.9, 1], 0.5, "a", [0, 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param([0.8, 0.9, 1], 0.5, -1, [0, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param([0.8, 0.9, 1], 0.5, 0, ["0", 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param([0.8, 0.9, 1], 0.5, 0, [-1, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param([0.8, 0.9, 1], 0.5, 0, [0, 1], "a", marks=pytest.mark.xfail(raises=TypeError)), + pytest.param([0.8, 0.9, 1], 0.5, 0, [0, 1], -1, marks=pytest.mark.xfail(raises=ValueError)), + ], +) +def test_post_init_check_float_positive( + stim_intensities: Any, # noqa: ANN401 + catch_prob: Any, # noqa: ANN401 + fix_intensity: Any, # noqa: ANN401 + output_behavior: Any, # noqa: ANN401 + noise_std: Any, # noqa: ANN401 +): + Task( + name=NAME, + stim_intensities=stim_intensities, + catch_prob=catch_prob, + fix_intensity=fix_intensity, + output_behavior=output_behavior, + noise_std=noise_std, + ) + + +@pytest.mark.parametrize( + ("stim_time", "dt", "tau", "fix_time", "iti"), + [ + (1000, 20, 100, 100, 0), + pytest.param("a", 20, 100, 100, 0, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(0, 20, 100, 100, 0, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(1000, "a", 100, 100, 0, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(1000, 0, 100, 100, 0, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(1000, 20, "a", 100, 0, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(1000, 20, 0, 100, 0, marks=pytest.mark.xfail(raises=ValueError)), + (1000, 20, 100, 0, 0), + pytest.param(1000, 20, 100, "a", 0, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(1000, 20, 100, -1, 0, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(1000, 20, 100, 100, "a", marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(1000, 20, 100, 100, -1, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(1000, 20, 100, 100, ("a", 0), marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(1000, 20, 100, 100, (-1, 0), marks=pytest.mark.xfail(raises=ValueError)), + ], +) +def test_post_init_check_time_vars( + stim_time: Any, # noqa: ANN401 + dt: Any, # noqa: ANN401 + tau: Any, # noqa: ANN401 + fix_time: Any, # noqa: ANN401 + iti: Any, # noqa: ANN401 +): + Task( + name=NAME, + stim_time=stim_time, + dt=dt, + tau=tau, + fix_time=fix_time, + iti=iti, + ) + + +@pytest.mark.parametrize( + ("max_sequential", "n_outputs"), + [ + (None, 2), + pytest.param("a", 2, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(0, 2, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(None, "a", marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(None, 0, marks=pytest.mark.xfail(raises=ValueError)), + ], +) +def test_post_init_check_other_int_positive( + max_sequential: Any, # noqa: ANN401 + n_outputs: Any, # noqa: ANN401 +): + Task( + name=NAME, + max_sequential=max_sequential, + n_outputs=n_outputs, + ) + + +@pytest.mark.parametrize( + ("shuffle_trials", "scaling"), + [ + (True, True), + pytest.param("a", True, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(False, "a", marks=pytest.mark.xfail(raises=TypeError)), + ], +) +def test_post_init_check_bool( + shuffle_trials: Any, # noqa: ANN401 + scaling: Any, # noqa: ANN401 +): + Task( + name=NAME, + shuffle_trials=shuffle_trials, + scaling=scaling, + ) + + @pytest.mark.parametrize( ("session", "shuffle_trials", "expected_dict", "expected_type"), [ @@ -200,6 +334,29 @@ def test_minmaxscaler(): assert all(task._outputs[n_trial].min() >= 0 and task._outputs[n_trial].max() <= 1 for n_trial in trial_indices) +@pytest.mark.parametrize( + ("ntrials", "random_seed"), + [ + (20, None), + pytest.param("a", None, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(0, None, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(0.5, None, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param((30, 40, 50), None, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(("40", 50), None, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param((40, "50"), None, marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(20, "a", marks=pytest.mark.xfail(raises=TypeError)), + pytest.param(20, -1, marks=pytest.mark.xfail(raises=ValueError)), + pytest.param(20, 0.5, marks=pytest.mark.xfail(raises=TypeError)), + ], +) +def test_generate_trials_check( + task: Task, + ntrials: Any, # noqa: ANN401 + random_seed: Any, # noqa: ANN401 +): + _ = task.generate_trials(ntrials=ntrials, random_seed=random_seed) + + @pytest.mark.parametrize( "ntrials", [NTRIALS, (100, 200)], From 6d8c1eaa88d4fefc1ada5af4844b32201e435a7d Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Fri, 29 Mar 2024 14:50:48 +0100 Subject: [PATCH 5/6] ignore ANN401 in tests folder --- pyproject.toml | 3 ++- tests/test_task.py | 36 ++++++++++++++++++------------------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6962707..679c223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,7 @@ isort.known-first-party = ["annubes"] "S101", # Use of `assert` detected "ANN201", # Missing return type "D103", # Missing function docstring - "SLF001", # private member access + "SLF001", # Private member access + "ANN401", # Function arguments annotated with too generic `Any` type ] "docs/*" = ["ALL"] diff --git a/tests/test_task.py b/tests/test_task.py index c402808..ec5ac02 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -25,7 +25,7 @@ def task(): pytest.param(5, marks=pytest.mark.xfail(raises=TypeError)), ], ) -def test_post_init_check_str(name: Any): # noqa: ANN401 +def test_post_init_check_str(name: Any): Task(name=name) @@ -38,7 +38,7 @@ def test_post_init_check_str(name: Any): # noqa: ANN401 pytest.param({"v": "a", "a": 0.5}, marks=pytest.mark.xfail(raises=TypeError)), ], ) -def test_post_init_check_session(session: Any): # noqa: ANN401 +def test_post_init_check_session(session: Any): Task(name=NAME, session=session) @@ -59,11 +59,11 @@ def test_post_init_check_session(session: Any): # noqa: ANN401 ], ) def test_post_init_check_float_positive( - stim_intensities: Any, # noqa: ANN401 - catch_prob: Any, # noqa: ANN401 - fix_intensity: Any, # noqa: ANN401 - output_behavior: Any, # noqa: ANN401 - noise_std: Any, # noqa: ANN401 + stim_intensities: Any, + catch_prob: Any, + fix_intensity: Any, + output_behavior: Any, + noise_std: Any, ): Task( name=NAME, @@ -95,11 +95,11 @@ def test_post_init_check_float_positive( ], ) def test_post_init_check_time_vars( - stim_time: Any, # noqa: ANN401 - dt: Any, # noqa: ANN401 - tau: Any, # noqa: ANN401 - fix_time: Any, # noqa: ANN401 - iti: Any, # noqa: ANN401 + stim_time: Any, + dt: Any, + tau: Any, + fix_time: Any, + iti: Any, ): Task( name=NAME, @@ -122,8 +122,8 @@ def test_post_init_check_time_vars( ], ) def test_post_init_check_other_int_positive( - max_sequential: Any, # noqa: ANN401 - n_outputs: Any, # noqa: ANN401 + max_sequential: Any, + n_outputs: Any, ): Task( name=NAME, @@ -141,8 +141,8 @@ def test_post_init_check_other_int_positive( ], ) def test_post_init_check_bool( - shuffle_trials: Any, # noqa: ANN401 - scaling: Any, # noqa: ANN401 + shuffle_trials: Any, + scaling: Any, ): Task( name=NAME, @@ -351,8 +351,8 @@ def test_minmaxscaler(): ) def test_generate_trials_check( task: Task, - ntrials: Any, # noqa: ANN401 - random_seed: Any, # noqa: ANN401 + ntrials: Any, + random_seed: Any, ): _ = task.generate_trials(ntrials=ntrials, random_seed=random_seed) From 20dcd1899c51b403e1d359443593b80e6b839d4c Mon Sep 17 00:00:00 2001 From: gcroci2 Date: Mon, 8 Apr 2024 17:00:54 +0200 Subject: [PATCH 6/6] change order for checking float prob --- annubes/task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/annubes/task.py b/annubes/task.py index eb63ded..8361618 100644 --- a/annubes/task.py +++ b/annubes/task.py @@ -311,12 +311,12 @@ def _check_float_positive(self, name: str, value: Any, prob: bool = False) -> No if not isinstance(value, float | int): msg = f"`{name}` must be a float or integer." raise TypeError(msg) - if not value >= 0: - msg = f"`{name}` must be greater than or equal to 0." - raise ValueError(msg) if prob and not 0 <= value <= 1: msg = f"`{name}` must be between 0 and 1." raise ValueError(msg) + if not value >= 0: + msg = f"`{name}` must be greater than or equal to 0." + raise ValueError(msg) def _check_int_positive(self, name: str, value: Any, strict: bool) -> None: # noqa: ANN401 if not isinstance(value, int | np.int_):