Skip to content

Commit

Permalink
Add models preferences (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Apr 28, 2023
1 parent 0a4be2b commit cb5ad74
Show file tree
Hide file tree
Showing 15 changed files with 583 additions and 247 deletions.
146 changes: 57 additions & 89 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
from PyQt6 import QtGui
from PyQt6.QtCore import (QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QPoint,
QUrlQuery, QMetaObject, QEvent)
QUrlQuery, QMetaObject, QEvent, QThreadPool)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QDialogButtonBox, QFileDialog, QLabel, QLineEdit,
QMainWindow, QMessageBox, QPlainTextEdit,
QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
QProgressDialog, QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox, QTableWidget,
QMenuBar, QFormLayout, QTableWidgetItem,
QHeaderView, QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)
from whisper import tokenizer

from buzz.cache import TasksCache
Expand All @@ -30,7 +29,8 @@
from .assets import get_asset_path
from .icon import Icon
from .locale import _
from .model_loader import ModelLoader, WhisperModelSize, ModelType, TranscriptionModel
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, get_local_model_path, \
ModelDownloader
from .paths import file_paths_as_title
from .recording import RecordingAmplitudeListener
from .settings.settings import Settings, APP_NAME
Expand All @@ -43,6 +43,8 @@
FileTranscriberQueueWorker, FileTranscriptionTask, RecordingTranscriber, LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE)
from .widgets.line_edit import LineEdit
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from .widgets.model_type_combo_box import ModelTypeComboBox
from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
from .widgets.preferences_dialog import PreferencesDialog
from .widgets.toolbar import ToolBar
Expand Down Expand Up @@ -163,33 +165,6 @@ def set_recording(self):
self.setDefault(False)


class DownloadModelProgressDialog(QProgressDialog):
start_time: datetime

def __init__(self, parent: Optional[QWidget], *args) -> None:
super().__init__(_('Downloading model (0%, unknown time remaining)'),
_('Cancel'), 0, 100, parent, *args)

# Setting this to a high value to avoid showing the dialog for models that
# are checked locally but set progress to 0 immediately, i.e. Hugging Face or Faster Whisper models
self.setMinimumDuration(10_000)

self.setWindowModality(Qt.WindowModality.ApplicationModal)
self.start_time = datetime.now()
self.setFixedSize(self.size())

def set_fraction_completed(self, fraction_completed: float) -> None:
self.setValue(int(fraction_completed * self.maximum()))

if fraction_completed > 0.0:
time_spent = (datetime.now() - self.start_time).total_seconds()
time_left = (time_spent / fraction_completed) - time_spent

self.setLabelText(_('Downloading model') +
f' ({fraction_completed :.0%}, ' +
humanize.naturaldelta(time_left) + ')')


def show_model_download_error_dialog(parent: QWidget, error: str):
message = parent.tr(
'An error occurred while loading the Whisper model') + \
Expand All @@ -200,9 +175,8 @@ def show_model_download_error_dialog(parent: QWidget, error: str):


class FileTranscriberWidget(QWidget):
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
model_loader: Optional[ModelLoader] = None
transcriber_thread: Optional[QThread] = None
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
model_loader: Optional[ModelDownloader] = None
file_transcription_options: FileTranscriptionOptions
transcription_options: TranscriptionOptions
is_transcribing = False
Expand Down Expand Up @@ -291,23 +265,16 @@ def on_transcription_options_changed(self, transcription_options: TranscriptionO
def on_click_run(self):
self.run_button.setDisabled(True)

self.transcriber_thread = QThread()
self.model_loader = ModelLoader(model=self.transcription_options.model)
self.model_loader.moveToThread(self.transcriber_thread)

self.transcriber_thread.started.connect(self.model_loader.run)
self.model_loader.finished.connect(
self.transcriber_thread.quit)

self.model_loader.progress.connect(self.on_download_model_progress)

self.model_loader.error.connect(self.on_download_model_error)
self.model_loader.error.connect(self.model_loader.deleteLater)

self.model_loader.finished.connect(self.on_model_loaded)
self.model_loader.finished.connect(self.model_loader.deleteLater)
model_path = get_local_model_path(model=self.transcription_options.model)
if model_path is not None:
self.on_model_loaded(model_path)
return

self.transcriber_thread.start()
self.model_loader = ModelDownloader(model=self.transcription_options.model)
self.model_loader.signals.progress.connect(self.on_download_model_progress)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.model_loader.signals.finished.connect(self.on_model_loaded)
QThreadPool().globalInstance().start(self.model_loader)

def on_model_loaded(self, model_path: str):
self.reset_transcriber_controls()
Expand All @@ -320,12 +287,13 @@ def on_download_model_progress(self, progress: Tuple[float, float]):
(current_size, total_size) = progress

if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = DownloadModelProgressDialog(parent=self)
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)

if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)

def on_download_model_error(self, error: str):
self.reset_model_download()
Expand All @@ -337,7 +305,7 @@ def reset_transcriber_controls(self):

def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
self.model_loader.stop()
self.model_loader.cancel()
self.reset_model_download()

def reset_model_download(self):
Expand All @@ -349,8 +317,8 @@ def on_word_level_timings_changed(self, value: int):
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value

