Skip to content

Commit

Permalink
[ssl/bestrq] model and numerical stability (#2060)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Oct 18, 2023
1 parent 10d798f commit 34ef62b
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ def __init__(
torch.empty(self.num_codebooks, num_embeddings))
torch.nn.init.zeros_(self.encoder_top_n_out_bias)

# mask embedding
mask_embedding_dim = num_mel_bins
self.mask_emb = torch.nn.parameter.Parameter(
torch.empty(mask_embedding_dim), requires_grad=True)
torch.nn.init.trunc_normal_(self.mask_emb, std=0.1)

# stack input: eg: fbank
self.stack_frames = self.encoder.embed.right_context + 1
self.stride = self.encoder.embed.subsampling_rate
Expand Down Expand Up @@ -245,7 +239,8 @@ def _apply_mask_signal(
device=input.device)

masks_expand = masks.unsqueeze(-1) # [B, T, 1]
mask_emb = self.mask_emb.to(input.device).view(1, 1, -1)
mask_emb = torch.normal(mean=0, std=0.1,
size=(1, 1, input.size(2))).to(input.device)
xs = torch.where(masks_expand, mask_emb, input)
return xs, masks

Expand Down

0 comments on commit 34ef62b

Please sign in to comment.