-
Notifications
You must be signed in to change notification settings - Fork 665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rust] Load model on given device #3419
Conversation
@@ -22,7 +22,7 @@ private RustLibrary() {} | |||
|
|||
public static native boolean isCudaAvailable(); | |||
|
|||
public static native long loadModel(String modelPath, int dtype); | |||
public static native long loadModel(String modelPath, int dtype, String device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we use String, can we just use device id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then how to distinguish cpu and cuda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public static native long loadModel(String modelPath, int dtype, int deviceType, int deviceId);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -234,3 +234,21 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_runInference<'local>( | |||
} | |||
} | |||
} | |||
|
|||
pub fn as_device(device: &String) -> Result<Device> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we re-use existing as_device()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I have modified the existing as_device()
a little bit, please review.
@@ -22,7 +22,7 @@ private RustLibrary() {} | |||
|
|||
public static native boolean isCudaAvailable(); | |||
|
|||
public static native long loadModel(String modelPath, int dtype); | |||
public static native long loadModel(String modelPath, int dtype, String device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5a010fd
to
36302fb
Compare
.get_string(&model_path) | ||
.expect("Couldn't get java string!") | ||
.into(); | ||
let dtype = as_data_type(dtype).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let dtype = as_data_type(dtype).unwrap(); | |
let dtype = as_data_type(dtype)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Description
Brief description of what this PR is about