From ec89a668760051007c929461ed9e904d63cb33ba Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 26 Apr 2024 11:45:49 -0700 Subject: [PATCH] [rust] Avoid panic in error case (#3133) --- extensions/tokenizers/rust/src/models/mod.rs | 9 +++-- extensions/tokenizers/rust/src/ndarray/mod.rs | 36 +++++++++++-------- .../main/java/ai/djl/engine/rust/RsModel.java | 8 +++++ 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/extensions/tokenizers/rust/src/models/mod.rs b/extensions/tokenizers/rust/src/models/mod.rs index e09ed6be3a1c..0bd12dac20db 100644 --- a/extensions/tokenizers/rust/src/models/mod.rs +++ b/extensions/tokenizers/rust/src/models/mod.rs @@ -4,7 +4,7 @@ mod distilbert; use crate::ndarray::as_data_type; use crate::{cast_handle, to_handle, to_string_array}; use bert::{BertConfig, BertModel}; -use candle_core::DType; +use candle_core::{DType, Error}; use candle_core::{Device, Result, Tensor}; use candle_nn::VarBuilder; use distilbert::{DistilBertConfig, DistilBertModel}; @@ -43,7 +43,10 @@ fn load_model<'local>( // Load config let config: String = std::fs::read_to_string(model_path.join("config.json"))?; - let config: Config = serde_json::from_str(&config).unwrap(); + let config: Config = match serde_json::from_str(&config) { + Ok(conf) => conf, + Err(err) => return Err(Error::wrap(err)), + }; // Get candle device let device = if candle_core::utils::cuda_is_available() { @@ -55,7 +58,7 @@ fn load_model<'local>( }?; // Get candle dtype - let dtype = as_data_type(dtype).unwrap(); + let dtype = as_data_type(dtype)?; let safetensors_path = model_path.join("model.safetensors"); let vb = if safetensors_path.exists() { diff --git a/extensions/tokenizers/rust/src/ndarray/mod.rs b/extensions/tokenizers/rust/src/ndarray/mod.rs index 57bbedcef646..60302a8f2532 100644 --- a/extensions/tokenizers/rust/src/ndarray/mod.rs +++ b/extensions/tokenizers/rust/src/ndarray/mod.rs @@ -295,22 +295,30 @@ fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) - match device_type.as_str() { "cpu" => Ok(Device::Cpu), "gpu" => { - let mut device = CUDA_DEVICE.lock().unwrap(); - if let Some(device) = device.as_ref() { - return Ok(device.clone()); - }; - let d = Device::new_cuda(0).unwrap(); - *device = Some(d.clone()); - Ok(d) + if candle_core::utils::cuda_is_available() { + let mut device = CUDA_DEVICE.lock().unwrap(); + if let Some(device) = device.as_ref() { + return Ok(device.clone()); + }; + let d = Device::new_cuda(0).unwrap(); + *device = Some(d.clone()); + Ok(d) + } else { + Err(Error::Msg(String::from("CUDA is not available."))) + } } "mps" => { - let mut device = METAL_DEVICE.lock().unwrap(); - if let Some(device) = device.as_ref() { - return Ok(device.clone()); - }; - let d = Device::new_metal(0).unwrap(); - *device = Some(d.clone()); - Ok(d) + if candle_core::utils::metal_is_available() { + let mut device = METAL_DEVICE.lock().unwrap(); + if let Some(device) = device.as_ref() { + return Ok(device.clone()); + }; + let d = Device::new_metal(0).unwrap(); + *device = Some(d.clone()); + Ok(d) + } else { + Err(Error::Msg(String::from("metal is not available."))) + } } _ => Err(Error::Msg(format!("Invalid device type: {}", device_type))), } diff --git a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java index 4cdb3e9b5cd0..5bc447724e31 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java +++ b/extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java @@ -49,6 +49,14 @@ public void load(Path modelPath, String prefix, Map options) "Model directory doesn't exist: " + modelPath.toAbsolutePath()); } modelDir = modelPath.toAbsolutePath(); + Path config = modelDir.resolve("config.json"); + if (!Files.isRegularFile(config)) { + throw new FileNotFoundException("config.json file not found"); + } + Path file = modelDir.resolve("model.safetensors"); + if (!Files.isRegularFile(file)) { + throw new FileNotFoundException("model.safetensors file not found"); + } long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal()); block = new RsSymbolBlock((RsNDManager) manager, handle); }