Skip to content

Commit

Permalink
Test for longer-than-65,536 inputs for all model types (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson authored Dec 11, 2024
1 parent c543b56 commit b5de1f4
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 77 deletions.
32 changes: 0 additions & 32 deletions tests/test_nam/test_models/_convolutional.py

This file was deleted.

49 changes: 43 additions & 6 deletions tests/test_nam/test_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# Created Date: Saturday June 4th 2022
# Author: Steven Atkinson (steven@atkinson.mn)

import abc
from pathlib import Path
from tempfile import TemporaryDirectory
import abc as _abc
from pathlib import Path as _Path
from tempfile import TemporaryDirectory as _TemporaryDirectory

import pytest as _pytest
import torch as _torch

class Base(abc.ABC):

class Base(_abc.ABC):
@classmethod
def setup_class(cls, C, args=None, kwargs=None):
cls._C = C
Expand All @@ -20,8 +23,42 @@ def test_init(self, args=None, kwargs=None):

def test_export(self, args=None, kwargs=None):
model = self._construct(args=args, kwargs=kwargs)
with TemporaryDirectory() as tmpdir:
model.export(Path(tmpdir))
with _TemporaryDirectory() as tmpdir:
model.export(_Path(tmpdir))

@_pytest.mark.parametrize(
"device",
(
_pytest.param(
"cuda",
marks=_pytest.mark.skipif(
not _torch.cuda.is_available(), reason="CUDA-specific test"
),
),
_pytest.param(
"mps",
marks=_pytest.mark.skipif(
not _torch.backends.mps.is_available(), reason="MPS-specific test"
),
),
),
)
def test_process_input_longer_than_65536_on(self, device: str):
"""
Processing inputs longer than 65,536 samples using various accelerator
backends can cause problems.
See:
* https://github.com/sdatkinson/neural-amp-modeler/issues/505
* https://github.com/sdatkinson/neural-amp-modeler/issues/512
(Funny that both have the same length limit--65,536...)
Assert that precautions are taken.
"""
x = _torch.zeros((65_536 + 1,)).to(device)
model = self._construct().to(device)
model(x)

def _construct(self, C=None, args=None, kwargs=None):
C = self._C if C is None else C
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nam/test_models/test_conv_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from nam.models import conv_net

from ._convolutional import Convolutional as _Convolutional
from .base import Base as _Base


class TestConvNet(_Convolutional):
class TestConvNet(_Base):
@classmethod
def setup_class(cls):
channels = 3
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nam/test_models/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from nam.models import linear as _linear

from ._convolutional import Convolutional as _Convolutional
from .base import Base as _Base


class TestLinear(_Convolutional):
class TestLinear(_Base):
@classmethod
def setup_class(cls):
C = _linear.Linear
Expand Down
44 changes: 23 additions & 21 deletions tests/test_nam/test_models/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,26 @@
# Created Date: Sunday July 17th 2022
# Author: Steven Atkinson (steven@atkinson.mn)

from pathlib import Path
from tempfile import TemporaryDirectory
import pytest as _pytest
import torch as _torch

import pytest
import torch
from nam.models import recurrent as _recurrent

from nam.models import recurrent
from .base import Base as _Base

from .base import Base
_metadata_loudness_x_mocked = 0.1 * _torch.randn((11,)) # Shorter for speed

_metadata_loudness_x_mocked = 0.1 * torch.randn((11,)) # Shorter for speed


class TestLSTM(Base):
class TestLSTM(_Base):
@classmethod
def setup_class(cls):
class LSTMWithMocks(recurrent.LSTM):
class LSTMWithMocks(_recurrent.LSTM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._get_initial_state_burn_in = 7

@classmethod
def _metadata_loudness_x(cls) -> torch.Tensor:
def _metadata_loudness_x(cls) -> _torch.Tensor:
return _metadata_loudness_x_mocked

num_layers = 2
Expand All @@ -37,15 +34,20 @@ def _metadata_loudness_x(cls) -> torch.Tensor:
cls._num_layers = num_layers
cls._hidden_size = hidden_size

def test_get_initial_state_cpu(self):
return self._t_initial_state("cpu")

@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU test")
def test_get_initial_state_gpu(self):
self._t_initial_state("cuda")

def _t_initial_state(self, device):
@_pytest.mark.parametrize(
"device",
(
"cpu",
_pytest.param(
"cuda",
marks=_pytest.mark.skipif(
not _torch.cuda.is_available(), reason="GPU test"
),
),
),
)
def test_get_initial_state_on(self, device: str):
model = self._construct().to(device)
h, c = model._get_initial_state()
assert isinstance(h, torch.Tensor)
assert isinstance(c, torch.Tensor)
assert isinstance(h, _torch.Tensor)
assert isinstance(c, _torch.Tensor)
31 changes: 17 additions & 14 deletions tests/test_nam/test_models/test_wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
# Created Date: Friday May 5th 2023
# Author: Steven Atkinson (steven@atkinson.mn)

import pytest
import torch
import pytest as _pytest
import torch as _torch

from nam.models.wavenet import WaveNet
from nam.train.core import Architecture, get_wavenet_config
from nam.models.wavenet import WaveNet as _WaveNet
from nam.train.core import (
Architecture as _Architecture,
get_wavenet_config as _get_wavenet_config,
)

from ._convolutional import Convolutional as _Convolutional
from .base import Base as _Base


class TestWaveNet(_Convolutional):
class TestWaveNet(_Base):
@classmethod
def setup_class(cls):
C = WaveNet
C = _WaveNet
args = ()
kwargs = {
"layers_configs": [
Expand All @@ -31,22 +34,22 @@ def setup_class(cls):
super().setup_class(C, args, kwargs)

def test_import_weights(self):
config = get_wavenet_config(Architecture.FEATHER)
model_1 = WaveNet.init_from_config(config)
model_2 = WaveNet.init_from_config(config)
config = _get_wavenet_config(_Architecture.FEATHER)
model_1 = _WaveNet.init_from_config(config)
model_2 = _WaveNet.init_from_config(config)

batch_size = 2
x = torch.randn(batch_size, model_1.receptive_field + 23)
x = _torch.randn(batch_size, model_1.receptive_field + 23)

y1 = model_1(x)
y2_before = model_2(x)

model_2.import_weights(model_1._export_weights())
y2_after = model_2(x)

assert not torch.allclose(y2_before, y1)
assert torch.allclose(y2_after, y1)
assert not _torch.allclose(y2_before, y1)
assert _torch.allclose(y2_after, y1)


if __name__ == "__main__":
pytest.main()
_pytest.main()

0 comments on commit b5de1f4

Please sign in to comment.