Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rust] Load model on given device #3419

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 16 additions & 26 deletions extensions/tokenizers/rust/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod mistral;
mod roberta;
mod xlm_roberta;

use crate::ndarray::as_data_type;
use crate::ndarray::{as_data_type, as_device};
use crate::{cast_handle, drop_handle, to_handle, to_string_array};
use bert::{BertConfig, BertForSequenceClassification, BertModel};
use camembert::{CamembertConfig, CamembertModel};
Expand Down Expand Up @@ -54,34 +54,13 @@ pub(crate) trait Model {
}
}

fn load_model<'local>(
env: &mut JNIEnv,
model_path: JString,
dtype: jint,
) -> Result<Box<dyn Model>> {
let model_path: String = env
.get_string(&model_path)
.expect("Couldn't get java string!")
.into();

fn load_model(model_path: String, dtype: DType, device: Device) -> Result<Box<dyn Model>> {
let model_path = PathBuf::from(model_path);

// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))?;
let config: Config = serde_json::from_str(&config).map_err(Error::msg)?;

// Get candle device
let device = if candle::utils::cuda_is_available() {
Device::new_cuda(0)
} else if candle::utils::metal_is_available() {
Device::new_metal(0)
} else {
Ok(Device::Cpu)
}?;

// Get candle dtype
let dtype = as_data_type(dtype).unwrap();

// Load safetensors
let safetensors_paths: Vec<PathBuf> = std::fs::read_dir(model_path)?
.filter_map(|entry| {
Expand Down Expand Up @@ -167,14 +146,25 @@ fn load_model<'local>(

#[no_mangle]
pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_loadModel<'local>(
mut env: JNIEnv,
mut env: JNIEnv<'local>,
_: JObject,
model_path: JString,
dtype: jint,
device_type: JString,
device_id: jint,
) -> jlong {
let model = load_model(&mut env, model_path, dtype);
let model = || {
let model_path: String = env
.get_string(&model_path)
.expect("Couldn't get java string!")
.into();
let dtype = as_data_type(dtype)?;
let device = as_device(&mut env, device_type, device_id as usize)?;
load_model(model_path, dtype, device)
};
let ret = model();

match model {
match ret {
Ok(output) => to_handle(output),
Err(err) => {
env.throw(err.to_string()).unwrap();
Expand Down
29 changes: 7 additions & 22 deletions extensions/tokenizers/rust/src/ndarray/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ mod other;
mod reduce;
mod unary;

static CUDA_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);
static METAL_DEVICE: std::sync::Mutex<Option<Device>> = std::sync::Mutex::new(None);

#[no_mangle]
pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_getDataType(
_: JNIEnv,
Expand Down Expand Up @@ -287,32 +284,20 @@ fn to_data_type(data_type: DType) -> i32 {
}
}

fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) -> Result<Device> {
pub fn as_device<'local>(
env: &mut JNIEnv<'local>,
device_type: JString,
device_id: usize,
) -> Result<Device> {
let device_type: String = env
.get_string(&device_type)
.expect("Couldn't get java string!")
.into();

match device_type.as_str() {
"cpu" => Ok(Device::Cpu),
"gpu" => {
let mut device = CUDA_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_cuda(0).unwrap();
*device = Some(d.clone());
Ok(d)
}
"mps" => {
let mut device = METAL_DEVICE.lock().unwrap();
if let Some(device) = device.as_ref() {
return Ok(device.clone());
};
let d = Device::new_metal(0).unwrap();
*device = Some(d.clone());
Ok(d)
}
"gpu" => Device::new_cuda(device_id),
"mps" => Device::new_metal(device_id),
_ => Err(Error::Msg(format!("Invalid device type: {}", device_type))),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
}
setModelDir(modelPath);
if (block == null) {
handle.set(RustLibrary.loadModel(modelDir.toString(), dataType.ordinal()));
Device device = manager.getDevice();
handle.set(
RustLibrary.loadModel(
modelDir.toAbsolutePath().toString(),
dataType.ordinal(),
device.getDeviceType(),
device.getDeviceId()));
block = new RsSymbolBlock((RsNDManager) manager, handle.get());
} else {
loadBlock(prefix, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ private RustLibrary() {}

public static native boolean isCudaAvailable();

public static native long loadModel(String modelPath, int dtype);
public static native long loadModel(
String modelPath, int dtype, String deviceType, int deviceId);

public static native long deleteModel(long handle);

Expand Down
Loading