Skip to content

Commit

Permalink
[docs] Updates trace whisper model document (#3426)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Aug 19, 2024
1 parent 71540c2 commit 2524292
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions examples/docs/whisper_speech_text.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ Output:

## Trace the model

You can use the following script to get the traced model and also vocabulary files:
You can use the following script to get the traced model and also vocabulary files, this script works
with `transformers==4.38.0`

```python
from transformers import WhisperProcessor, WhisperForConditionalGeneration
Expand All @@ -35,7 +36,11 @@ import numpy as np

processor = WhisperProcessor.from_pretrained("openai/whisper-base")
processor.tokenizer.save_pretrained("whisper-tokenizer")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", return_dict=False)
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", torchscript=True, attn_implementation="eager")
model.generation_config.language = "en"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

test = []
Expand All @@ -45,13 +50,13 @@ for ele in ds:
input_features = processor(np.concatenate(test), return_tensors="pt").input_features
generated_ids = model.generate(inputs=input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Original: " + transcription)
print(f"Original: {transcription}")

# Start tracing
traced_model = torch.jit.trace_module(model, {"generate": [input_features]})
generated_ids = traced_model.generate(input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Traced: " + transcription)
print(f"Traced: {transcription}")

torch.jit.save(traced_model, "whisper_en.pt")
```

0 comments on commit 2524292

Please sign in to comment.