Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Reserve more eval memory and use ggml scratch buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
juho-p committed Apr 8, 2023
1 parent 44ddfe8 commit 248fc8c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
35 changes: 35 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,23 @@ impl Context {
pub fn used_mem(&self) -> usize {
unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) }
}

/// Set scratch buffer
pub fn use_scratch(&self, scratch_buffer: Option<&mut Buffer>) {
let (size, data) = if let Some(buffer) = scratch_buffer {
(buffer.data.len(), buffer.data.as_ptr() as *mut c_void)
} else {
(0, std::ptr::null_mut())
};
// SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API
unsafe {
ggml_sys::ggml_set_scratch(self.ptr.as_ptr(), ggml_sys::ggml_scratch {
offs: 0,
size,
data,
});
}
}
}

impl Drop for Context {
Expand All @@ -315,6 +332,24 @@ impl Drop for Context {
}
}

/// Pre-allocated buffer
pub struct Buffer {
data: Vec<u8>,
}

impl Buffer {
/// Creates new buffer
pub fn new(size: usize) -> Self {
let mut data: Vec<u8> = Vec::with_capacity(size);
// SAFETY: contents are left uninitialized. Don't use them.
unsafe { data.set_len(size) };

Buffer {
data
}
}
}

/// Tensors are owned by the context. A tensor is alive as long as the
/// underlying context it was created with is alive.
pub struct Tensor {
Expand Down
27 changes: 26 additions & 1 deletion llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ mod util;
/// The end of text token.
pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?)

const SCRATCH_SIZE: usize = 512 * 1024 * 1024; // 512MB

/// The hyperparameters of the model.
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)]
pub struct Hyperparameters {
Expand Down Expand Up @@ -103,6 +105,9 @@ pub struct InferenceSession {

/// The logits that were last predicted by the network. Zeroed out otherwise.
last_logits: Vec<f32>,

/// Scratch buffers
scratch: [ggml::Buffer; 2]
}
impl InferenceSession {
fn repetition_penalty_tokens(&self) -> &[TokenId] {
Expand All @@ -128,10 +133,15 @@ impl Clone for InferenceSession {
mem_per_token: self.mem_per_token,
tokens: self.tokens.clone(),
last_logits: self.last_logits.clone(),
scratch: inference_session_scratch_buffers(),
}
}
}

fn inference_session_scratch_buffers() -> [ggml::Buffer; 2] {
[ggml::Buffer::new(SCRATCH_SIZE), ggml::Buffer::new(SCRATCH_SIZE)]
}

#[derive(serde::Serialize, Clone, PartialEq)]
/// A serializable snapshot of the inference process. Can be saved to disk.
// Keep in sync with [InferenceSession] and [InferenceSnapshot]
Expand Down Expand Up @@ -1116,6 +1126,7 @@ impl Model {
mem_per_token: 0,
tokens: vec![],
last_logits: vec![0.0; n_vocab],
scratch: inference_session_scratch_buffers(),
}
}

Expand Down Expand Up @@ -1150,7 +1161,13 @@ impl Model {

// For the first run, we need to guess a maximum buffer size so we can measure
// the actual memory consumption of the temporary ggml context.
let mut buf_size = 1024 * 1024 * 1024;
let mut buf_size = 1024 * 1024 * if n_layer >= 80 {
1536
} else if n_layer >= 60 {
1280
} else {
1024
};
if session.mem_per_token > 0 && session.mem_per_token * n > buf_size {
// add 10% to account for ggml object overhead
buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;
Expand Down Expand Up @@ -1189,6 +1206,8 @@ impl Model {
let input_self_attention = input_layer.share();
let mut current: ggml::Tensor;

ctx0.use_scratch(Some(&mut session.scratch[0]));

// norm
{
current = ctx0.op_rms_norm(&input_layer);
Expand Down Expand Up @@ -1312,6 +1331,8 @@ impl Model {
current = ctx0.op_mul_mat(&self.layers[il].wo, &current);
}

ctx0.use_scratch(Some(&mut session.scratch[1]));

let input_feed_forward = ctx0.op_add(&current, &input_self_attention);

// feed-forward network
Expand Down Expand Up @@ -1345,6 +1366,8 @@ impl Model {
input_layer = current;
}

ctx0.use_scratch(Some(&mut session.scratch[0]));

// Used at the end to optionally extract the embeddings.
let embeddings_tensor;

Expand All @@ -1362,6 +1385,8 @@ impl Model {
input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
}

ctx0.use_scratch(None);

// logits -> probs
// inpL = ctx0.op_soft_max(&inpL);

Expand Down

0 comments on commit 248fc8c

Please sign in to comment.