Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encountering RuntimeError During backward() in FlashDepthWiseConv1d with Specific Padding Settings #30

Open
kawabata-tomoko opened this issue Nov 11, 2024 · 1 comment

Comments

@kawabata-tomoko
Copy link

kawabata-tomoko commented Nov 11, 2024

This is a brief example that has been edited from the README.md file:

import torch
import torch.nn as nn
import torch.optim as optim
from flashfftconv import FlashDepthWiseConv1d
B=4
L=26000
d=512
k=3
padding=k-1
dtype=torch.bfloat16
device="cuda:4"
# set up PyTorch equivalent to get the weights
# in_channels = out_channels, and kernel size must be odd
x=torch.randn((B,d,L),device=device,dtype=dtype)
conv1d_torch = nn.Conv1d(
    in_channels = d,
    out_channels = d,
    kernel_size = k,
    groups = d,
    padding = padding,
    dtype = dtype,
    device=device
)

flash_conv1d = FlashDepthWiseConv1d(
    channels = d,
    kernel_size=k,
    padding=padding,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    dtype = dtype # this should be the dtype of the weights
).to(device=device)

out_torch = conv1d_torch(x) # x is B, d, L
out_flash = flash_conv1d(x) # x can be a different dtype than weights

# out_torch and out_flash should be the same!
out_flash.sum().backward()#Got an error!
out_torch.sum().backward()#It's OK

When I ran this sample program, I encountered the following error message:

RuntimeError                              Traceback (most recent call last)
Cell In[16], line 1
----> 1 out_flash.sum().backward()

File ~/miniconda3/lib/python3.9/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/miniconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/miniconda3/lib/python3.9/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

File ~/miniconda3/lib/python3.9/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
    295     raise RuntimeError(
    296         "Implementing both 'backward' and 'vjp' for a custom "
    297         "Function is not allowed. You should only implement one "
    298         "of them."
    299     )
    300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)

File ~/miniconda3/lib/python3.9/site-packages/flashfftconv-0.0.0-py3.9.egg/flashfftconv/depthwise_1d.py:20, in conv1dFunc.backward(ctx, dout)
     18 input, weight, bias = ctx.saved_tensors
     19 dout  = dout.contiguous()
---> 20 du, dk, dbias = conv1d_backward(dout, input, weight, bias, ctx.padding, ctx.is_bhl)
     21 return du, dk, dbias, None, None

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2048, 26000] but got: [2048, 26002].

Interestingly, this code works if the padding is set to (kernel-1)//2, regardless of whether using dtype=float16, float32, or bfloat16. Here is another example copied from test_conv1d.py:

import torch
import torch.nn as nn
from flashfftconv import FlashDepthWiseConv1d
torch.cuda.empty_cache() # empty cache between runs
torch.manual_seed(42)
device = 'cuda:4'
dtype=(torch.float16, torch.float16)
in_dtype = dtype[0]
w_dtype = dtype[1]
k=5
d=768
l=8192
b=4
padding = 1 #(k -1)//2

torch.set_default_device(device)
torch.set_default_dtype(w_dtype)

conv1d_torch = nn.Conv1d(
    in_channels = d,
    out_channels = d,
    kernel_size = k,
    groups = d,
    padding = padding
).to(device).to(w_dtype)

conv1d_cuda = FlashDepthWiseConv1d(
    channels = d,
    kernel_size=k,
    padding=padding,
    weights=conv1d_torch.weight,
    bias=conv1d_torch.bias,
    is_bhl=True,
    dtype=w_dtype,
).to(device)

x = torch.randn([b, d, l], device=device, dtype=in_dtype)
x_wdtype = x.clone().to(w_dtype)
x_cuda = x.clone().detach().requires_grad_(True)
dout = torch.randn([b, d, l], device=device, dtype=in_dtype)
dout_wdtype= dout.clone().to(w_dtype)

x.requires_grad = True
x_wdtype.requires_grad = True

y_torch = conv1d_torch(x_wdtype)
y_cuda = conv1d_cuda(x_cuda)

y_torch.backward(dout_wdtype, retain_graph=True)
y_cuda.backward(dout, retain_graph=True)

This caused the same error:

...
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([4, 768, 8192]) and output[0] has a shape of torch.Size([4, 768, 8190]).

I believe there might be an error in the implementation of the backward method in the program. Could you please provide any suggestions or references for possible corrections?
P.S. Tested with NVIDIA A800 80GB device, Driver Version: 525.85.12, CUDA Version: 12.0. Python 3.9.19, torch==2.3.1,g++ (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0

@DanFu09
Copy link
Contributor

DanFu09 commented Nov 11, 2024

Looks like a bug - feel free to look through the outputs and file a PR to fix it if you have the chance. We are (slowly) working to rewrite this library in a more modern framework like ThunderKittens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants