Skip to content

Commit

Permalink
ruff: tests docstring (#1588)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Mar 6, 2023
1 parent d01f1fd commit 163cb61
Show file tree
Hide file tree
Showing 71 changed files with 267 additions and 121 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ unfixable = ["F401"]
"__init__.py" = ["D100"]
"tests/**" = [
"D102", # Missing docstring in public method
"D103", # Missing docstring in public function
"D401", # todo # First line of docstring should be in imperative mood: ...
"D415", # todo # First line should end with a period, question mark, or exclamation point"
]

[tool.ruff.pydocstyle]
Expand Down
2 changes: 2 additions & 0 deletions tests/integrations/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

@pytest.fixture(scope="session")
def datadir():
"""Global data dir for location of datasets."""
return Path(_PATH_DATASETS)


def pytest_configure(config):
"""Local configuration of pytest."""
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
1 change: 1 addition & 0 deletions tests/integrations/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
"""Context manager to make sure that no warning is raised for a given call."""
with pytest.warns(None) as record:
yield

Expand Down
2 changes: 2 additions & 0 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def update(self, value):


def test_metric_lightning(tmpdir):
"""Test that including a metric inside a lightning module calculates a simple sum correctly."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
Expand Down
15 changes: 9 additions & 6 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
)


def pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
"""Comparison function."""
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time]
target = target.detach().cpu().numpy()
Expand All @@ -54,15 +55,15 @@ def pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
return torch.tensor(mss)


def average_metric(preds, target, metric_func):
def _average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


pesq_original_batch_8k_nb = partial(pesq_original_batch, fs=8000, mode="nb")
pesq_original_batch_16k_nb = partial(pesq_original_batch, fs=16000, mode="nb")
pesq_original_batch_16k_wb = partial(pesq_original_batch, fs=16000, mode="wb")
pesq_original_batch_8k_nb = partial(_pesq_original_batch, fs=8000, mode="nb")
pesq_original_batch_16k_nb = partial(_pesq_original_batch, fs=16000, mode="nb")
pesq_original_batch_16k_wb = partial(_pesq_original_batch, fs=16000, mode="wb")


@pytest.mark.parametrize(
Expand All @@ -88,7 +89,7 @@ def test_pesq(self, preds, target, ref_metric, fs, mode, num_processes, ddp):
preds,
target,
PerceptualEvaluationSpeechQuality,
reference_metric=partial(average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric, metric_func=ref_metric),
metric_args={"fs": fs, "mode": mode, "n_processes": num_processes},
)

Expand Down Expand Up @@ -126,12 +127,14 @@ def test_pesq_half_gpu(self, preds, target, ref_metric, fs, mode):


def test_error_on_different_shape(metric_class=PerceptualEvaluationSpeechQuality):
"""Test that an error is raised on different shapes of input."""
metric = metric_class(16000, "nb")
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))


def test_on_real_audio():
"""Test that metric works as expected on real audio signals."""
rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH)
rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB)
pesq = perceptual_evaluation_speech_quality(torch.from_numpy(deg), torch.from_numpy(ref), rate, "wb")
Expand Down
6 changes: 5 additions & 1 deletion tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def naive_implementation_pit_scipy(
metric_func: Callable,
eval_func: str,
) -> Tuple[Tensor, Tensor]:
"""A naive implementation of `Permutation Invariant Training` based on Scipy.
"""Naive implementation of `Permutation Invariant Training` based on Scipy.
Args:
preds: predictions, shape[batch, spk, time]
Expand Down Expand Up @@ -167,6 +167,7 @@ def test_pit_half_gpu(self, preds, target, ref_metric, metric_func, eval_func):


def test_error_on_different_shape() -> None:
"""Test that error is raised on different shapes of input."""
metric = PermutationInvariantTraining(signal_noise_ratio, "max")
with pytest.raises(
RuntimeError,
Expand All @@ -176,18 +177,21 @@ def test_error_on_different_shape() -> None:


def test_error_on_wrong_eval_func() -> None:
"""Test that error is raised on wrong `eval_func` argument."""
metric = PermutationInvariantTraining(signal_noise_ratio, "xxx")
with pytest.raises(ValueError, match='eval_func can only be "max" or "min"'):
metric(torch.randn(3, 3, 10), torch.randn(3, 3, 10))


def test_error_on_wrong_shape() -> None:
"""Test that error is raised on wrong input shape."""
metric = PermutationInvariantTraining(signal_noise_ratio, "max")
with pytest.raises(ValueError, match="Inputs must be of shape *"):
metric(torch.randn(3), torch.randn(3))


def test_consistency_of_two_implementations() -> None:
"""Test that both backend functions for computing metric (depending on torch version) returns the same result."""
from torchmetrics.functional.audio.pit import (
_find_best_perm_by_exhaustive_method,
_find_best_perm_by_linear_sum_assignment,
Expand Down
14 changes: 8 additions & 6 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)


def sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool = False) -> Tensor:
def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool = False) -> Tensor:
# shape: preds [BATCH_SIZE, spk, Time] , target [BATCH_SIZE, spk, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, spk, Time] , target [NUM_BATCHES*BATCH_SIZE, spk, Time]
target = target.detach().cpu().numpy()
Expand All @@ -55,13 +55,13 @@ def sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
return torch.tensor(mss)


def average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


original_impl_compute_permutation = partial(sdr_original_batch)
original_impl_compute_permutation = partial(_sdr_original_batch)


@pytest.mark.skipif( # TODO: figure out why tests leads to cuda errors on latest torch
Expand All @@ -88,7 +88,7 @@ def test_sdr(self, preds, target, ref_metric, ddp):
preds,
target,
SignalDistortionRatio,
reference_metric=partial(average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric, metric_func=ref_metric),
metric_args={},
)

Expand Down Expand Up @@ -130,14 +130,16 @@ def test_sdr_half_gpu(self, preds, target, ref_metric):


def test_error_on_different_shape(metric_class=SignalDistortionRatio):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))


def test_on_real_audio():
rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH)
rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB)
"""Test that metric works on real audio signal."""
_, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH)
_, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB)
assert torch.allclose(
signal_distortion_ratio(torch.from_numpy(deg), torch.from_numpy(ref)).float(),
torch.tensor(0.2211),
Expand Down
11 changes: 6 additions & 5 deletions tests/unittests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
speechmetrics_sisdr = speechmetrics.load("sisdr")


def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool):
def _speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
if zero_mean:
Expand All @@ -57,14 +57,14 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool):
return torch.tensor(mss)


def average_metric(preds, target, metric_func):
def _average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


speechmetrics_si_sdr_zero_mean = partial(speechmetrics_si_sdr, zero_mean=True)
speechmetrics_si_sdr_no_zero_mean = partial(speechmetrics_si_sdr, zero_mean=False)
speechmetrics_si_sdr_zero_mean = partial(_speechmetrics_si_sdr, zero_mean=True)
speechmetrics_si_sdr_no_zero_mean = partial(_speechmetrics_si_sdr, zero_mean=False)


@pytest.mark.parametrize(
Expand All @@ -86,7 +86,7 @@ def test_si_sdr(self, preds, target, ref_metric, zero_mean, ddp):
preds,
target,
ScaleInvariantSignalDistortionRatio,
reference_metric=partial(average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric, metric_func=ref_metric),
metric_args={"zero_mean": zero_mean},
)

Expand Down Expand Up @@ -123,6 +123,7 @@ def test_si_sdr_half_gpu(self, preds, target, ref_metric, zero_mean):


def test_error_on_different_shape(metric_class=ScaleInvariantSignalDistortionRatio):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
9 changes: 5 additions & 4 deletions tests/unittests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
speechmetrics_sisdr = speechmetrics.load("sisdr")


def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True):
def _speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
if zero_mean:
Expand All @@ -57,7 +57,7 @@ def speechmetrics_si_sdr(preds: Tensor, target: Tensor, zero_mean: bool = True):
return torch.tensor(mss)


def average_metric(preds, target, metric_func):
def _average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()
Expand All @@ -66,7 +66,7 @@ def average_metric(preds, target, metric_func):
@pytest.mark.parametrize(
"preds, target, ref_metric",
[
(inputs.preds, inputs.target, speechmetrics_si_sdr),
(inputs.preds, inputs.target, _speechmetrics_si_sdr),
],
)
class TestSISNR(MetricTester):
Expand All @@ -81,7 +81,7 @@ def test_si_snr(self, preds, target, ref_metric, ddp):
preds,
target,
ScaleInvariantSignalNoiseRatio,
reference_metric=partial(average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric, metric_func=ref_metric),
)

def test_si_snr_functional(self, preds, target, ref_metric):
Expand Down Expand Up @@ -114,6 +114,7 @@ def test_si_snr_half_gpu(self, preds, target, ref_metric):


def test_error_on_different_shape(metric_class=ScaleInvariantSignalNoiseRatio):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
11 changes: 6 additions & 5 deletions tests/unittests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)


def bss_eval_images_snr(preds: Tensor, target: Tensor, zero_mean: bool):
def _bss_eval_images_snr(preds: Tensor, target: Tensor, zero_mean: bool):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
if zero_mean:
Expand All @@ -57,14 +57,14 @@ def bss_eval_images_snr(preds: Tensor, target: Tensor, zero_mean: bool):
return torch.tensor(mss)


def average_metric(preds: Tensor, target: Tensor, metric_func: Callable):
def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


mireval_snr_zeromean = partial(bss_eval_images_snr, zero_mean=True)
mireval_snr_nozeromean = partial(bss_eval_images_snr, zero_mean=False)
mireval_snr_zeromean = partial(_bss_eval_images_snr, zero_mean=True)
mireval_snr_nozeromean = partial(_bss_eval_images_snr, zero_mean=False)


@pytest.mark.parametrize(
Expand All @@ -86,7 +86,7 @@ def test_snr(self, preds, target, ref_metric, zero_mean, ddp):
preds,
target,
SignalNoiseRatio,
reference_metric=partial(average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric, metric_func=ref_metric),
metric_args={"zero_mean": zero_mean},
)

Expand Down Expand Up @@ -123,6 +123,7 @@ def test_snr_half_gpu(self, preds, target, ref_metric, zero_mean):


def test_error_on_different_shape(metric_class=SignalNoiseRatio):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
16 changes: 9 additions & 7 deletions tests/unittests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)


def stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool):
def _stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool):
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time]
target = target.detach().cpu().numpy()
Expand All @@ -52,16 +52,16 @@ def stoi_original_batch(preds: Tensor, target: Tensor, fs: int, extended: bool):
return torch.tensor(mss)


def average_metric(preds, target, metric_func):
def _average_metric(preds, target, metric_func):
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
return metric_func(preds, target).mean()


stoi_original_batch_8k_ext = partial(stoi_original_batch, fs=8000, extended=True)
stoi_original_batch_16k_ext = partial(stoi_original_batch, fs=16000, extended=True)
stoi_original_batch_8k_noext = partial(stoi_original_batch, fs=8000, extended=False)
stoi_original_batch_16k_noext = partial(stoi_original_batch, fs=16000, extended=False)
stoi_original_batch_8k_ext = partial(_stoi_original_batch, fs=8000, extended=True)
stoi_original_batch_16k_ext = partial(_stoi_original_batch, fs=16000, extended=True)
stoi_original_batch_8k_noext = partial(_stoi_original_batch, fs=8000, extended=False)
stoi_original_batch_16k_noext = partial(_stoi_original_batch, fs=16000, extended=False)


@pytest.mark.parametrize(
Expand All @@ -85,7 +85,7 @@ def test_stoi(self, preds, target, ref_metric, fs, extended, ddp):
preds,
target,
ShortTimeObjectiveIntelligibility,
reference_metric=partial(average_metric, metric_func=ref_metric),
reference_metric=partial(_average_metric, metric_func=ref_metric),
metric_args={"fs": fs, "extended": extended},
)

Expand Down Expand Up @@ -122,12 +122,14 @@ def test_stoi_half_gpu(self, preds, target, ref_metric, fs, extended):


def test_error_on_different_shape(metric_class=ShortTimeObjectiveIntelligibility):
"""Test that error is raised on different shapes of input."""
metric = metric_class(16000)
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))


def test_on_real_audio():
"""Test that metric works on real audio signal."""
rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH)
rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB)
assert torch.allclose(
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@


def compare_mean(values, weights):
"""Reference implementation for mean aggregation."""
"""Baseline implementation for mean aggregation."""
return np.average(values.numpy(), weights=weights)


def compare_sum(values, weights):
"""Reference implementation for sum aggregation."""
"""Baseline implementation for sum aggregation."""
return np.sum(values.numpy())


def compare_min(values, weights):
"""Reference implementation for min aggregation."""
"""Baseline implementation for min aggregation."""
return np.min(values.numpy())


def compare_max(values, weights):
"""Reference implementation for max aggregation."""
"""Baseline implementation for max aggregation."""
return np.max(values.numpy())


Expand Down
Loading

0 comments on commit 163cb61

Please sign in to comment.