diff --git a/rust/mimi-pyo3/Cargo.toml b/rust/mimi-pyo3/Cargo.toml index 64c9534..f7c772f 100644 --- a/rust/mimi-pyo3/Cargo.toml +++ b/rust/mimi-pyo3/Cargo.toml @@ -16,4 +16,4 @@ crate-type = ["cdylib"] anyhow = "1" numpy = "0.21.0" pyo3 = "0.21.0" -moshi = { path = "../moshi-core", version = "0.2.1" } +moshi = { path = "../moshi-core", version = "0.2.2" } diff --git a/rust/mimi-pyo3/src/lib.rs b/rust/mimi-pyo3/src/lib.rs index 5f033ea..ee696aa 100644 --- a/rust/mimi-pyo3/src/lib.rs +++ b/rust/mimi-pyo3/src/lib.rs @@ -39,7 +39,7 @@ macro_rules! py_bail { }; } -fn encodec_cfg() -> encodec::Config { +fn encodec_cfg(max_seq_len: Option) -> encodec::Config { let seanet_cfg = seanet::Config { dimension: 512, channels: 1, @@ -82,7 +82,7 @@ fn encodec_cfg() -> encodec::Config { kv_repeat: 1, conv_layout: true, // see builders.py cross_attention: false, - max_seq_len: 4096, + max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins. }; encodec::Config { channels: 1, @@ -107,9 +107,9 @@ struct Tokenizer { #[pymethods] impl Tokenizer { - #[pyo3(signature = (path, *, dtype="f32"))] + #[pyo3(signature = (path, *, dtype="f32", max_seq_len=None))] #[new] - fn new(path: std::path::PathBuf, dtype: &str) -> PyResult { + fn new(path: std::path::PathBuf, dtype: &str, max_seq_len: Option) -> PyResult { let device = candle::Device::Cpu; let dtype = match dtype { "f32" => candle::DType::F32, @@ -119,7 +119,7 @@ impl Tokenizer { }; let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? }; - let cfg = encodec_cfg(); + let cfg = encodec_cfg(max_seq_len); let encodec = encodec::Encodec::new(cfg, vb).w()?; Ok(Self { encodec, device, dtype }) } @@ -240,9 +240,9 @@ struct StreamTokenizer { #[pymethods] impl StreamTokenizer { - #[pyo3(signature = (path, *, dtype="f32"))] + #[pyo3(signature = (path, *, dtype="f32", max_seq_len=None))] #[new] - fn new(path: std::path::PathBuf, dtype: &str) -> PyResult { + fn new(path: std::path::PathBuf, dtype: &str, max_seq_len: Option) -> PyResult { let device = candle::Device::Cpu; let dtype = match dtype { "f32" => candle::DType::F32, @@ -252,7 +252,7 @@ impl StreamTokenizer { }; let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? }; - let cfg = encodec_cfg(); + let cfg = encodec_cfg(max_seq_len); let mut e_encodec = encodec::Encodec::new(cfg, vb).w()?; let mut d_encodec = e_encodec.clone(); let (encoder_tx, e_rx) = std::sync::mpsc::channel::>(); diff --git a/rust/moshi-backend/Cargo.toml b/rust/moshi-backend/Cargo.toml index b15ff42..5de5d0b 100644 --- a/rust/moshi-backend/Cargo.toml +++ b/rust/moshi-backend/Cargo.toml @@ -26,7 +26,7 @@ rcgen = "0.13.1" http = "1.1.0" lazy_static = "1.5.0" log = "0.4.20" -moshi = { path = "../moshi-core", version = "0.2.1" } +moshi = { path = "../moshi-core", version = "0.2.2" } ogg = { version = "0.9.1", features = ["async"] } opus = "0.3.0" rand = { version = "0.8.5", features = ["getrandom"] }