Skip to content

Commit

Permalink
discard modification on gau
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Aug 29, 2023
1 parent 33f89cb commit 17fe05a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 347 deletions.
34 changes: 32 additions & 2 deletions mmpose/models/utils/rtmcc_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from .transformer import ScaleNorm


def rope(x, dim):
"""Applies Rotary Position Embedding to input tensor.
Expand Down Expand Up @@ -79,6 +77,38 @@ def forward(self, x):
return x * self.scale


class ScaleNorm(nn.Module):
"""Scale Norm.
Args:
dim (int): The dimension of the scale vector.
eps (float, optional): The minimum value in clamp. Defaults to 1e-5.
Reference:
`Transformers without Tears: Improving the Normalization
of Self-Attention <https://arxiv.org/abs/1910.05895>`_
"""

def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))

def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: The tensor after applying scale norm.
"""

norm = torch.norm(x, dim=2, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g


class RTMCCBlock(nn.Module):
"""Gated Attention Unit (GAU) in RTMBlock.
Expand Down
Loading

0 comments on commit 17fe05a

Please sign in to comment.