Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

None type error with local attention #45

Closed
benderama3 opened this issue Nov 8, 2020 · 3 comments
Closed

None type error with local attention #45

benderama3 opened this issue Nov 8, 2020 · 3 comments

Comments

@benderama3
Copy link

benderama3 commented Nov 8, 2020

Trying to run the example from the readme using local attention instead of linear attention. I changed the attention_type and added an additional argument in the TransformerEncoderBuilder.from_kwargs method:

import torch
from fast_transformers.builders import TransformerEncoderBuilder

# Create the builder for our transformers
builder = TransformerEncoderBuilder.from_kwargs(
    n_layers=8,
    n_heads=8,
    query_dimensions=64,
    value_dimensions=64,
    feed_forward_dimensions=1024,
    local_context=8, #ADDED THIS LINE
)

# Build a transformer with softmax attention
builder.attention_type = "full"
softmax_model = builder.get()

# Build a transformer with linear attention
builder.attention_type = "local" #CHANGED THIS LINE
linear_model = builder.get()

# Construct the dummy input
X = torch.rand(10, 1000, 8*64)

# Prepare everythin for CUDA
X = X.cuda()
softmax_model.cuda()
softmax_model.eval()
linear_model.cuda()
linear_model.eval()

# Warmup the GPU
with torch.no_grad():
    softmax_model(X)
    linear_model(X)
torch.cuda.synchronize()

# Measure the execution time
softmax_start = torch.cuda.Event(enable_timing=True)
softmax_end = torch.cuda.Event(enable_timing=True)
linear_start = torch.cuda.Event(enable_timing=True)
linear_end = torch.cuda.Event(enable_timing=True)

with torch.no_grad():
    softmax_start.record()
    y = softmax_model(X)
    softmax_end.record()
    torch.cuda.synchronize()
    print("Softmax: ", softmax_start.elapsed_time(softmax_end), "ms")
    # Softmax: 144 ms (on a GTX1080Ti)

with torch.no_grad():
    linear_start.record()
    y = linear_model(X)
    linear_end.record()
    torch.cuda.synchronize()
    print("Linear: ", linear_start.elapsed_time(linear_end), "ms")
    # Linear: 68 ms (on a GTX1080Ti)

The exemple throws an error:

Traceback (most recent call last):
  File "_.py", line 35, in <module>
    linear_model(X)
  File "/home/jjhon/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jjhon/.local/lib/python3.6/site-packages/fast_transformers/transformers.py", line 139, in forward
    x = layer(x, attn_mask=attn_mask, length_mask=length_mask)
  File "/home/jjhon/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jjhon/.local/lib/python3.6/site-packages/fast_transformers/transformers.py", line 81, in forward
    key_lengths=length_mask
  File "/home/jjhon/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jjhon/.local/lib/python3.6/site-packages/fast_transformers/attention/attention_layer.py", line 109, in forward
    key_lengths
  File "/home/jjhon/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/jjhon/.local/lib/python3.6/site-packages/fast_transformers/attention/local_attention.py", line 82, in forward
    self.local_context
  File "/home/jjhon/.local/lib/python3.6/site-packages/fast_transformers/local_product/__init__.py", line 49, in forward
    local_context
TypeError: 'NoneType' object is not callable

It does work with the other attention modules.
Am I doing something wrong?
Is the local_context argument supposed to be an integer?

Thank you.

EDIT: looks like it is failing using cuda only (pytorch 1.6 with cuda 10.1), it works on the cpu

EDIT2: fixed using --no-cache-dir argument when installing with pip (to recompile)

@15805383399
Copy link

I got similar problems
I found that there is no casual_product_gpu in the directory ,and it can't be imported
I have tried --no-cache-dir to recompile, but it didn't work

  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 712, in <module>
    train()
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 587, in train
    losses = net.train_step(batch_x, batch_y, batch_mask)
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 302, in train_step
    h, y_type = self.forward_hidden(x)
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 368, in forward_hidden
    h = self.transformer_encoder(pos_emb, attn_mask)  # y: b x s x d_model
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/transformers.py", line 138, in forward
    x = layer(x, attn_mask=attn_mask, length_mask=length_mask)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/transformers.py", line 77, in forward
    x = x + self.dropout(self.attention(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/attention_layer.py", line 105, in forward
    new_values = self.inner_attention(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/causal_linear_attention.py", line 98, in forward
    V = causal_linear(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/causal_linear_attention.py", line 23, in causal_linear
    V_new = causal_dot_product(Q, K, V)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/causal_product/__init__.py", line 44, in forward
    CausalDotProduct.dot[device.type](
TypeError: 'NoneType' object is not callable

进程已结束,退出代码为 1

image

@jingshuangliu22
Copy link

jingshuangliu22 commented Oct 24, 2022

I got similar problems I found that there is no casual_product_gpu in the directory ,and it can't be imported I have tried --no-cache-dir to recompile, but it didn't work

  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 712, in <module>
    train()
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 587, in train
    losses = net.train_step(batch_x, batch_y, batch_mask)
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 302, in train_step
    h, y_type = self.forward_hidden(x)
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 368, in forward_hidden
    h = self.transformer_encoder(pos_emb, attn_mask)  # y: b x s x d_model
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/transformers.py", line 138, in forward
    x = layer(x, attn_mask=attn_mask, length_mask=length_mask)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/transformers.py", line 77, in forward
    x = x + self.dropout(self.attention(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/attention_layer.py", line 105, in forward
    new_values = self.inner_attention(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/causal_linear_attention.py", line 98, in forward
    V = causal_linear(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/causal_linear_attention.py", line 23, in causal_linear
    V_new = causal_dot_product(Q, K, V)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/causal_product/__init__.py", line 44, in forward
    CausalDotProduct.dot[device.type](
TypeError: 'NoneType' object is not callable

进程已结束,退出代码为 1

image

Hi, I met the exact same issue and may I ask if you have solved it? @15805383399

@jingshuangliu22
Copy link

I got similar problems I found that there is no casual_product_gpu in the directory ,and it can't be imported I have tried --no-cache-dir to recompile, but it didn't work

  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 712, in <module>
    train()
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 587, in train
    losses = net.train_step(batch_x, batch_y, batch_mask)
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 302, in train_step
    h, y_type = self.forward_hidden(x)
  File "/media/lyk/高速/Project/github/compound-word-transformer/workspace/uncond/cp-linear/main-cp.py", line 368, in forward_hidden
    h = self.transformer_encoder(pos_emb, attn_mask)  # y: b x s x d_model
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/transformers.py", line 138, in forward
    x = layer(x, attn_mask=attn_mask, length_mask=length_mask)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/transformers.py", line 77, in forward
    x = x + self.dropout(self.attention(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/attention_layer.py", line 105, in forward
    new_values = self.inner_attention(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/causal_linear_attention.py", line 98, in forward
    V = causal_linear(
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/attention/causal_linear_attention.py", line 23, in causal_linear
    V_new = causal_dot_product(Q, K, V)
  File "/home/lyk/anaconda3/lib/python3.8/site-packages/pytorch_fast_transformers-0.4.0-py3.8-linux-x86_64.egg/fast_transformers/causal_product/__init__.py", line 44, in forward
    CausalDotProduct.dot[device.type](
TypeError: 'NoneType' object is not callable

进程已结束,退出代码为 1

image

Hi, I met the exact same issue and may I ask if you have solved it? @15805383399

I resolved this one with the methods mentioned in #96 😁

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants