Skip to content

Commit

Permalink
model.decodde work && recognize.py work
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 23, 2023
1 parent 0a98847 commit e987b00
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
2 changes: 1 addition & 1 deletion wenet/paraformer/ali_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def decode(self, methods: List[str], speech: torch.Tensor,
results = []
out, out_lens = self.forward(speech, speech_lens)
for (i, value) in enumerate(out.argmax(-1).numpy()):
results.append(DecodeResult(value.numpy()[:out_lens[i]]))
results.append(DecodeResult(value[:out_lens[i]]))

results_dict['paraformer_greedy_search'] = results
return results_dict
17 changes: 6 additions & 11 deletions wenet/paraformer/ali_paraformer/test_infer_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml
from wenet.cif.predictor import Predictor
from wenet.paraformer.ali_paraformer.model import (
AliParaformer,
SanmDecoer,
SanmEncoder,
)
from wenet.transformer.cmvn import GlobalCMVN
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.cmvn import load_cmvn
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.init_model import init_model

Expand Down Expand Up @@ -62,9 +54,12 @@ def main():
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)], dtype=torch.int64)

out, token_nums = model(feats, feats_lens)
print("".join([char_dict[id] for id in out.argmax(-1)[0].numpy()]))
print(token_nums)
decode_results = model.decode(['paraformer_greedy_search'], feats,
feats_lens)
print("".join([
char_dict[id]
for id in decode_results['paraformer_greedy_search'][0].tokens
]))

if args.output_file:
script_model = torch.jit.script(model)
Expand Down

0 comments on commit e987b00

Please sign in to comment.