Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add checks for input types in Task #64

Merged
merged 7 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 96 additions & 11 deletions annubes/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,29 @@ class Task(TaskSettingsMixin):
max_sequential: int | None = None

def __post_init__(self):
self._task_settings = vars(self).copy()
# Check input parameters
## Check str
self._check_str("name", self.name)
## Check dict
self._check_session("session", self.session)
## Check float
for intensity in self.stim_intensities:
self._check_float_positive("stim_intensities", intensity)
self._check_float_positive("catch_prob", self.catch_prob, prob=True)
self._check_float_positive("fix_intensity", self.fix_intensity)
for intensity in self.output_behavior:
self._check_float_positive("output_behavior", intensity)
self._check_float_positive("noise_std", self.noise_std)
## Check int
self._check_time_vars()
if self.max_sequential is not None:
gcroci2 marked this conversation as resolved.
Show resolved Hide resolved
self._check_int_positive("max_sequential", self.max_sequential, strict=True)
self._check_int_positive("n_outputs", self.n_outputs, strict=True)
## Check bool
self._check_bool("shuffle_trials", self.shuffle_trials)
self._check_bool("scaling", self.scaling)

if not 0 <= self.catch_prob <= 1:
msg = "`catch_prob` must be between 0 and 1."
raise ValueError(msg)
self._task_settings = vars(self).copy()

sum_session_vals = sum(self.session.values())
self._session = {}
Expand All @@ -107,13 +125,6 @@ def __post_init__(self):
if not self.shuffle_trials:
self._session = OrderedDict(self._session)

if not self.dt > 0:
msg = "`dt` must be greater than 0."
raise ValueError(msg)
if not self.tau > 0:
msg = "`tau` must be greater than 0."
raise ValueError(msg)

# Derived and other attributes
self._modalities = set(dict.fromkeys(char for string in self._session for char in string))
self._n_inputs = len(self._modalities) + 1 # includes start cue
Expand All @@ -138,6 +149,18 @@ def generate_trials(
`generate_trials()` method's call ("ntrials", "random_state"), and the generated data ("modality_seq",
"time", "phases", "inputs", "outputs").
"""
# Check input parameters
if isinstance(ntrials, tuple):
if len(ntrials) != 2: # noqa: PLR2004
msg = "`ntrials` must be an integer or a tuple of two integers."
raise ValueError(msg)
self._check_int_positive("ntrials", ntrials[0], strict=True)
self._check_int_positive("ntrials", ntrials[1], strict=True)
else:
self._check_int_positive("ntrials", ntrials, strict=True)
if random_seed is not None:
self._check_int_positive("random_seed", random_seed, strict=False)

# Set random state
if random_seed is None:
rng = np.random.default_rng(random_seed)
Expand Down Expand Up @@ -178,6 +201,9 @@ def plot_trials(self, n_plots: int = 1) -> go.Figure:
Returns:
go.Figure: Plotly figure of trial results.
"""
# Check input parameters
self._check_int_positive("n_plots", n_plots, strict=True)

if (p := n_plots) > (t := self._ntrials):
msg = f"Number of plots requested ({p}) exceeds number of trials ({t}). Will plot all trials."
warnings.warn(msg, stacklevel=2)
Expand Down Expand Up @@ -265,6 +291,65 @@ def plot_trials(self, n_plots: int = 1) -> go.Figure:
fig.update_layout(height=1300, width=900, title_text="Trials")
return fig

def _check_str(self, name: str, value: Any) -> None: # noqa: ANN401
if not isinstance(value, str):
msg = f"`{name}` must be a string"
raise TypeError(msg)
gcroci2 marked this conversation as resolved.
Show resolved Hide resolved

def _check_session(self, name: str, value: Any) -> None: # noqa: ANN401
if not isinstance(value, dict):
msg = f"`{name}` must be a dictionary."
raise TypeError(msg)
if not all(isinstance(k, str) for k in value):
msg = f"Keys of `{name}` must be strings."
raise TypeError(msg)
if not all(isinstance(v, (float | int)) for v in value.values()):
msg = f"Values of `{name}` must be floats or integers."
raise TypeError(msg)

