Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] refactor mqa repeat #2497

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 12 additions & 51 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def _update_kv_and_cache(
# non-trivial to calculate `next_cache_start` here.
# new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
new_cache = (k, v)
# for multi query or multi group attention
if self.h_kv != self.h and self.h_kv != 1:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)
return k, v, new_cache

def forward(
Expand Down Expand Up @@ -245,19 +257,6 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

# for multi query or multi group attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down Expand Up @@ -364,19 +363,6 @@ def forward(
q = q.transpose(1, 2) # (batch, time1, head, d_k)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
Expand Down Expand Up @@ -459,19 +445,6 @@ def forward(
q, k, v = self.forward_qkv(query, key, value)
new_cache = (k, v) if not self.training else cache

# for multi query or multi groups attention
if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)

B = query.size(0)
Beams = 1
if B != k.size(0):
Expand Down Expand Up @@ -645,18 +618,6 @@ def forward(
# see above
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

if self.h_kv != self.h:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=1,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=1,
)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down
Loading