Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No module named triton on 2080TI #120

Open
JamesHOEEEE opened this issue May 23, 2024 · 3 comments
Open

No module named triton on 2080TI #120

JamesHOEEEE opened this issue May 23, 2024 · 3 comments

Comments

@JamesHOEEEE
Copy link

Hi all

I was try to run amg_example.py on 2080TI too , I know the triton kernel was written specifically for the A100 ,

so according the ReadMe file its need to set the environment variable to SEGMENT_ANYTHING_FAST_USE_FLASH_4=0,

here is my code

import OS
os.environ[' SEGMENT_ANYTHING_FAST_USE_FLASH_4'] = '0'

but its still have miss the triton module error ,

Did I do something wrong? Or have any suggestions?

thanks you

@cpuhrsch
Copy link
Contributor

Hi @JamesHOEEEE,

You can try manually disabling the kernel by commenting out these lines here:

import math
sm_scale = 1. / math.sqrt(q_.size(-1))
# Check if second last dimension is multiple of 256
q_size_2_padded = (((q_.size(-2) + 256 - 1) // 256) * 256) - q_.size(-2)
def kernel_guards(q_, k_, v_):
return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL
# vit_b and vit_l
# TODO: This kernel currently does not produce correct results for batch size 1 for this case
if q_.size(0) > 1 and q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
o = torch.ops.customflash.custom_flash_aligned(
q_, k_, v_, rel_h_w, sm_scale)
if o.numel() > 0:
return o
# vit_h
if q_size_2_padded == 0 and q_.size(-1) == 80 and kernel_guards(q_, k_, v_):
# Only support multiples of 64, so need to pad
q = torch.nn.functional.pad(q_, (0, 128 - 80, 0, 0), "constant", 0)
k = torch.nn.functional.pad(k_, (0, 128 - 80, 0, 0), "constant", 0)
v = torch.nn.functional.pad(v_, (0, 128 - 80, 0, 0), "constant", 0)
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
o = torch.ops.customflash.custom_flash_aligned(
q, k, v, rel_h_w, sm_scale)
if o.numel() > 0:
return o[:, :, :, :80]

Thanks,
Christian

@JamesHOEEEE
Copy link
Author

Hi @cpuhrsch

Thanks for your reply , It seems working fine now

@FlyingAnt2018
Copy link

Hi @JamesHOEEEE,

You can try manually disabling the kernel by commenting out these lines here:

import math
sm_scale = 1. / math.sqrt(q_.size(-1))
# Check if second last dimension is multiple of 256
q_size_2_padded = (((q_.size(-2) + 256 - 1) // 256) * 256) - q_.size(-2)
def kernel_guards(q_, k_, v_):
return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL
# vit_b and vit_l
# TODO: This kernel currently does not produce correct results for batch size 1 for this case
if q_.size(0) > 1 and q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
o = torch.ops.customflash.custom_flash_aligned(
q_, k_, v_, rel_h_w, sm_scale)
if o.numel() > 0:
return o
# vit_h
if q_size_2_padded == 0 and q_.size(-1) == 80 and kernel_guards(q_, k_, v_):
# Only support multiples of 64, so need to pad
q = torch.nn.functional.pad(q_, (0, 128 - 80, 0, 0), "constant", 0)
k = torch.nn.functional.pad(k_, (0, 128 - 80, 0, 0), "constant", 0)
v = torch.nn.functional.pad(v_, (0, 128 - 80, 0, 0), "constant", 0)
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
o = torch.ops.customflash.custom_flash_aligned(
q, k, v, rel_h_w, sm_scale)
if o.numel() > 0:
return o[:, :, :, :80]

Thanks, Christian

Hi, bro. I am working on windows11 os, and have no effect by commenting these lines. Report "No module named triton", and how to deal with code in flash_4.py about "import triton" ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants