Skip to content

Commit

Permalink
fix rope pos embdining (#2463)
Browse files Browse the repository at this point in the history
* fix rope pos embdining

* fix dropout

* fix comment
  • Loading branch information
Mddct authored Apr 10, 2024
1 parent 1da8b0b commit b9c5d8b
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def __init__(self,
rope_theta=10000.0):
super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len)
delattr(self, 'pe')

pe = precompute_freqs_cis(head_dim, max_len * 2, rope_theta)
self.register_buffer("pe", pe.unsqueeze(0))
self.max_len = max_len * 2
pe = precompute_freqs_cis(head_dim, self.max_len, rope_theta)
self.register_buffer("pe", torch.view_as_real(pe.unsqueeze(0)))
self.dropout_rate = dropout_rate

def forward(
Expand All @@ -219,13 +219,34 @@ def forward(
offset: Union[int,
torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]:

pos_emb = self.position_encoding(offset, x.size(1), False)
pos_emb = self.position_encoding(offset, x.size(1), True)
pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2]
# NOTE(Mddct): some model don't scale
# TODO(Mddct): fix
x = x * self.xscale
# NOTE(Mddct) dropout don't suuport complex float for pos_emb
return self.dropout(x), self.dropout_complex(pos_emb)
return self.dropout(x), pos_emb

def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int,
apply_dropout: bool = True) -> torch.Tensor:

pe = torch.view_as_complex(self.pe)
if isinstance(offset, int):
assert offset + size <= self.max_len
pos_emb = pe[:, offset:offset + size]
else:
assert torch.max(offset) + size <= self.max_len
index = offset.unsqueeze(1) + torch.arange(0, size).to(
offset.device) # B X T
flag = index > 0
# remove negative offset
index = index * flag
pos_emb = F.embedding(index, pe[0]) # B X T X head_dim//2
if apply_dropout:
# NOTE(Mddct) dropout don't suuport complex float for pos_emb
pos_emb = self.dropout_complex(pos_emb)
return pos_emb

def dropout_complex(self, x):
mask = torch.nn.functional.dropout(
Expand Down

0 comments on commit b9c5d8b

Please sign in to comment.