Skip to content

Commit

Permalink
[BUGFIX] Workaround for PyTorch MPS bug with sequences longer than 65…
Browse files Browse the repository at this point in the history
…,536 samples (#506)

* Add testing to show failure cases for models using convolutions

* Fix imports

* Remove unused imports

* Fix

* Fix bug

* Remove debug statements

* Reason

* Skip condition

* Fix bug: decorator
  • Loading branch information
sdatkinson authored Nov 24, 2024
1 parent 0670778 commit 9a1c72e
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 52 deletions.
45 changes: 43 additions & 2 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:


class BaseNet(_Base):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._mps_65536_fallback = False

def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
Expand All @@ -179,16 +183,53 @@ def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
x = torch.cat(
(torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
)
y = self._forward(x, **kwargs)
if x.shape[1] < self.receptive_field:
raise ValueError(
f"Input has {x.shape[1]} samples, which is too few for this model with "
f"receptive field {self.receptive_field}!"
)
y = self._forward_mps_safe(x, **kwargs)
if scalar:
y = y[0]
return y

def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
return self(x)

def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Wrap `._forward()` to protect against MPS-unsupported inptu lengths
beyond 65,536 samples.
Check this again when PyTorch 2.5.2 is released--hopefully it's fixed
then.
"""
if not self._mps_65536_fallback:
try:
return self._forward(x, **kwargs)
except NotImplementedError as e:
if "Output channels > 65536 not supported at the MPS device." in str(e):
self._mps_65536_fallback = True
return self._forward_mps_safe(x, **kwargs)
else:
raise e
else:
# Stitch together the output one piece at a time to avoid the MPS error
stride = 65_536 - (self.receptive_field - 1)
# We need to make sure that the last segment is big enough that we have the required history for the receptive field.
out_list = []
for i in range(0, x.shape[1], stride):
j = min(i+65_536, x.shape[1])
xi = x[:, i:j]
out_list.append(self._forward(xi, **kwargs))
# Bit hacky, but correct.
if j == x.shape[1]:
break
return torch.cat(out_list, dim=1)


@abc.abstractmethod
def _forward(self, x: torch.Tensor) -> torch.Tensor:
def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
The true forward method.
Expand Down
46 changes: 9 additions & 37 deletions nam/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
Linear model
"""

import json
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
Expand All @@ -30,38 +27,6 @@ def pad_start_default(self) -> bool:
def receptive_field(self) -> int:
return self._net.weight.shape[2]

def export(self, outdir: Path):
training = self.training
self.eval()
with open(Path(outdir, "config.json"), "w") as fp:
json.dump(
{
"version": __version__,
"architecture": self.__class__.__name__,
"config": {
"receptive_field": self.receptive_field,
"bias": self._bias,
},
},
fp,
indent=4,
)

params = [self._net.weight.flatten()]
if self._bias:
params.append(self._net.bias.flatten())
params = torch.cat(params).detach().cpu().numpy()
# Hope I don't regret using np.save...
np.save(Path(outdir, "weights.npy"), params)

# And an input/output to verify correct computation:
x, y = self._export_input_output()
np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())

# And resume training state
self.train(training)

def export_cpp_header(self):
raise NotImplementedError()

Expand All @@ -73,7 +38,14 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor:
return self._net(x[:, None])[:, 0]

def _export_config(self):
raise NotImplementedError()
return {
"receptive_field": self.receptive_field,
"bias": self._bias,
}

def _export_weights(self) -> np.ndarray:
raise NotImplementedError()
params_list = [self._net.weight.flatten()]
if self._bias:
params_list.append(self._net.bias.flatten())
params = torch.cat(params_list).detach().cpu().numpy()
return params
30 changes: 30 additions & 0 deletions tests/test_nam/test_models/_convolutional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# File: _conv_mixin.py
# Created Date: Saturday November 23rd 2024
# Author: Steven Atkinson (steven@atkinson.mn)

"""
Mix-in tests for models with a convolution layer
"""

import pytest as _pytest
import torch as _torch

from .base import Base as _Base


class Convolutional(_Base):
@_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test")
def test_process_input_longer_than_65536(self):
"""
Processing inputs longer than 65,536 samples using the MPS backend can
cause problems.
See: https://github.com/sdatkinson/neural-amp-modeler/issues/505
Assert that precautions are taken.
"""

x = _torch.zeros((65_536 + 1,)).to("mps")

model = self._construct().to("mps")
model(x)
15 changes: 6 additions & 9 deletions tests/test_nam/test_models/test_conv_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
# Created Date: Friday May 6th 2022
# Author: Steven Atkinson (steven@atkinson.mn)

from pathlib import Path
from tempfile import TemporaryDirectory

import pytest
import pytest as _pytest

from nam.models import conv_net

from .base import Base
from ._convolutional import Convolutional as _Convolutional


class TestConvNet(Base):
class TestConvNet(_Convolutional):
@classmethod
def setup_class(cls):
channels = 3
Expand All @@ -23,18 +20,18 @@ def setup_class(cls):
{"batchnorm": False, "activation": "Tanh"},
)

@pytest.mark.parametrize(
@_pytest.mark.parametrize(
("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
)
def test_init(self, batchnorm, activation):
super().test_init(kwargs={"batchnorm": batchnorm, "activation": activation})

@pytest.mark.parametrize(
@_pytest.mark.parametrize(
("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
)
def test_export(self, batchnorm, activation):
super().test_export(kwargs={"batchnorm": batchnorm, "activation": activation})


if __name__ == "__main__":
pytest.main()
_pytest.main()
18 changes: 18 additions & 0 deletions tests/test_nam/test_models/test_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# File: test_linear.py
# Created Date: Saturday November 23rd 2024
# Author: Steven Atkinson (steven@atkinson.mn)

import pytest as _pytest

from nam.models import linear as _linear

from ._convolutional import Convolutional as _Convolutional


class TestLinear(_Convolutional):
@classmethod
def setup_class(cls):
C = _linear.Linear
args = ()
kwargs = {"receptive_field": 2, "sample_rate": 44100}
super().setup_class(C, args, kwargs)
29 changes: 25 additions & 4 deletions tests/test_nam/test_models/test_wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,28 @@
from nam.models.wavenet import WaveNet
from nam.train.core import Architecture, get_wavenet_config

from ._convolutional import Convolutional as _Convolutional


class TestWaveNet(_Convolutional):
@classmethod
def setup_class(cls):
C = WaveNet
args = ()
kwargs = {
"layers_configs": [
{
"input_size": 1,
"condition_size": 1,
"head_size": 1,
"channels": 1,
"kernel_size": 1,
"dilations": [1]
}
]
}
super().setup_class(C, args, kwargs)

# from .base import Base


class TestWaveNet(object):
def test_import_weights(self):
config = get_wavenet_config(Architecture.FEATHER)
model_1 = WaveNet.init_from_config(config)
Expand All @@ -29,3 +46,7 @@ def test_import_weights(self):

assert not torch.allclose(y2_before, y1)
assert torch.allclose(y2_after, y1)


if __name__ == "__main__":
pytest.main()

0 comments on commit 9a1c72e

Please sign in to comment.