Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalence between Patch Merging and Conv. #256

Open
AHHOZP opened this issue Aug 29, 2022 · 3 comments
Open

Equivalence between Patch Merging and Conv. #256

AHHOZP opened this issue Aug 29, 2022 · 3 comments

Comments

@AHHOZP
Copy link

AHHOZP commented Aug 29, 2022

Hello, after looking at the code in patch merging part, we found the complex operation that slice the feature and concatenate them then go through the linear layer to reduce dimension from 4C to 2C is completely equal to a conv layer of kernel size 2 and stride 2.

The operation you did is concatenate 4 pixels from a 2x2 patch in to 1 pixel, but quadrupled channel.
Every 2x2 patch shared the same weight with other patches in your linear layer (self.reduction).
The conv(kernel size=2, stride=2) does the same thing.

Amount of parameters of this linear layer is equal to this conv layer.
linear layer params = input channel * output channel = 4C * 2C = 8 * C^2
conv layer params = kernel size * kernel size * input channel * output channel = 2 * 2 * C * (2 * C) = 8 * C^2
SO, linear layer params == conv layer params

@yan-mingyuan
Copy link

yan-mingyuan commented Aug 11, 2023

I agree with you that both implementations are essentially the same.

Standard Implementation Using Slicing Operation

Parameters:

self.reduction = nn.Linear(4*C, 2*C, bias=False)

Forward Pass:

x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C)  # [B, HW/4, 4*C]

x = self.norm(x)
x = self.reduction(x)

The first approach employs slicing, which may appear intricate initially. However, after optimization, it introduces minimal overhead. Nevertheless,

  • it involves complex slicing operations, leading to more concise code.
  • while not strictly necessary, it introduces the concept of PatchMerge. Nevertheless, using a kernel size of 2 and a stride of 2 for downsampling is intuitively straightforward and aligns well with common understanding.

Alternative Implementation Using Conv2d

Parameters:

self.reduction = nn.Conv2d(C, C*2, kernel_size=2, stride=2, bias=False)

Forward Pass:

# forward
# [B, H*W, C] -> [B, H, W, C] -> [B, C, H, W]
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
x = self.norm(x)
x = self.reduction(x)
x = x.permute(0, 2, 3, 1).view(B, -1, C*2)

The second approach involves Conv2d, which appears more elegant; however,

  • the permute operation might introduce potential additional movement overhead.
  • it's noteworthy that no contiguous operation is necessary before the view operation, indicating that extra movement overhead might not be present.

@AHHOZP
Copy link
Author

AHHOZP commented Aug 15, 2023

You are right, but the tensor slice and concat also cost time. I dont know how much, and which cost more when compared to contiguous.
Another problem is layernorm, the code for 'Using Conv2d', you will find the norm is different from origin implementation.
And batchnorm is used for cnn instead of layernorm, so I guess this may be another reason why they use patch mergy instead of cnn.

@yan-mingyuan
Copy link

I agree with your perspective. Whether we employ tensor slicing or concatenation, both operations involve accessing non-contiguous memory regions, inevitably leading to non-locality that can impact performance.

Shifting our focus to the matter of normalization. While BatchNorm2d is common in computer vision, LayerNorm, typical in transformer-based NLP tasks, is used here, possibly due to the architecture's influence. This choice shouldn't heavily impact normalization layer behavior, as it's separate from the decision between tensor transposition and patch merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants