Skip to content

Commit

Permalink
speed up by caching model context
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle committed Jun 23, 2024
1 parent d4e3ecd commit 5276004
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 26 deletions.
1 change: 0 additions & 1 deletion core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ pub fn get_model_path() -> Result<PathBuf> {
#[derive(Deserialize, Serialize)]
pub struct TranscribeOptions {
pub path: PathBuf,
pub model_path: PathBuf,
pub lang: Option<String>,
pub verbose: bool,

Expand Down
34 changes: 19 additions & 15 deletions core/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,36 @@ use crate::audio;
use crate::config::TranscribeOptions;
use crate::transcript::{Segment, Transcript};
use eyre::{bail, Context, Ok, OptionExt, Result};
use std::path::Path;
use std::sync::Mutex;
use std::time::Instant;
pub use whisper_rs::SegmentCallbackData;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};

pub use whisper_rs::WhisperContext;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContextParameters};
type ProgressCallbackType = once_cell::sync::Lazy<Mutex<Option<Box<dyn Fn(i32) + Send + Sync>>>>;
static PROGRESS_CALLBACK: ProgressCallbackType = once_cell::sync::Lazy::new(|| Mutex::new(None));

pub fn create_context(model_path: &Path) -> Result<WhisperContext> {
log::debug!("open model...");
if !model_path.exists() {
bail!("whisper file doesn't exist")
}
let ctx = WhisperContext::new_with_params(
model_path.to_str().ok_or_eyre("can't convert model option to str")?,
WhisperContextParameters::default(),
)
.context("failed to open model")?;
Ok(ctx)
}

pub fn transcribe(
ctx: &WhisperContext,
options: &TranscribeOptions,
progress_callback: Option<Box<dyn Fn(i32) + Send + Sync>>,
new_segment_callback: Option<Box<dyn Fn(whisper_rs::SegmentCallbackData)>>,
abort_callback: Option<Box<dyn Fn() -> bool>>,
) -> Result<Transcript> {
log::debug!("Transcribe called with {:?}", options);
if !options.model_path.exists() {
bail!("whisper file doesn't exist")
}

if !options.path.clone().exists() {
bail!("audio file doesn't exist")
Expand All @@ -40,18 +52,10 @@ pub fn transcribe(
let mut samples = vec![0.0f32; original_samples.len()];
whisper_rs::install_whisper_log_trampoline();
whisper_rs::convert_integer_to_float_audio(&original_samples, &mut samples)?;

log::debug!("open model...");
let ctx = WhisperContext::new_with_params(
options.model_path.to_str().ok_or_eyre("can't convert model option to str")?,
WhisperContextParameters::default(),
)
.context("failed to open model")?;
let mut state = ctx.create_state().context("failed to create key")?;

let mut params = FullParams::new(SamplingStrategy::default());
log::debug!("set language to {:?}", options.lang);

if let Some(true) = options.translate {
params.set_translate(true);
}
Expand Down Expand Up @@ -169,7 +173,6 @@ mod tests {
.ok_or_eyre("cant convert path to str")?
.to_owned()
.into(),
model_path: config::get_model_path()?,
lang: None,
n_threads: None,
verbose: false,
Expand All @@ -178,7 +181,8 @@ mod tests {
translate: None,
max_text_ctx: None,
};
transcribe(args, None, None, None)?;
let ctx = create_context(Path::new(&config::get_model_path().unwrap())).unwrap();
transcribe(&ctx, args, None, None, None)?;

Ok(())
}
Expand Down
8 changes: 4 additions & 4 deletions desktop/src-tauri/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ pub fn run(app: &App) {

let args = Args::parse();
let lang = language_name_to_whisper_lang(&args.language);
let mut options = TranscribeOptions {
let options = TranscribeOptions {
path: args.file,
model_path: args.model,
lang: Some(lang),
init_prompt: args.init_prompt,
n_threads: args.n_threads,
Expand All @@ -130,11 +129,12 @@ pub fn run(app: &App) {
verbose: false,
max_text_ctx: args.max_text_ctx,
};
options.model_path = prepare_model_path(&options.model_path);
let model_path = prepare_model_path(&args.model);

eprintln!("Transcribe... 🔄");
let start = Instant::now(); // Measure start time
let transcript = model::transcribe(&options, None, None, None).unwrap();
let ctx = model::create_context(&model_path).unwrap();
let transcript = model::transcribe(&ctx, &options, None, None, None).unwrap();
let elapsed = start.elapsed();
println!(
"{}",
Expand Down
42 changes: 40 additions & 2 deletions desktop/src-tauri/src/cmd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use crate::config;
use crate::setup::ModelContext;
use eyre::{bail, Context, ContextCompat, OptionExt, Result};
use serde_json::{json, Value};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use tauri::State;
use tauri::{
window::{ProgressBarState, ProgressBarStatus},
Manager,
};
use tokio::sync::Mutex;
use vibe::{model::SegmentCallbackData, transcript::Transcript};
pub mod audio;

Expand Down Expand Up @@ -142,7 +145,16 @@ pub async fn get_default_model_path() -> Result<String> {
}

#[tauri::command]
pub async fn transcribe(app_handle: tauri::AppHandle, options: vibe::config::TranscribeOptions) -> Result<Transcript> {
pub async fn transcribe(
app_handle: tauri::AppHandle,
options: vibe::config::TranscribeOptions,
model_context_state: State<'_, Mutex<Option<ModelContext>>>,
) -> Result<Transcript> {
let model_context = model_context_state.lock().await;
if model_context.is_none() {
bail!("Please load model first")
}
let ctx = model_context.as_ref().unwrap();
let app_handle_c = app_handle.clone();

let new_segment_callback = move |data: SegmentCallbackData| {
Expand Down Expand Up @@ -176,6 +188,7 @@ pub async fn transcribe(app_handle: tauri::AppHandle, options: vibe::config::Tra
// prevent panic crash. sometimes whisper.cpp crash without nice errors.
let unwind_result = catch_unwind(AssertUnwindSafe(|| {
vibe::model::transcribe(
&ctx.handle,
&options,
Some(Box::new(progress_callback)),
Some(Box::new(new_segment_callback)),
Expand Down Expand Up @@ -257,3 +270,28 @@ pub fn is_avx2_enabled() -> bool {
#[allow(clippy::comparison_to_empty)]
return env!("WHISPER_NO_AVX") != "ON";
}

#[tauri::command]
pub async fn load_model(model_path: String, model_context_state: State<'_, Mutex<Option<ModelContext>>>) -> Result<String> {
let mut state_guard = model_context_state.lock().await;
if let Some(state) = state_guard.as_ref() {
// check if new path is different
if model_path != state.path {
log::debug!("model path changed. reloading");
// reload
let context = vibe::model::create_context(Path::new(&model_path))?;
*state_guard = Some(ModelContext {
path: model_path.clone(),
handle: context,
});
}
} else {
log::debug!("loading model first time");
let context = vibe::model::create_context(Path::new(&model_path))?;
*state_guard = Some(ModelContext {
path: model_path.clone(),
handle: context,
});
}
Ok(model_path)
}
1 change: 1 addition & 0 deletions desktop/src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ fn main() {
.invoke_handler(tauri::generate_handler![
cmd::transcribe,
cmd::download_model,
cmd::load_model,
cmd::get_default_model_path,
cmd::get_commit_hash,
cmd::get_cuda_version,
Expand Down
10 changes: 10 additions & 0 deletions desktop/src-tauri/src/setup.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
use crate::{cli, panic_hook};
use tauri::{App, Manager};
use tokio::sync::Mutex;
use vibe::model::WhisperContext;

pub struct ModelContext {
pub path: String,
pub handle: WhisperContext,
}

pub fn setup(app: &App) -> Result<(), Box<dyn std::error::Error>> {
// Add panic hook
panic_hook::set_panic_hook(app.app_handle());

// Manage model context
app.manage(Mutex::new(None::<ModelContext>));

// Log some useful data
if let Ok(version) = tauri::webview_version() {
log::debug!("webview version: {}", version);
Expand Down
5 changes: 3 additions & 2 deletions desktop/src/pages/batch/viewModel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ export function viewModel() {
}
setInProgress(true)
let localIndex = 0
await invoke('load_model', { modelPath: preferences.modelPath })
try {
setCurrentIndex(localIndex)
const loopStartTime = performance.now()
Expand All @@ -108,11 +109,11 @@ export function viewModel() {
setProgress(null)
const options = {
path: file.path,
model_path: preferences.modelPath,
...preferences.modelOptions,
}
const startTime = performance.now()
const res: Transcript = await invoke('transcribe', { options })

const res: Transcript = await invoke('transcribe', { options, modelPath: preferences.modelPath })

// Calculate time
let total = Math.round((performance.now() - startTime) / 1000)
Expand Down
4 changes: 2 additions & 2 deletions desktop/src/pages/home/viewModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ export function viewModel() {
async function transcribe() {
setSegments(null)
setLoading(true)
await invoke('load_model', { modelPath: preferences.modelPath })
try {
const options = {
path: files[0].path,
model_path: preferences.modelPath,
...preferences.modelOptions,
}
const startTime = performance.now()
const res: transcript.Transcript = await invoke('transcribe', { options })
const res: transcript.Transcript = await invoke('transcribe', { options, modelPath: preferences.modelPath })

// Calcualte time
const total = Math.round((performance.now() - startTime) / 1000)
Expand Down

0 comments on commit 5276004

Please sign in to comment.