From 2dc0797e64931a62a7209d0cb30913a73de949bf Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Mon, 26 Jun 2023 08:46:17 +0100 Subject: [PATCH] Refactor transcribers class (#508) --- buzz/gui.py | 3 +- buzz/recording_transcriber.py | 139 ++++++++++++++++++++++++++++++++++ buzz/transcriber.py | 128 ------------------------------- tests/transcriber_test.py | 4 +- 4 files changed, 143 insertions(+), 131 deletions(-) create mode 100644 buzz/recording_transcriber.py diff --git a/buzz/gui.py b/buzz/gui.py index 524db85ce..6c98497ff 100644 --- a/buzz/gui.py +++ b/buzz/gui.py @@ -35,8 +35,9 @@ from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, OutputFormat, Task, TranscriptionOptions, - FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL, + FileTranscriptionTask, LOADED_WHISPER_DLL, DEFAULT_WHISPER_TEMPERATURE, LANGUAGES) +from .recording_transcriber import RecordingTranscriber from .file_transcriber_queue_worker import FileTranscriberQueueWorker from .widgets.line_edit import LineEdit from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog diff --git a/buzz/recording_transcriber.py b/buzz/recording_transcriber.py new file mode 100644 index 000000000..e4d2d8423 --- /dev/null +++ b/buzz/recording_transcriber.py @@ -0,0 +1,139 @@ +import datetime +import logging +import threading +from typing import Optional + +import numpy as np +import sounddevice +import whisper +from PyQt6.QtCore import QObject, pyqtSignal +from sounddevice import PortAudioError + +from buzz import transformers_whisper +from buzz.model_loader import ModelType +from buzz.transcriber import TranscriptionOptions, WhisperCpp, whisper_cpp_params +from buzz.transformers_whisper import TransformersWhisper + + +class RecordingTranscriber(QObject): + transcription = pyqtSignal(str) + finished = pyqtSignal() + error = pyqtSignal(str) + is_running = False + MAX_QUEUE_SIZE = 10 + + def __init__(self, transcription_options: TranscriptionOptions, + input_device_index: Optional[int], sample_rate: int, model_path: str, + parent: Optional[QObject] = None) -> None: + super().__init__(parent) + self.transcription_options = transcription_options + self.current_stream = None + self.input_device_index = input_device_index + self.sample_rate = sample_rate + self.model_path = model_path + self.n_batch_samples = 5 * self.sample_rate # every 5 seconds + # pause queueing if more than 3 batches behind + self.max_queue_size = 3 * self.n_batch_samples + self.queue = np.ndarray([], dtype=np.float32) + self.mutex = threading.Lock() + + def start(self): + model_path = self.model_path + + if self.transcription_options.model.model_type == ModelType.WHISPER: + model = whisper.load_model(model_path) + elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP: + model = WhisperCpp(model_path) + else: # ModelType.HUGGING_FACE + model = transformers_whisper.load_model(model_path) + + initial_prompt = self.transcription_options.initial_prompt + + logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s', + self.transcription_options, model_path, self.sample_rate, self.input_device_index) + + self.is_running = True + try: + with sounddevice.InputStream(samplerate=self.sample_rate, + device=self.input_device_index, dtype="float32", + channels=1, callback=self.stream_callback): + while self.is_running: + self.mutex.acquire() + if self.queue.size >= self.n_batch_samples: + samples = self.queue[:self.n_batch_samples] + self.queue = self.queue[self.n_batch_samples:] + self.mutex.release() + + logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s', + samples.size, self.queue.size, self.amplitude(samples)) + time_started = datetime.datetime.now() + + if self.transcription_options.model.model_type == ModelType.WHISPER: + assert isinstance(model, whisper.Whisper) + result = model.transcribe( + audio=samples, language=self.transcription_options.language, + task=self.transcription_options.task.value, + initial_prompt=initial_prompt, + temperature=self.transcription_options.temperature) + elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP: + assert isinstance(model, WhisperCpp) + result = model.transcribe( + audio=samples, + params=whisper_cpp_params( + language=self.transcription_options.language + if self.transcription_options.language is not None else 'en', + task=self.transcription_options.task.value, word_level_timings=False)) + else: + assert isinstance(model, TransformersWhisper) + result = model.transcribe(audio=samples, + language=self.transcription_options.language + if self.transcription_options.language is not None else 'en', + task=self.transcription_options.task.value) + + next_text: str = result.get('text') + + # Update initial prompt between successive recording chunks + initial_prompt += next_text + + logging.debug('Received next result, length = %s, time taken = %s', + len(next_text), datetime.datetime.now() - time_started) + self.transcription.emit(next_text) + else: + self.mutex.release() + except PortAudioError as exc: + self.error.emit(str(exc)) + logging.exception('') + return + + self.finished.emit() + + @staticmethod + def get_device_sample_rate(device_id: Optional[int]) -> int: + """Returns the sample rate to be used for recording. It uses the default sample rate + provided by Whisper if the microphone supports it, or else it uses the device's default + sample rate. + """ + whisper_sample_rate = whisper.audio.SAMPLE_RATE + try: + sounddevice.check_input_settings( + device=device_id, samplerate=whisper_sample_rate) + return whisper_sample_rate + except PortAudioError: + device_info = sounddevice.query_devices(device=device_id) + if isinstance(device_info, dict): + return int(device_info.get('default_samplerate', whisper_sample_rate)) + return whisper_sample_rate + + def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status): + # Try to enqueue the next block. If the queue is already full, drop the block. + chunk: np.ndarray = in_data.ravel() + with self.mutex: + if self.queue.size < self.max_queue_size: + self.queue = np.append(self.queue, chunk) + + @staticmethod + def amplitude(arr: np.ndarray): + return (abs(max(arr)) + abs(min(arr))) / 2 + + def stop_recording(self): + self.is_running = False diff --git a/buzz/transcriber.py b/buzz/transcriber.py index 8b4cab561..7b82e4b3d 100644 --- a/buzz/transcriber.py +++ b/buzz/transcriber.py @@ -7,7 +7,6 @@ import os import sys import tempfile -import threading from abc import abstractmethod from dataclasses import dataclass, field from multiprocessing.connection import Connection @@ -19,18 +18,15 @@ import ffmpeg import numpy as np import openai -import sounddevice import stable_whisper import tqdm import whisper from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot -from sounddevice import PortAudioError from whisper import tokenizer from . import transformers_whisper from .conn import pipe_stderr from .model_loader import TranscriptionModel, ModelType -from .transformers_whisper import TransformersWhisper # Catch exception from whisper.dll not getting loaded. # TODO: Remove flag and try-except when issue with loading @@ -101,130 +97,6 @@ class Status(enum.Enum): completed_at: Optional[datetime.datetime] = None -class RecordingTranscriber(QObject): - transcription = pyqtSignal(str) - finished = pyqtSignal() - error = pyqtSignal(str) - is_running = False - MAX_QUEUE_SIZE = 10 - - def __init__(self, transcription_options: TranscriptionOptions, - input_device_index: Optional[int], sample_rate: int, model_path: str, - parent: Optional[QObject] = None) -> None: - super().__init__(parent) - self.transcription_options = transcription_options - self.current_stream = None - self.input_device_index = input_device_index - self.sample_rate = sample_rate - self.model_path = model_path - self.n_batch_samples = 5 * self.sample_rate # every 5 seconds - # pause queueing if more than 3 batches behind - self.max_queue_size = 3 * self.n_batch_samples - self.queue = np.ndarray([], dtype=np.float32) - self.mutex = threading.Lock() - - def start(self): - model_path = self.model_path - - if self.transcription_options.model.model_type == ModelType.WHISPER: - model = whisper.load_model(model_path) - elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP: - model = WhisperCpp(model_path) - else: # ModelType.HUGGING_FACE - model = transformers_whisper.load_model(model_path) - - initial_prompt = self.transcription_options.initial_prompt - - logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s', - self.transcription_options, model_path, self.sample_rate, self.input_device_index) - - self.is_running = True - try: - with sounddevice.InputStream(samplerate=self.sample_rate, - device=self.input_device_index, dtype="float32", - channels=1, callback=self.stream_callback): - while self.is_running: - self.mutex.acquire() - if self.queue.size >= self.n_batch_samples: - samples = self.queue[:self.n_batch_samples] - self.queue = self.queue[self.n_batch_samples:] - self.mutex.release() - - logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s', - samples.size, self.queue.size, self.amplitude(samples)) - time_started = datetime.datetime.now() - - if self.transcription_options.model.model_type == ModelType.WHISPER: - assert isinstance(model, whisper.Whisper) - result = model.transcribe( - audio=samples, language=self.transcription_options.language, - task=self.transcription_options.task.value, - initial_prompt=initial_prompt, - temperature=self.transcription_options.temperature) - elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP: - assert isinstance(model, WhisperCpp) - result = model.transcribe( - audio=samples, - params=whisper_cpp_params( - language=self.transcription_options.language - if self.transcription_options.language is not None else 'en', - task=self.transcription_options.task.value, word_level_timings=False)) - else: - assert isinstance(model, TransformersWhisper) - result = model.transcribe(audio=samples, - language=self.transcription_options.language - if self.transcription_options.language is not None else 'en', - task=self.transcription_options.task.value) - - next_text: str = result.get('text') - - # Update initial prompt between successive recording chunks - initial_prompt += next_text - - logging.debug('Received next result, length = %s, time taken = %s', - len(next_text), datetime.datetime.now() - time_started) - self.transcription.emit(next_text) - else: - self.mutex.release() - except PortAudioError as exc: - self.error.emit(str(exc)) - logging.exception('') - return - - self.finished.emit() - - @staticmethod - def get_device_sample_rate(device_id: Optional[int]) -> int: - """Returns the sample rate to be used for recording. It uses the default sample rate - provided by Whisper if the microphone supports it, or else it uses the device's default - sample rate. - """ - whisper_sample_rate = whisper.audio.SAMPLE_RATE - try: - sounddevice.check_input_settings( - device=device_id, samplerate=whisper_sample_rate) - return whisper_sample_rate - except PortAudioError: - device_info = sounddevice.query_devices(device=device_id) - if isinstance(device_info, dict): - return int(device_info.get('default_samplerate', whisper_sample_rate)) - return whisper_sample_rate - - def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status): - # Try to enqueue the next block. If the queue is already full, drop the block. - chunk: np.ndarray = in_data.ravel() - with self.mutex: - if self.queue.size < self.max_queue_size: - self.queue = np.append(self.queue, chunk) - - @staticmethod - def amplitude(arr: np.ndarray): - return (abs(max(arr)) + abs(min(arr))) / 2 - - def stop_recording(self): - self.is_running = False - - class OutputFormat(enum.Enum): TXT = 'txt' SRT = 'srt' diff --git a/tests/transcriber_test.py b/tests/transcriber_test.py index c234237f8..09f54203b 100644 --- a/tests/transcriber_test.py +++ b/tests/transcriber_test.py @@ -11,11 +11,11 @@ from pytestqt.qtbot import QtBot from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel -from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, RecordingTranscriber, - Segment, Task, WhisperCpp, WhisperCppFileTranscriber, +from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, Segment, Task, WhisperCpp, WhisperCppFileTranscriber, WhisperFileTranscriber, get_default_output_file_path, to_timestamp, whisper_cpp_params, write_output, TranscriptionOptions) +from buzz.recording_transcriber import RecordingTranscriber from tests.mock_sounddevice import MockInputStream from tests.model_loader import get_model_path