⚡️ Whisper JAX - up to 70x faster than OpenAI Whisper #1277
Replies: 7 comments 13 replies
-
Do you have any CPU benchmark ? |
Beta Was this translation helpful? Give feedback.
-
What is the difference between OpenAI and HuggingFace Transformer that so big gap of inference time? |
Beta Was this translation helpful? Give feedback.
-
@sanchit-gandhi , this is brilliant work. The Kaggle notebook is portable to google colab too, right? |
Beta Was this translation helpful? Give feedback.
-
Hi. thank you very much . I just want to know when I choose translate . I don't have the language I'm gonna translate to ? . for example |
Beta Was this translation helpful? Give feedback.
-
Would this work on mobile devices? |
Beta Was this translation helpful? Give feedback.
-
I guess the demo is not working anymore, when I try to use it, after I record something and click submit, the area where the results should be shown just shows Error. |
Beta Was this translation helpful? Give feedback.
-
I want to embed this module in my application, is there an image in gcp that can be used directly ? |
Beta Was this translation helpful? Give feedback.
-
Whisper JAX ⚡️ is a highly optimised Whisper implementation for both GPU and TPU. Try the demo here and transcribe a 1 hour of audio in under 15 seconds: https://huggingface.co/spaces/sanchit-gandhi/whisper-jax
The 70x speed gain we see comes in three stages:
Let's find out more below 👇
1. Batching over un-batched
🤗 Transformers implements a batching algorithm where a single audio sample is chunked into 30s segments, and then chunks transcribed in batches. This batching algorithm gives up to a 7x gain over OpenAI (which transcibes chunks sequentially) with nearly no degradation to the WER.
2. JAX over PyTorch
JAX is an automatic differentiation library for high-performance machine learning research. By Just-In Time (JIT) compiling Whisper, we get a 2x speed-up vs 🤗 Transformers PyTorch on GPU.
3. TPUs over GPUs
Tensor Processing Units (TPUs) are ML accelerators designed by Google
TPUs are purpose built for matrix multiplications, giving them a signficant advantage over more general GPUs. The result? Running Whisper JAX on TPU v4-8 is 5x faster than on an NVIDIA A100.
Adding it all up
7x from batching
2x from JAX
5x speed-gain from TPU
=> 70x speed-gain overall
Table 1: Average inference time in seconds for audio files of increasing length. GPU device is a single A100 40GB GPU.
TPU device is a single TPU v4-8.
Check out the repository for using the model yourself: https://github.com/sanchit-gandhi/whisper-jax
All pre-trained OpenAI checkpoints are compatible! For fine-tuned checkpoints, we include instructions for converting PyTorch weights to Flax: https://github.com/sanchit-gandhi/whisper-jax#available-models-and-languages
Beta Was this translation helpful? Give feedback.
All reactions