Skip to content

Commit

Permalink
[transformer] refacgtor mqa repeat (#2497)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Apr 22, 2024
1 parent c415f6c commit fbe75dd
Showing 1 changed file with 12 additions and 51 deletions.
63 changes: 12 additions & 51 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def _update_kv_and_cache(
# non-trivial to calculate `next_cache_start` here.
# new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = (k, v)
# for multi query or multi group attention
if self.h_kv != self.h and self.h_kv != 1:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)
return k, v, new_cache

def forward(
Expand Down Expand Up @@ -245,19 +257,6 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

# for multi query or multi group attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down Expand Up @@ -364,19 +363,6 @@ def forward(
q = q.transpose(1, 2) # (batch, time1, head, d_k)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
Expand Down Expand Up @@ -459,19 +445,6 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
new_cache = (k, v) if not self.training else cache

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

B = query.size(0)
Beams = 1
if B != k.size(0):
Expand Down Expand Up @@ -645,18 +618,6 @@ def forward(
# see above
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down

0 comments on commit fbe75dd

Please sign in to comment.