From 2f3371b61f76dd70174dd389b3b149d63a083ed9 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 15 Aug 2024 19:01:56 -0700 Subject: [PATCH 1/4] [rust] Load model on given device --- extensions/tokenizers/rust/src/models/mod.rs | 34 ++++++++++++++----- .../main/java/ai/djl/engine/rust/RsModel.java | 6 +++- .../java/ai/djl/engine/rust/RustLibrary.java | 2 +- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/extensions/tokenizers/rust/src/models/mod.rs b/extensions/tokenizers/rust/src/models/mod.rs index cfae04cec36..bbe57e152aa 100644 --- a/extensions/tokenizers/rust/src/models/mod.rs +++ b/extensions/tokenizers/rust/src/models/mod.rs @@ -58,11 +58,16 @@ fn load_model<'local>( env: &mut JNIEnv, model_path: JString, dtype: jint, + device: JString, ) -> Result> { let model_path: String = env .get_string(&model_path) .expect("Couldn't get java string!") .into(); + let device: String = env + .get_string(&device) + .expect("Couldn't get java string!") + .into(); let model_path = PathBuf::from(model_path); @@ -71,13 +76,7 @@ fn load_model<'local>( let config: Config = serde_json::from_str(&config).map_err(Error::msg)?; // Get candle device - let device = if candle::utils::cuda_is_available() { - Device::new_cuda(0) - } else if candle::utils::metal_is_available() { - Device::new_metal(0) - } else { - Ok(Device::Cpu) - }?; + let device = as_device(&device).expect("Couldn't get device!"); // Get candle dtype let dtype = as_data_type(dtype).unwrap(); @@ -171,8 +170,9 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_loadModel<'local>( _: JObject, model_path: JString, dtype: jint, + device: JString, ) -> jlong { - let model = load_model(&mut env, model_path, dtype); + let model = load_model(&mut env, model_path, dtype, device); match model { Ok(output) => to_handle(output), @@ -234,3 +234,21 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_runInference<'local>( } } } + +pub fn as_device(device: &String) -> Result { + if device.starts_with("gpu") { + if let Some(id_str) = device + .strip_prefix("gpu(") + .and_then(|s| s.strip_suffix(")")) + { + if let Ok(id) = id_str.parse::() { + return Device::new_cuda(id); + } + } + panic!("Invalid GPU format!"); + } else if device == "cpu()" { + return Ok(Device::Cpu); + } else { + panic!("Unsupported device string!"); + }; +} diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java index be9430577fc..63f6b0e3333 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java @@ -28,6 +28,7 @@ /** {@code RsModel} is the Rust implementation of {@link Model}. */ public class RsModel extends BaseModel { + private Device device; private final AtomicReference handle; /** @@ -38,6 +39,7 @@ public class RsModel extends BaseModel { */ RsModel(String name, Device device) { super(name); + this.device = device; manager = RsNDManager.getSystemManager().newSubManager(device); manager.setName("RsModel"); dataType = DataType.FLOAT16; @@ -54,7 +56,9 @@ public void load(Path modelPath, String prefix, Map options) } setModelDir(modelPath); if (block == null) { - handle.set(RustLibrary.loadModel(modelDir.toString(), dataType.ordinal())); + handle.set( + RustLibrary.loadModel( + modelDir.toString(), dataType.ordinal(), device.toString())); block = new RsSymbolBlock((RsNDManager) manager, handle.get()); } else { loadBlock(prefix, options); diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java index 852c392f13c..440fe6bde03 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java @@ -22,7 +22,7 @@ private RustLibrary() {} public static native boolean isCudaAvailable(); - public static native long loadModel(String modelPath, int dtype); + public static native long loadModel(String modelPath, int dtype, String device); public static native long deleteModel(long handle); From c0802b4b4e07c2c9e3d41f1e224d8c016261e161 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 16 Aug 2024 10:00:08 -0700 Subject: [PATCH 2/4] Update --- extensions/tokenizers/rust/src/models/mod.rs | 60 +++++-------------- extensions/tokenizers/rust/src/ndarray/mod.rs | 29 +++------ .../main/java/ai/djl/engine/rust/RsModel.java | 5 +- .../java/ai/djl/engine/rust/RustLibrary.java | 3 +- 4 files changed, 29 insertions(+), 68 deletions(-) diff --git a/extensions/tokenizers/rust/src/models/mod.rs b/extensions/tokenizers/rust/src/models/mod.rs index bbe57e152aa..4ef5b91a855 100644 --- a/extensions/tokenizers/rust/src/models/mod.rs +++ b/extensions/tokenizers/rust/src/models/mod.rs @@ -5,7 +5,7 @@ mod mistral; mod roberta; mod xlm_roberta; -use crate::ndarray::as_data_type; +use crate::ndarray::{as_data_type, as_device}; use crate::{cast_handle, drop_handle, to_handle, to_string_array}; use bert::{BertConfig, BertForSequenceClassification, BertModel}; use camembert::{CamembertConfig, CamembertModel}; @@ -54,33 +54,13 @@ pub(crate) trait Model { } } -fn load_model<'local>( - env: &mut JNIEnv, - model_path: JString, - dtype: jint, - device: JString, -) -> Result> { - let model_path: String = env - .get_string(&model_path) - .expect("Couldn't get java string!") - .into(); - let device: String = env - .get_string(&device) - .expect("Couldn't get java string!") - .into(); - +fn load_model(model_path: String, dtype: DType, device: Device) -> Result> { let model_path = PathBuf::from(model_path); // Load config let config: String = std::fs::read_to_string(model_path.join("config.json"))?; let config: Config = serde_json::from_str(&config).map_err(Error::msg)?; - // Get candle device - let device = as_device(&device).expect("Couldn't get device!"); - - // Get candle dtype - let dtype = as_data_type(dtype).unwrap(); - // Load safetensors let safetensors_paths: Vec = std::fs::read_dir(model_path)? .filter_map(|entry| { @@ -166,15 +146,25 @@ fn load_model<'local>( #[no_mangle] pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_loadModel<'local>( - mut env: JNIEnv, + mut env: JNIEnv<'local>, _: JObject, model_path: JString, dtype: jint, - device: JString, + device_type: JString, + device_id: jint, ) -> jlong { - let model = load_model(&mut env, model_path, dtype, device); + let model = || { + let model_path: String = env + .get_string(&model_path) + .expect("Couldn't get java string!") + .into(); + let dtype = as_data_type(dtype).unwrap(); + let device = as_device(&mut env, device_type, device_id as usize)?; + load_model(model_path, dtype, device) + }; + let ret = model(); - match model { + match ret { Ok(output) => to_handle(output), Err(err) => { env.throw(err.to_string()).unwrap(); @@ -234,21 +224,3 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_runInference<'local>( } } } - -pub fn as_device(device: &String) -> Result { - if device.starts_with("gpu") { - if let Some(id_str) = device - .strip_prefix("gpu(") - .and_then(|s| s.strip_suffix(")")) - { - if let Ok(id) = id_str.parse::() { - return Device::new_cuda(id); - } - } - panic!("Invalid GPU format!"); - } else if device == "cpu()" { - return Ok(Device::Cpu); - } else { - panic!("Unsupported device string!"); - }; -} diff --git a/extensions/tokenizers/rust/src/ndarray/mod.rs b/extensions/tokenizers/rust/src/ndarray/mod.rs index 781c459ec48..059b69cd4b1 100644 --- a/extensions/tokenizers/rust/src/ndarray/mod.rs +++ b/extensions/tokenizers/rust/src/ndarray/mod.rs @@ -14,9 +14,6 @@ mod other; mod reduce; mod unary; -static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); -static METAL_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); - #[no_mangle] pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getDataType( _: JNIEnv, @@ -287,7 +284,11 @@ fn to_data_type(data_type: DType) -> i32 { } } -fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) -> Result { +pub fn as_device<'local>( + env: &mut JNIEnv<'local>, + device_type: JString, + device_id: usize, +) -> Result { let device_type: String = env .get_string(&device_type) .expect("Couldn't get java string!") @@ -295,24 +296,8 @@ fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) - match device_type.as_str() { "cpu" => Ok(Device::Cpu), - "gpu" => { - let mut device = CUDA_DEVICE.lock().unwrap(); - if let Some(device) = device.as_ref() { - return Ok(device.clone()); - }; - let d = Device::new_cuda(0).unwrap(); - *device = Some(d.clone()); - Ok(d) - } - "mps" => { - let mut device = METAL_DEVICE.lock().unwrap(); - if let Some(device) = device.as_ref() { - return Ok(device.clone()); - }; - let d = Device::new_metal(0).unwrap(); - *device = Some(d.clone()); - Ok(d) - } + "gpu" => Device::new_cuda(device_id), + "mps" => Device::new_metal(device_id), _ => Err(Error::Msg(format!("Invalid device type: {}", device_type))), } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java index 63f6b0e3333..a2b3b68f5ac 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java @@ -58,7 +58,10 @@ public void load(Path modelPath, String prefix, Map options) if (block == null) { handle.set( RustLibrary.loadModel( - modelDir.toString(), dataType.ordinal(), device.toString())); + modelDir.toAbsolutePath().toString(), + dataType.ordinal(), + device.getDeviceType(), + device.getDeviceId())); block = new RsSymbolBlock((RsNDManager) manager, handle.get()); } else { loadBlock(prefix, options); diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java index 440fe6bde03..ef848e0291a 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RustLibrary.java @@ -22,7 +22,8 @@ private RustLibrary() {} public static native boolean isCudaAvailable(); - public static native long loadModel(String modelPath, int dtype, String device); + public static native long loadModel( + String modelPath, int dtype, String deviceType, int deviceId); public static native long deleteModel(long handle); From 8cb82fd5f22bb35f95eed10af2c32801e3f2e22e Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 16 Aug 2024 10:01:31 -0700 Subject: [PATCH 3/4] Update --- extensions/tokenizers/rust/src/models/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/tokenizers/rust/src/models/mod.rs b/extensions/tokenizers/rust/src/models/mod.rs index 4ef5b91a855..11ade3982d4 100644 --- a/extensions/tokenizers/rust/src/models/mod.rs +++ b/extensions/tokenizers/rust/src/models/mod.rs @@ -158,7 +158,7 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_loadModel<'local>( .get_string(&model_path) .expect("Couldn't get java string!") .into(); - let dtype = as_data_type(dtype).unwrap(); + let dtype = as_data_type(dtype)?; let device = as_device(&mut env, device_type, device_id as usize)?; load_model(model_path, dtype, device) }; From 7bc202c0fbe9577f34e45647bb9bf832c4ba658c Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 16 Aug 2024 10:13:25 -0700 Subject: [PATCH 4/4] Update --- .../tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java index a2b3b68f5ac..bbc1dfd01c9 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java @@ -28,7 +28,6 @@ /** {@code RsModel} is the Rust implementation of {@link Model}. */ public class RsModel extends BaseModel { - private Device device; private final AtomicReference handle; /** @@ -39,7 +38,6 @@ public class RsModel extends BaseModel { */ RsModel(String name, Device device) { super(name); - this.device = device; manager = RsNDManager.getSystemManager().newSubManager(device); manager.setName("RsModel"); dataType = DataType.FLOAT16; @@ -56,6 +54,7 @@ public void load(Path modelPath, String prefix, Map options) } setModelDir(modelPath); if (block == null) { + Device device = manager.getDevice(); handle.set( RustLibrary.loadModel( modelDir.toAbsolutePath().toString(),