Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flops of Swin Transformer #142

Open
AiHaiHai opened this issue Aug 29, 2024 · 2 comments
Open

Flops of Swin Transformer #142

AiHaiHai opened this issue Aug 29, 2024 · 2 comments
Labels
question Further information is requested

Comments

@AiHaiHai
Copy link

The FLOPs of swin_t is 4.5G, but I get 3.13G here. Am I using it the wrong way?

import torchvision.models as models
from ptflops import get_model_complexity_info

net = models.swin_t(num_classes=1000)
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))
Computational complexity:       3.13 GMac
Number of parameters:           28.29 M 

image

@sovrasov sovrasov added the question Further information is requested label Aug 30, 2024
@sovrasov
Copy link
Owner

Transformer support is not full in the torch backend, to fix that you could switch to aten:

import torchvision.models as models
from ptflops import get_model_complexity_info

net = models.swin_t(num_classes=1000)
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=False, backend='aten')
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))
Computational complexity:       4.5 GMac
Number of parameters:           28.29 M 

@AiHaiHai
Copy link
Author

AiHaiHai commented Aug 31, 2024

Transformer 支持在 torch 后端并不完全,要修复您可以切换到 aten 的问题:

import torchvision.models as models
from ptflops import get_model_complexity_info

net = models.swin_t(num_classes=1000)
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=False, backend='aten')
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))
Computational complexity:       4.5 GMac
Number of parameters:           28.29 M 

Thanks a lot, it works now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants