Skip to content

Commit

Permalink
chore(ggml): update code
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <sam@secondstate.io>
  • Loading branch information
apepkuss committed Sep 25, 2023
1 parent d7a6360 commit e701974
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 46 deletions.
42 changes: 23 additions & 19 deletions ggml-llama-via-wasinn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ Now let's build and run this example.
rustup target add wasm32-wasi
```

- Install `openblas`

```bash
apt install -y libopenblas-dev
```

- Install WasmEdge Runtime

Use the following command to install WasmEdge Runtime and the `wasi_nn-ggml` plugin:
Expand Down Expand Up @@ -60,33 +66,31 @@ Now let's build and run this example.

If the command runs successfully, you can find the `ggml-llama-wasm.so` file in the root directory.

- Download the Llama2 model

For simplicity, we use `orca-mini-3b.ggmlv3.q4_0.bin` model in this example .
- Download the Llama2 model of GGUF format

```bash
curl -LO https://huggingface.co/TheBloke/orca_mini_3B-GGML/resolve/main/orca-mini-3b.ggmlv3.q4_0.bin
```

You can also download `llama-2-7b-chat.ggmlv3.q4_0.bin` model for your evaluation. Note that the model needs 16GB memory (peak consumption) to run.

```bash
curl -LO https://huggingface.co/localmodels/Llama-2-7B-Chat-ggml/resolve/main/llama-2-7b-chat.ggmlv3.q4_0.bin
curl -LO https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_M.gguf
```

- Build & run the `run-ggml-llama-wasm` app

```bash
cargo run -p run-ggml-llama-inference -- .:. ggml-llama-wasm.so default 'Once upon a time, '
cargo run -p run-ggml-llama-inference -- .:. ggml-llama-wasm.so default
```

Note that the argument `default` in the command above is from the code line in the `main.rs` of `run-ggml-llama-inference` project. You name it as you like, but you have to keep them same.
If the command runs successfully, you can try the multi-turn conversations like below:

```rust
...
// preload named model
let preloads = vec!["default:GGML:CPU:orca-mini-3b.ggmlv3.q4_0.bin"];
...
```bash
[Question]:
What is the capital of the United States?
[Answer]:
The capital of the United States is Washington, D.C. (District of Columbia).
[Question]:
What about France?
[Answer]:
The capital of France is Paris.
[Question]:
I have two apples, each costing 5 dollars. What is the total cost of these apples?
[Answer]:
The total cost of the two apples is $10.
```

If the command runs successfully, you can see the text generated by the model.
59 changes: 40 additions & 19 deletions ggml-llama-via-wasinn/ggml-llama-wasm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,51 @@ use wasi_nn;
fn main() {
let args: Vec<String> = env::args().collect();
let model_name: &str = &args[1];
let prompt: &str = &args[2];

let graph =
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Ggml, wasi_nn::ExecutionTarget::CPU)
.build_from_cache(model_name)
.unwrap();
println!("Loaded model into wasi-nn with ID: {:?}", graph);

let mut context = graph.init_execution_context().unwrap();
println!("Created wasi-nn execution context with ID: {:?}", context);

let tensor_data = prompt.as_bytes().to_vec();
println!("Read input tensor, size in bytes: {}", tensor_data.len());
context
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
.unwrap();

// Execute the inference.
context.compute().unwrap();
println!("Executed model inference");

// Retrieve the output.
let mut output_buffer = vec![0u8; 1000];
context.get_output(0, &mut output_buffer).unwrap();
let output = String::from_utf8(output_buffer.clone()).unwrap();
println!("Output: {}", output);

let system_prompt = String::from("<<SYS>>You are a helpful, respectful and honest assistant. Always answer as short as possible, while being safe. <</SYS>>");
let mut saved_prompt = String::new();

loop {
println!("[Question]:");
let input = read_input();
if saved_prompt == "" {
saved_prompt = format!("[INST] {} {} [/INST]", system_prompt, input.trim());
} else {
saved_prompt = format!("{} [INST] {} [/INST]", saved_prompt, input.trim());
}

let tensor_data = saved_prompt.as_bytes().to_vec();
context
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
.unwrap();

// Execute the inference.
context.compute().unwrap();

// Retrieve the output.
let mut output_buffer = vec![0u8; 1000];
let output_size = context.get_output(0, &mut output_buffer).unwrap();
let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
println!("[Answer]: {}", output);
}
}

fn read_input() -> String {
loop {
let mut answer = String::new();
std::io::stdin()
.read_line(&mut answer)
.ok()
.expect("Failed to read line");
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
return answer;
}
}
}
4 changes: 3 additions & 1 deletion ggml-llama-via-wasinn/run-ggml-llama-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ version = "0.1.0"
edition = "2021"

[dependencies]
wasmedge-sdk ={ git = "https://github.com/WasmEdge/wasmedge-rust-sdk.git", branch = "main", features = ["wasi_nn"] }
wasmedge-sdk = { git = "https://github.com/apepkuss/wasmedge-rust-sdk.git", branch = "feat/enbale-nn-preload", features = [
"wasi_nn",
] }
15 changes: 8 additions & 7 deletions ggml-llama-via-wasinn/run-ggml-llama-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use wasmedge_sdk::{
config::{CommonConfigOptions, ConfigBuilder, HostRegistrationConfigOptions},
params,
plugin::PluginManager,
plugin::{ExecutionTarget, NNBackend, NNPreload, PluginManager},
Module, VmBuilder,
};

Expand All @@ -20,15 +20,16 @@ fn infer() -> Result<(), Box<dyn std::error::Error>> {
let dir_mapping = &args[1];
let wasm_file = &args[2];
let model_name = &args[3];
let prompt = &args[4];

println!("load plugin");

// load wasinn-pytorch-plugin from the default plugin directory: /usr/local/lib/wasmedge
PluginManager::load(None)?;
// preload named model
let preloads = vec!["default:GGML:CPU:orca-mini-3b.ggmlv3.q4_0.bin"];
PluginManager::nn_preload(preloads);
PluginManager::nn_preload(vec![NNPreload::new(
"default",
NNBackend::GGML,
ExecutionTarget::CPU,
"llama-2-7b-chat.Q5_K_M.gguf",
)]);

let config = ConfigBuilder::new(CommonConfigOptions::default())
.with_host_registration_config(HostRegistrationConfigOptions::default().wasi(true))
Expand All @@ -50,7 +51,7 @@ fn infer() -> Result<(), Box<dyn std::error::Error>> {
vm.wasi_module_mut()
.expect("Not found wasi module")
.initialize(
Some(vec![wasm_file, model_name, prompt]),
Some(vec![wasm_file, model_name]),
None,
Some(vec![dir_mapping]),
);
Expand Down

0 comments on commit e701974

Please sign in to comment.