Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Release code of MixFormer (CVPR2022, Oral) #1820

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from

Conversation

chensnathan
Copy link

MixFormer: Mixing Features across Windows and Dimensions

Pre-trained models will be added in next few days.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Apr 8, 2022

Thanks for your contribution!

@Seperendity
Copy link

Seperendity commented May 20, 2022

您好,想请教一下,ppcls/arch/backbone/model_zoo/mixformer.py中line229这里的维度v = v * x_cnn2v是如何计算的呢?我看每个部分的最后两个维度分别是(1, C // self.num_heads)(N, C // self.num_heads)这里做矩阵乘列和行的维度不是不对应吗?

@chensnathan
Copy link
Author

@Seperendity 你好,这个是可以通过广播机制来实现的

@Seperendity
Copy link

@chensnathan 非常感谢您的解答!知道用的是广播机制了。但还是对为什么广播后的值乘的对应维度是numswindowtokens数这两维,我看论文的意思以为是把权重乘到通道维度上。x_cnn2v = torch.sigmoid(channel_interaction).reshape([-1, 1, self.num_heads, 1, C // self.num_heads]) v = v.reshape([x_cnn2v.shape[0], -1, self.num_heads, N, C // self.num_heads])代码中这么乘的原因是什么呢?直观上来看并没有将学到的权重赋到dims维度上,希望您能解答一下,不甚感激。

@chensnathan
Copy link
Author

chensnathan commented May 22, 2022

@Seperendity 你好,这样做是为了配合v的维度。举个例子理解一下,假设v的shape是[B, C, H, W],x_cnn2v的shape是[B, C, 1, 1],那么v = v * x_cnn2v是一个简单的channel attention。但是,在代码里的第223行,因为后续要准备做window-based self-attention,v的shape是[B*(H/win)*(W/win), win*win, num_heads, C/num_heads],而x_cnn2v的shape是[B, C, 1, 1],这个时候没法直接做channel attention。当然这里可以用不同的实现:

  1. 你可以把v再reshape回[B, C, H, W],做完channel attention之后,再变成[B*(H/win)*(W/win), win*win, num_heads, C/num_heads],再进入到下面的self-attention。
  2. 我这里选择的是,v的shape从[B*(H/win)*(W/win), win*win, num_heads, C/num_heads]变为[B, (H/win)*(W/win), win*win, num_heads, C/num_heads],x_cnn2v变为[B, 1, 1, num_heads, C/num_heads],然后再变为[B*(H/win)*(W/win), win*win, num_heads, C/num_heads],再进入到下面的self-attention。
    本质上是一样的。

@Seperendity
Copy link

@chensnathan 明白您的意思了,非常感谢您的耐心解答!很有意思的工作

@cxz1276316542
Copy link

MixFormer: Mixing Features across Windows and Dimensions

Pre-trained models will be added in next few day
你好,预训练模型出来了吗?在哪里下载呢?

This was referenced Sep 23, 2022
docs/en/models/MixFormer_en.md Outdated Show resolved Hide resolved
docs/zh_CN/models/ImageNet1k/MixFormer.md Show resolved Hide resolved
ppcls/arch/backbone/__init__.py Outdated Show resolved Hide resolved
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants