Skip to content

Commit

Permalink
Will add huggingface model id to the transcription table (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
raivisdejus authored May 26, 2024
1 parent 5ba8eaa commit 7820952
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 6 deletions.
12 changes: 10 additions & 2 deletions buzz/db/dao/transcription_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def create_transcription(self, task: FileTranscriptionTask):
task,
time_queued,
url,
whisper_model_size
whisper_model_size,
hugging_face_model_id
) VALUES (
:id,
:export_formats,
Expand All @@ -43,7 +44,8 @@ def create_transcription(self, task: FileTranscriptionTask):
:task,
:time_queued,
:url,
:whisper_model_size
:whisper_model_size,
:hugging_face_model_id
)
"""
)
Expand Down Expand Up @@ -74,6 +76,12 @@ def create_transcription(self, task: FileTranscriptionTask):
if task.transcription_options.model.whisper_model_size
else None,
)
query.bindValue(
":hugging_face_model_id",
task.transcription_options.model.hugging_face_model_id
if task.transcription_options.model.hugging_face_model_id
else None,
)
if not query.exec():
raise Exception(query.lastError().text())

Expand Down
1 change: 1 addition & 0 deletions buzz/db/entity/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Transcription(Entity):
task: str = Task.TRANSCRIBE.value
model_type: str = ModelType.WHISPER.value
whisper_model_size: str | None = None
hugging_face_model_id: str | None = None
language: str | None = None
id: str = field(default_factory=lambda: str(uuid.uuid4()))
error_message: str | None = None
Expand Down
7 changes: 5 additions & 2 deletions buzz/db/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def copy_transcriptions_from_json_to_sqlite(conn: Connection):
for task in tasks:
cursor.execute(
"""
INSERT INTO transcription (id, error_message, export_formats, file, output_folder, progress, language, model_type, source, status, task, time_ended, time_queued, time_started, url, whisper_model_size)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO transcription (id, error_message, export_formats, file, output_folder, progress, language, model_type, source, status, task, time_ended, time_queued, time_started, url, whisper_model_size, hugging_face_model_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
RETURNING id;
""",
(
Expand All @@ -43,6 +43,9 @@ def copy_transcriptions_from_json_to_sqlite(conn: Connection):
task.transcription_options.model.whisper_model_size.value
if task.transcription_options.model.whisper_model_size
else None,
task.transcription_options.model.hugging_face_model_id
if task.transcription_options.model.hugging_face_model_id
else None,
),
)
transcription_id = cursor.fetchone()[0]
Expand Down
3 changes: 2 additions & 1 deletion buzz/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ CREATE TABLE transcription (
time_queued TIMESTAMP NOT NULL,
time_started TIMESTAMP,
url TEXT,
whisper_model_size TEXT
whisper_model_size TEXT,
hugging_face_model_id TEXT
);

CREATE TABLE transcription_segment (
Expand Down
5 changes: 4 additions & 1 deletion buzz/widgets/transcription_record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from uuid import UUID

import logging
from PyQt6.QtSql import QSqlRecord

from buzz.model_loader import TranscriptionModel, ModelType, WhisperModelSize
Expand All @@ -18,6 +18,9 @@ def model(record: QSqlRecord) -> TranscriptionModel:
whisper_model_size=WhisperModelSize(record.value("whisper_model_size"))
if record.value("whisper_model_size")
else None,
hugging_face_model_id=record.value("hugging_face_model_id")
if record.value("hugging_face_model_id")
else None
)

@staticmethod
Expand Down

0 comments on commit 7820952

Please sign in to comment.