diff --git a/buzz/db/dao/transcription_dao.py b/buzz/db/dao/transcription_dao.py index c8f2ed85d..763a7a8cb 100644 --- a/buzz/db/dao/transcription_dao.py +++ b/buzz/db/dao/transcription_dao.py @@ -30,7 +30,8 @@ def create_transcription(self, task: FileTranscriptionTask): task, time_queued, url, - whisper_model_size + whisper_model_size, + hugging_face_model_id ) VALUES ( :id, :export_formats, @@ -43,7 +44,8 @@ def create_transcription(self, task: FileTranscriptionTask): :task, :time_queued, :url, - :whisper_model_size + :whisper_model_size, + :hugging_face_model_id ) """ ) @@ -74,6 +76,12 @@ def create_transcription(self, task: FileTranscriptionTask): if task.transcription_options.model.whisper_model_size else None, ) + query.bindValue( + ":hugging_face_model_id", + task.transcription_options.model.hugging_face_model_id + if task.transcription_options.model.hugging_face_model_id + else None, + ) if not query.exec(): raise Exception(query.lastError().text()) diff --git a/buzz/db/entity/transcription.py b/buzz/db/entity/transcription.py index 9055fa25b..50cdeab06 100644 --- a/buzz/db/entity/transcription.py +++ b/buzz/db/entity/transcription.py @@ -15,6 +15,7 @@ class Transcription(Entity): task: str = Task.TRANSCRIBE.value model_type: str = ModelType.WHISPER.value whisper_model_size: str | None = None + hugging_face_model_id: str | None = None language: str | None = None id: str = field(default_factory=lambda: str(uuid.uuid4())) error_message: str | None = None diff --git a/buzz/db/helpers.py b/buzz/db/helpers.py index a447fc6da..767fae9d9 100644 --- a/buzz/db/helpers.py +++ b/buzz/db/helpers.py @@ -15,8 +15,8 @@ def copy_transcriptions_from_json_to_sqlite(conn: Connection): for task in tasks: cursor.execute( """ - INSERT INTO transcription (id, error_message, export_formats, file, output_folder, progress, language, model_type, source, status, task, time_ended, time_queued, time_started, url, whisper_model_size) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO transcription (id, error_message, export_formats, file, output_folder, progress, language, model_type, source, status, task, time_ended, time_queued, time_started, url, whisper_model_size, hugging_face_model_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id; """, ( @@ -43,6 +43,9 @@ def copy_transcriptions_from_json_to_sqlite(conn: Connection): task.transcription_options.model.whisper_model_size.value if task.transcription_options.model.whisper_model_size else None, + task.transcription_options.model.hugging_face_model_id + if task.transcription_options.model.hugging_face_model_id + else None, ), ) transcription_id = cursor.fetchone()[0] diff --git a/buzz/schema.sql b/buzz/schema.sql index 7bac73e6e..ee18b4f0e 100644 --- a/buzz/schema.sql +++ b/buzz/schema.sql @@ -14,7 +14,8 @@ CREATE TABLE transcription ( time_queued TIMESTAMP NOT NULL, time_started TIMESTAMP, url TEXT, - whisper_model_size TEXT + whisper_model_size TEXT, + hugging_face_model_id TEXT ); CREATE TABLE transcription_segment ( diff --git a/buzz/widgets/transcription_record.py b/buzz/widgets/transcription_record.py index bfa3c4ae0..782b4e428 100644 --- a/buzz/widgets/transcription_record.py +++ b/buzz/widgets/transcription_record.py @@ -1,5 +1,5 @@ from uuid import UUID - +import logging from PyQt6.QtSql import QSqlRecord from buzz.model_loader import TranscriptionModel, ModelType, WhisperModelSize @@ -18,6 +18,9 @@ def model(record: QSqlRecord) -> TranscriptionModel: whisper_model_size=WhisperModelSize(record.value("whisper_model_size")) if record.value("whisper_model_size") else None, + hugging_face_model_id=record.value("hugging_face_model_id") + if record.value("hugging_face_model_id") + else None ) @staticmethod