def _check_float_positive(self, name: str, value: Any, prob: bool = False) -> None: # noqa: ANN401
if not isinstance(value, float | int):
msg = f"`{name}` must be a float or integer."
raise TypeError(msg)
if not value >= 0:
msg = f"`{name}` must be greater than or equal to 0."
raise ValueError(msg)
if prob and not 0 <= value <= 1:
msg = f"`{name}` must be between 0 and 1."
raise ValueError(msg)
gcroci2 marked this conversation as resolved.
Show resolved Hide resolved

def _check_int_positive(self, name: str, value: Any, strict: bool) -> None: # noqa: ANN401
gcroci2 marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(value, int | np.int_):
msg = f"`{name}` must be an integer."
raise TypeError(msg)
if strict:
if not value > 0:
msg = f"`{name}` must be greater than 0."
raise ValueError(msg)
elif not value >= 0:
msg = f"`{name}` must be greater than or equal to 0."
raise ValueError(msg)

def _check_bool(self, name: str, value: Any) -> None: # noqa: ANN401
if not isinstance(value, bool):
msg = f"`{name}` must be a boolean."
raise TypeError(msg)

def _check_time_vars(self) -> None:
strictly_positive = {
"stim_time": (self.stim_time, True),
"dt": (self.dt, True),
"tau": (self.tau, True),
"fix_time": (self.fix_time, False),
}
for name, value in strictly_positive.items():
self._check_int_positive(name, value[0], strict=value[1])
if isinstance(self.iti, tuple):
for iti in self.iti:
self._check_int_positive("iti", iti, strict=False)
else:
self._check_int_positive("iti", self.iti, strict=False)

def _build_trials_seq(self) -> NDArray[np.str_]:
"""Generate a sequence of modalities."""
# Extract keys and probabilities from the dictionary
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ isort.known-first-party = ["annubes"]
"S101", # Use of `assert` detected
"ANN201", # Missing return type
"D103", # Missing function docstring
"SLF001", # private member access
"SLF001", # Private member access
"ANN401", # Function arguments annotated with too generic `Any` type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of this. I get that it's not pretty to turn it off many times over on individual functions, but I do think it's a good practice to avoid using Any in most cases.

Alternatively, at the top of the test file in question, you can state AnyType = Any #noqa: ANN401, RUF100 as a type alias and then use AnyType wherever you don't want the linter to check without having to noqa each instance.
(the RUF100 ignoring isn't strictly necessary right now, but that is due to a bug in RUF, which I assume will be fixed at some point and then we won't have to fix it post-hoc).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the tests, you do want to test Any types. In the code-base I agree, that wouldn't be good. Very likely, we will need to state the exception at the beginning of each test file, so in my opinion is better to just set it in the toml directly on all tests files. This won't disable the error in the code-base, of course.

]
"docs/*" = ["ALL"]
157 changes: 157 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
from typing import Any

import numpy as np
import plotly.graph_objects as go
Expand All @@ -17,6 +18,139 @@ def task():
)


@pytest.mark.parametrize(
"name",
[
NAME,
pytest.param(5, marks=pytest.mark.xfail(raises=TypeError)),
],
)
def test_post_init_check_str(name: Any):
Task(name=name)