def closeEvent(self, event: QtGui.QCloseEvent) -> None:
if self.transcriber_thread is not None:
self.transcriber_thread.wait()
if self.model_loader is not None:
self.model_loader.cancel()

self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task)
Expand Down Expand Up @@ -432,9 +400,9 @@ class RecordingTranscriberWidget(QWidget):
current_status: 'RecordingStatus'
transcription_options: TranscriptionOptions
selected_device_id: Optional[int]
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
transcriber: Optional[RecordingTranscriber] = None
model_loader: Optional[ModelLoader] = None
model_loader: Optional[ModelDownloader] = None
transcription_thread: Optional[QThread] = None
recording_amplitude_listener: Optional[RecordingAmplitudeListener] = None
device_sample_rate: Optional[int] = None
Expand Down Expand Up @@ -541,29 +509,35 @@ def on_record_button_clicked(self):
def start_recording(self):
self.record_button.setDisabled(True)

model_path = get_local_model_path(model=self.transcription_options.model)
if model_path is not None:
self.on_model_loaded(model_path)
return

self.model_loader = ModelDownloader(model=self.transcription_options.model)
self.model_loader.signals.progress.connect(self.on_download_model_progress)
self.model_loader.signals.error.connect(self.on_download_model_error)
self.model_loader.signals.finished.connect(self.on_model_loaded)
QThreadPool().globalInstance().start(self.model_loader)

def on_model_loaded(self, model_path: str):
self.reset_recording_controls()
self.model_loader = None

self.transcription_thread = QThread()

self.model_loader = ModelLoader(model=self.transcription_options.model)
# TODO: make runnable
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
sample_rate=self.device_sample_rate,
transcription_options=self.transcription_options)
transcription_options=self.transcription_options,
model_path=model_path)

self.model_loader.moveToThread(self.transcription_thread)
self.transcriber.moveToThread(self.transcription_thread)

self.transcription_thread.started.connect(self.model_loader.run)
self.transcription_thread.started.connect(self.transcriber.start)
self.transcription_thread.finished.connect(
self.transcription_thread.deleteLater)

self.model_loader.finished.connect(self.reset_recording_controls)
self.model_loader.finished.connect(self.transcriber.start)
self.model_loader.finished.connect(self.model_loader.deleteLater)

self.model_loader.progress.connect(
self.on_download_model_progress)

self.model_loader.error.connect(self.on_download_model_error)

self.transcriber.transcription.connect(self.on_next_transcription)

self.transcriber.finished.connect(self.on_transcriber_finished)
Expand All @@ -580,12 +554,13 @@ def on_download_model_progress(self, progress: Tuple[float, float]):
(current_size, total_size) = progress

if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = DownloadModelProgressDialog(parent=self)
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)

if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_fraction_completed(fraction_completed=current_size / total_size)
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)

def set_recording_status_stopped(self):
self.record_button.set_stopped()
Expand Down Expand Up @@ -640,9 +615,7 @@ def reset_recording_controls(self):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.reset_record_button()
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.close()
self.model_download_progress_dialog = None
self.reset_model_download()

def reset_record_button(self):
self.record_button.setEnabled(True)
Expand All @@ -651,6 +624,9 @@ def on_recording_amplitude_changed(self, amplitude: float):
self.audio_meter_widget.update_amplitude(amplitude)

def closeEvent(self, event: QCloseEvent) -> None:
if self.model_loader is not None:
self.model_loader.cancel()

self.stop_recording()
if self.recording_amplitude_listener is not None:
self.recording_amplitude_listener.stop_recording()
Expand Down Expand Up @@ -1264,17 +1240,10 @@ def __init__(self, default_transcription_options: TranscriptionOptions = Transcr
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(self.on_hugging_face_model_changed)

self.model_type_combo_box = QComboBox(self)
if model_types is None:
model_types = [model_type for model_type in ModelType]
for model_type in model_types:
# Hide Whisper.cpp option is whisper.dll did not load correctly.
# See: https://github.com/chidiwilliams/buzz/issues/274, https://github.com/chidiwilliams/buzz/issues/197
if model_type == ModelType.WHISPER_CPP and LOADED_WHISPER_DLL is False:
continue
self.model_type_combo_box.addItem(model_type.value)
self.model_type_combo_box.setCurrentText(default_transcription_options.model.model_type.value)
self.model_type_combo_box.currentTextChanged.connect(self.on_model_type_changed)
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
default_model=default_transcription_options.model.model_type,
parent=self)
self.model_type_combo_box.changed.connect(self.on_model_type_changed)

self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems([size.value.title() for size in WhisperModelSize])
Expand Down Expand Up @@ -1339,8 +1308,7 @@ def reset_visible_rows(self):
model_type == ModelType.FASTER_WHISPER))
self.form_layout.setRowVisible(self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API)

def on_model_type_changed(self, text: str):
model_type = ModelType(text)
def on_model_type_changed(self, model_type: ModelType):
self.transcription_options.model.model_type = model_type
self.reset_visible_rows()
self.transcription_options_changed.emit(self.transcription_options)
Expand Down
Loading

0 comments on commit cb5ad74

Please sign in to comment.