Skip to content

Commit

Permalink
Increase max_seq_len for mimi-pyo3 and make it configurable.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 22, 2024
1 parent 3e3e573 commit 1283b17
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion rust/mimi-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ crate-type = ["cdylib"]
anyhow = "1"
numpy = "0.21.0"
pyo3 = "0.21.0"
moshi = { path = "../moshi-core", version = "0.2.1" }
moshi = { path = "../moshi-core", version = "0.2.2" }
16 changes: 8 additions & 8 deletions rust/mimi-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ macro_rules! py_bail {
};
}

fn encodec_cfg() -> encodec::Config {
fn encodec_cfg(max_seq_len: Option<usize>) -> encodec::Config {
let seanet_cfg = seanet::Config {
dimension: 512,
channels: 1,
Expand Down Expand Up @@ -82,7 +82,7 @@ fn encodec_cfg() -> encodec::Config {
kv_repeat: 1,
conv_layout: true, // see builders.py
cross_attention: false,
max_seq_len: 4096,
max_seq_len: max_seq_len.unwrap_or(8192), // the transformer works at 25hz so this is ~5 mins.
};
encodec::Config {
channels: 1,
Expand All @@ -107,9 +107,9 @@ struct Tokenizer {

#[pymethods]
impl Tokenizer {
#[pyo3(signature = (path, *, dtype="f32"))]
#[pyo3(signature = (path, *, dtype="f32", max_seq_len=None))]
#[new]
fn new(path: std::path::PathBuf, dtype: &str) -> PyResult<Self> {
fn new(path: std::path::PathBuf, dtype: &str, max_seq_len: Option<usize>) -> PyResult<Self> {
let device = candle::Device::Cpu;
let dtype = match dtype {
"f32" => candle::DType::F32,
Expand All @@ -119,7 +119,7 @@ impl Tokenizer {
};
let vb =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? };
let cfg = encodec_cfg();
let cfg = encodec_cfg(max_seq_len);
let encodec = encodec::Encodec::new(cfg, vb).w()?;
Ok(Self { encodec, device, dtype })
}
Expand Down Expand Up @@ -240,9 +240,9 @@ struct StreamTokenizer {

#[pymethods]
impl StreamTokenizer {
#[pyo3(signature = (path, *, dtype="f32"))]
#[pyo3(signature = (path, *, dtype="f32", max_seq_len=None))]
#[new]
fn new(path: std::path::PathBuf, dtype: &str) -> PyResult<Self> {
fn new(path: std::path::PathBuf, dtype: &str, max_seq_len: Option<usize>) -> PyResult<Self> {
let device = candle::Device::Cpu;
let dtype = match dtype {
"f32" => candle::DType::F32,
Expand All @@ -252,7 +252,7 @@ impl StreamTokenizer {
};
let vb =
unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[path], dtype, &device).w()? };
let cfg = encodec_cfg();
let cfg = encodec_cfg(max_seq_len);
let mut e_encodec = encodec::Encodec::new(cfg, vb).w()?;
let mut d_encodec = e_encodec.clone();
let (encoder_tx, e_rx) = std::sync::mpsc::channel::<Vec<f32>>();
Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ rcgen = "0.13.1"
http = "1.1.0"
lazy_static = "1.5.0"
log = "0.4.20"
moshi = { path = "../moshi-core", version = "0.2.1" }
moshi = { path = "../moshi-core", version = "0.2.2" }
ogg = { version = "0.9.1", features = ["async"] }
opus = "0.3.0"
rand = { version = "0.8.5", features = ["getrandom"] }
Expand Down

0 comments on commit 1283b17

Please sign in to comment.