forked from jumon/whisper-finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utilities.py
67 lines (54 loc) · 2.17 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from dataclasses import asdict
from typing import List
import jiwer
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
from whisper.tokenizer import Tokenizer, get_tokenizer
from dataloader import get_dataloader
def get_normalizer(multilingual: bool = False):
if multilingual:
return BasicTextNormalizer()
else:
return EnglishTextNormalizer()
def create_special_token_mask(token_tensor, tokenizer):
"""
Creates a mask tensor indicating which tokens in the input tensor are special tokens.
"""
# Create a mask tensor with the same shape as the input tensor
mask = torch.ones_like(token_tensor, dtype=torch.bool)
# Smarter hack: use tokenizer.eot
mask[(token_tensor >= tokenizer.eot) | (token_tensor == -100)] = False
return mask
def decode_tokens_to_prompt(tokens_batch: List[int], tokenizer: Tokenizer):
"""
Remove special tokens and decote tokens to text.
"""
special_token_mask = create_special_token_mask(tokens_batch, tokenizer)
tokens_without_st = [
torch.masked_select(tokens, mask)
for tokens, mask in zip(tokens_batch, special_token_mask)
]
text_prompts = [tokenizer.decode(token) for token in tokens_without_st]
return text_prompts
def get_WER_MultipleTexts(
transcription: list, reference: list, normalizer=EnglishTextNormalizer()
) -> float:
"""
Calculate WER between transcription and reference.
Transcription and reference are lists of strings.
"""
if normalizer is not None:
transcription = [normalizer(text) for text in transcription]
reference = [normalizer(text) for text in reference]
wer = jiwer.wer(reference, transcription)
return wer
def calculate_WER(predicted_tokens, reference_tokens, normalizer, tokenizer) -> float:
predicted_text_prompts = decode_tokens_to_prompt(predicted_tokens, tokenizer)
reference_text_prompts = decode_tokens_to_prompt(reference_tokens, tokenizer)
WER = get_WER_MultipleTexts(
predicted_text_prompts, reference_text_prompts, normalizer
)
return WER