You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to adopt this implementation of FFT convolutions inside my model, and initial testing yields great results across the board. Unfortunately, at large batch sizes (>140,000), it crashes with
CUDA Runtime Error at: <red>/flash-fft-conv/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:386
an illegal memory access was encountered
Here is a MWE:
import torch
from flashfftconv import FlashFFTConv
signal = torch.rand((140000, 4, 4096), dtype=torch.bfloat16, device="cuda")
kernel = torch.rand((4, 4096), dtype=torch.float32, device="cuda")
conv = FlashFFTConv(4096).to("cuda")
res = conv(signal, kernel)
res += 1
This crashes with batch size 140,000 but works with 130,000.
Do you have an idea where this could come from?
The text was updated successfully, but these errors were encountered:
Ah, this is because it’s breaking the actual indexing of GPU memory in the
kernel - it probably breaks at 131k * 4 * 4k = 2^31.
It should be fixable by changing some of the types of the ints used to
store pointers into HBM, I’ll try to have a go at it this week.
On Mon, Jun 24, 2024 at 8:31 AM francescocarzaniga ***@***.***> wrote:
I'm trying to adopt this implementation of FFT convolutions inside my
model, and initial testing yields great results across the board.
Unfortunately, at large batch sizes (>140,000), it crashes with
CUDA Runtime Error at: <red>/flash-fft-conv/csrc/flashfftconv/monarch_cuda/monarch_cuda_interface_fwd_bf16.cu:386
an illegal memory access was encountered
Here is a MWE:
import torch
from flashfftconv import FlashFFTConv
signal = torch.rand((140000, 4, 4096), dtype=torch.bfloat16, device="cuda")
kernel = torch.rand((4, 4096), dtype=torch.float32, device="cuda")
conv = FlashFFTConv(4096).to("cuda")
res = conv(signal, kernel)
res += 1
This crashes with batch size 140,000 but works with 130,000.
Do you have an idea where this could come from?
—
Reply to this email directly, view it on GitHub
<#27>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIQR7Y7KZ2TTNXUQSU3ZJA3WVAVCNFSM6AAAAABJ2CCUMSVHI2DSMVQWIX3LMV43ASLTON2WKOZSGM3TANJUGIZDOOA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
I'm trying to adopt this implementation of FFT convolutions inside my model, and initial testing yields great results across the board. Unfortunately, at large batch sizes (>140,000), it crashes with
Here is a MWE:
This crashes with batch size 140,000 but works with 130,000.
Do you have an idea where this could come from?
The text was updated successfully, but these errors were encountered: