Replies: 7 comments 24 replies
-
The short answer is no, and this comment from the maintainers sums it up:
The behaviour when multiple languages are present seems to be unpredictable with the current models. These past discussions may also be helpful. |
Beta Was this translation helpful? Give feedback.
-
One way of doing this is to combine speaker diarization with whisper. Here is standalone example code showing how this works.
Notes
I tested this with a short audio (narrated in Portuguese with segments of Polish) and it worked well. I don't know how well pyannote will identify speakers across different languages, it is my first time trying it. This is just a demo and obviously more code would be necessary to merge transcripts. Partial output sample showing the start/stop times, the speaker id, the language identified, and the transcripts of each audio block:
From Euronews, "Parlamento polaco investiga utilização do software Pegasus" |
Beta Was this translation helpful? Give feedback.
-
Here is updated code that supports output of srt or vtt files from the diarized transcripts. This is a little tricky because multiple whisper transcriptions are done to create a single file. large-v3 should also work fine now. At the bottom there is a sample main() showing how to use the new classes. import os
from typing import Any, Optional, TextIO
from pyannote.audio import Pipeline, Audio
import whisper
from whisper.utils import WriteSRT, WriteVTT
from whisper import Whisper
import torch
from math import ceil, floor
def diarize_audio(HF_AUTH_TOKEN, AUDIO_FILE):
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=HF_AUTH_TOKEN)
# Send pyannote pipeline to GPU (when available)
device: str = ""
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
pipeline.to(torch.device(device))
print(f"Diarize audio on {device}")
### diarization = pipeline(AUDIO_FILE)
io = Audio(mono='downmix', sample_rate=16000)
waveform, sample_rate = io(AUDIO_FILE)
diarization = pipeline({"waveform": waveform, "sample_rate": sample_rate})
return diarization
class AppendResultsMixin:
"""Class to return srt or vtt file path and open mode of write or append
to allow incremental writing.
"""
first_call: bool = True
output_path: str = ''
def get_path_and_open_mode(self, *, audio_path: str, dir: str, ext: str) -> tuple[str, str]:
mode: str
if self.first_call:
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
self.output_path: str = os.path.join(dir, audio_basename + "." + ext)
self.first_call = False
mode = 'w' # open for write initially
else:
mode = 'a' # open for append after
return self.output_path, mode
class WriteSRTIncremental(AppendResultsMixin, WriteSRT):
"""Incrementally create an SRT file with multiple calls appending new entries
to the file.
"""
srt_index: int = 1 # index for srt blocks retained across multiple calls
def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs):
path, mode = self.get_path_and_open_mode(audio_path=audio_path, dir=self.output_dir, ext=self.extension)
with open(path, mode, encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs) # type: ignore
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for (start, end, text) in self.iterate_result(result, options, **kwargs):
print(f"{self.srt_index}\n{start} --> {end}\n{text}\n", file=file, flush=True)
self.srt_index += 1
class WriteVTTIncremental(AppendResultsMixin, WriteVTT):
"""Incrementally create a VTT file with multiple calls appending new entries
to the file.
"""
def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs):
path, mode = self.get_path_and_open_mode(audio_path=audio_path, dir=self.output_dir, ext=self.extension)
with open(path, mode, encoding="utf-8") as f:
if mode != 'a':
print("WEBVTT\n", file=f)
self.write_result(result, file=f, options=options, **kwargs) # type: ignore
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WhisperFacade:
wmodel: Whisper
def __init__(self, model:str, *, quantize=False) -> None:
"""Load the Whisper model and optionally quantize."""
print("Initialize whisper")
whisper_model = whisper.load_model(model)
if quantize:
print("Quantize")
DTYPE = torch.qint8
qmodel: Whisper = torch.quantization.quantize_dynamic(
whisper_model, {torch.nn.Linear}, dtype=DTYPE)
del whisper_model
self.wmodel = qmodel
else:
self.wmodel = whisper_model
def _set_timing_for(self, segment: dict[str, float], # simplified typing
offset: float) -> None:
"""For speech fragments in different parts of an audio file, patch the
whisper segment and word timing using the offset (typically the diarization offset)
in seconds. This makes the timing accurate for subtitles when multiple
calls to whisper are used for various parts of the audio.
"""
s = segment
s['start'] += offset
s['end'] += offset
# Update word start/stop times, if present
if 'words' in s:
w: dict[str, float] # simplified typing
for w in s['words']: # type: ignore
w['start'] += offset
w['end'] += offset
def load_audio(self, file_path: str):
self.audio = whisper.load_audio(file_path)
def transcribe(self, *, start: float, end: float, options: dict[str, Any] ) -> dict[str, Any]:
"""Transcribe from start time to end time (both in seconds)."""
SAMPLE_RATE = 16_000 # 16kHz audio
start_index = floor(start * SAMPLE_RATE)
end_index = ceil(end * SAMPLE_RATE)
audio_segment = self.audio[start_index:end_index]
result = whisper.transcribe(self.wmodel, audio_segment, **options)
#
segments = result['segments']
s: dict[str, float] # simplified typing
for s in segments: # type: ignore
self._set_timing_for(segment=s, offset=start)
return result
# Demonstration of creating srt or vtt subtitles from multilanguage audio
HF_AUTH_TOKEN = "INSERT YOUR TOKEN"
AUDIO_FILE = "INSERT YOUR FILE.mp3"
torch.set_num_threads(6) # change as appropriate
def main():
diarization = diarize_audio(HF_AUTH_TOKEN, AUDIO_FILE)
model = WhisperFacade(model='medium', quantize=True)
model.load_audio(AUDIO_FILE)
#
writer = WriteSRTIncremental('.')
#writer = WriteVTTIncremental('.')
whisper_options = {"verbose": None, "word_timestamps": True,
"task": "transcribe", "suppress_tokens": ""}
writer_options = {"max_line_width":55, "max_line_count":2, "highlight_words":False}
print("Process diarized blocks")
for turn, _, speaker in diarization.itertracks(yield_label=True):
if turn.end - turn.start < 0.5: # Suppress short utterances (pyannote artifact)
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s IGNORED")
continue
result = model.transcribe(start=turn.start, end=turn.end, options=whisper_options)
language = result['language']
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s lang={language} {speaker}")
writer(result, AUDIO_FILE, writer_options )
if __name__ == '__main__' :
main() |
Beta Was this translation helpful? Give feedback.
-
Does the diarization model support the Spanish language if my audio is exclusively in that language? |
Beta Was this translation helpful? Give feedback.
-
I have audio with a single speaker switching between native (Marathi) language and English. The speaker diarization algorithms will not work, as it is same speaker switching the languages. Has anyone encountered this problem and found a solution? |
Beta Was this translation helpful? Give feedback.
-
whisper is baddest while it comes to bengali |
Beta Was this translation helpful? Give feedback.
-
For the mixed (multiple) languages in an audio file, I think Whisper should learn from Microsoft Cognitive which supports mixed languages. |
Beta Was this translation helpful? Give feedback.
-
Hey Community!
I've been experimenting with Whisper (locally installed) for a project involving multi-language audio transcription. My audio samples contain several languages. However, the results have been somewhat perplexing, and I'm hoping to gain insights on what I am doing wrong.
I created a sample audio file that starts with a sentence in German, followed by two sentences in English, and concludes with two sentences in Spanish. Given that the default task for Whisper is set to transcribe (as per the
whisper --help
documentation), my expectation was a straightforward transcription, not translation. However, what I get back from Whisper is a translation to English.Observed Behavior:
Default Model with No Explicit Task or Language Arguments: All output text was translated into English, contrary to my expectation of a transcription retaining the original languages.
Explicit Transcription Task with No Language Argument: Similar to the first scenario, the output was entirely in English, ignoring the multi-language nature of the audio. Using
--task transcribe
had no effect.Setting the Language Argument to "de": Interestingly, this produced the correct transcription, preserving all three languages (German, English, and Spanish) as spoken in the audio.
It seems that Whisper can accurately detect and switch between languages when a language argument is specified. However, the output does not align with what one might expect based on the command line arguments used. This discrepancy leaves me unsure about the output's nature, whether it's a direct transcription or a translation.
I have noticed that switching to the medium or large model (using the
--model medium
or--model large
argument), the transcription detects the different languages correctly and outputs them as spoken. Independent from the --task argument.My Questions:
I appreciate any insights, experiences, or advice you can share.
Thank you in advance for your help!
Phil
System Info:
Beta Was this translation helpful? Give feedback.
All reactions