@pytest.mark.parametrize(
"session",
[
{"v": 0.5, "a": 0.5},
pytest.param(5, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param({1: 0.5, "a": 0.5}, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param({"v": "a", "a": 0.5}, marks=pytest.mark.xfail(raises=TypeError)),
],
)
def test_post_init_check_session(session: Any):
Task(name=NAME, session=session)


@pytest.mark.parametrize(
("stim_intensities", "catch_prob", "fix_intensity", "output_behavior", "noise_std"),
[
([0.8, 0.9, 1], 0.5, 0, [0, 1], 0.01),
pytest.param([0.8, "a"], 0.5, 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param([0.8, -1], 0.5, 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param([0.8, 0.9, 1], "a", 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param([0.8, 0.9, 1], 5, 0, [0, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param([0.8, 0.9, 1], 0.5, "a", [0, 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param([0.8, 0.9, 1], 0.5, -1, [0, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param([0.8, 0.9, 1], 0.5, 0, ["0", 1], 0.01, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param([0.8, 0.9, 1], 0.5, 0, [-1, 1], 0.01, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param([0.8, 0.9, 1], 0.5, 0, [0, 1], "a", marks=pytest.mark.xfail(raises=TypeError)),
pytest.param([0.8, 0.9, 1], 0.5, 0, [0, 1], -1, marks=pytest.mark.xfail(raises=ValueError)),
],
)
def test_post_init_check_float_positive(
stim_intensities: Any,
catch_prob: Any,
fix_intensity: Any,
output_behavior: Any,
noise_std: Any,
):
Task(
name=NAME,
stim_intensities=stim_intensities,
catch_prob=catch_prob,
fix_intensity=fix_intensity,
output_behavior=output_behavior,
noise_std=noise_std,
)


@pytest.mark.parametrize(
("stim_time", "dt", "tau", "fix_time", "iti"),
[
(1000, 20, 100, 100, 0),
pytest.param("a", 20, 100, 100, 0, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(0, 20, 100, 100, 0, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(1000, "a", 100, 100, 0, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(1000, 0, 100, 100, 0, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(1000, 20, "a", 100, 0, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(1000, 20, 0, 100, 0, marks=pytest.mark.xfail(raises=ValueError)),
(1000, 20, 100, 0, 0),
pytest.param(1000, 20, 100, "a", 0, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(1000, 20, 100, -1, 0, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(1000, 20, 100, 100, "a", marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(1000, 20, 100, 100, -1, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(1000, 20, 100, 100, ("a", 0), marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(1000, 20, 100, 100, (-1, 0), marks=pytest.mark.xfail(raises=ValueError)),
],
)
def test_post_init_check_time_vars(
stim_time: Any,
dt: Any,
tau: Any,
fix_time: Any,
iti: Any,
):
Task(
name=NAME,
stim_time=stim_time,
dt=dt,
tau=tau,
fix_time=fix_time,
iti=iti,
)


@pytest.mark.parametrize(
("max_sequential", "n_outputs"),
[
(None, 2),
pytest.param("a", 2, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(0, 2, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(None, "a", marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(None, 0, marks=pytest.mark.xfail(raises=ValueError)),
],
)
def test_post_init_check_other_int_positive(
max_sequential: Any,
n_outputs: Any,
):
Task(
name=NAME,
max_sequential=max_sequential,
n_outputs=n_outputs,
)


@pytest.mark.parametrize(
("shuffle_trials", "scaling"),
[
(True, True),
pytest.param("a", True, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(False, "a", marks=pytest.mark.xfail(raises=TypeError)),
],
)
def test_post_init_check_bool(
shuffle_trials: Any,
scaling: Any,
):
Task(
name=NAME,
shuffle_trials=shuffle_trials,
scaling=scaling,
)


@pytest.mark.parametrize(
("session", "shuffle_trials", "expected_dict", "expected_type"),
[
Expand Down Expand Up @@ -200,6 +334,29 @@ def test_minmaxscaler():
assert all(task._outputs[n_trial].min() >= 0 and task._outputs[n_trial].max() <= 1 for n_trial in trial_indices)


@pytest.mark.parametrize(
("ntrials", "random_seed"),
[
(20, None),
pytest.param("a", None, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(0, None, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(0.5, None, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param((30, 40, 50), None, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(("40", 50), None, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param((40, "50"), None, marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(20, "a", marks=pytest.mark.xfail(raises=TypeError)),
pytest.param(20, -1, marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(20, 0.5, marks=pytest.mark.xfail(raises=TypeError)),
],
)
def test_generate_trials_check(
task: Task,
ntrials: Any,
random_seed: Any,
):
_ = task.generate_trials(ntrials=ntrials, random_seed=random_seed)


@pytest.mark.parametrize(
"ntrials",
[NTRIALS, (100, 200)],
Expand Down
Loading