-
Notifications
You must be signed in to change notification settings - Fork 154
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BUGFIX] Workaround for PyTorch MPS bug with sequences longer than 65…
…,536 samples (#506) * Add testing to show failure cases for models using convolutions * Fix imports * Remove unused imports * Fix * Fix bug * Remove debug statements * Reason * Skip condition * Fix bug: decorator
- Loading branch information
1 parent
0670778
commit 9a1c72e
Showing
6 changed files
with
131 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# File: _conv_mixin.py | ||
# Created Date: Saturday November 23rd 2024 | ||
# Author: Steven Atkinson (steven@atkinson.mn) | ||
|
||
""" | ||
Mix-in tests for models with a convolution layer | ||
""" | ||
|
||
import pytest as _pytest | ||
import torch as _torch | ||
|
||
from .base import Base as _Base | ||
|
||
|
||
class Convolutional(_Base): | ||
@_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test") | ||
def test_process_input_longer_than_65536(self): | ||
""" | ||
Processing inputs longer than 65,536 samples using the MPS backend can | ||
cause problems. | ||
See: https://github.com/sdatkinson/neural-amp-modeler/issues/505 | ||
Assert that precautions are taken. | ||
""" | ||
|
||
x = _torch.zeros((65_536 + 1,)).to("mps") | ||
|
||
model = self._construct().to("mps") | ||
model(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# File: test_linear.py | ||
# Created Date: Saturday November 23rd 2024 | ||
# Author: Steven Atkinson (steven@atkinson.mn) | ||
|
||
import pytest as _pytest | ||
|
||
from nam.models import linear as _linear | ||
|
||
from ._convolutional import Convolutional as _Convolutional | ||
|
||
|
||
class TestLinear(_Convolutional): | ||
@classmethod | ||
def setup_class(cls): | ||
C = _linear.Linear | ||
args = () | ||
kwargs = {"receptive_field": 2, "sample_rate": 44100} | ||
super().setup_class(C, args, kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters