From bd5480cd7f2def89b4b0646f17a3f2998714c9d3 Mon Sep 17 00:00:00 2001 From: Juho Peltonen Date: Fri, 7 Apr 2023 01:08:47 +0300 Subject: [PATCH] Reserve more eval memory and use ggml scratch buffers --- ggml/src/lib.rs | 36 ++++++++++++++++++++++++++++++++++++ llama-rs/src/lib.rs | 27 ++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 22c9eee8..02f2dbc0 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -303,6 +303,23 @@ impl Context { pub fn used_mem(&self) -> usize { unsafe { ggml_sys::ggml_used_mem(self.ptr.as_ptr()) } } + + /// Set scratch buffer + pub fn use_scratch(&self, scratch_buffer: Option<&mut Buffer>) { + let (size, data) = if let Some(buffer) = scratch_buffer { + (buffer.data.len(), buffer.data.as_ptr() as *mut c_void) + } else { + (0, std::ptr::null_mut()) + }; + // SAFETY: this just passes (most likely uninitialized) memory buffer to the ggml C API + unsafe { + ggml_sys::ggml_set_scratch(self.ptr.as_ptr(), ggml_sys::ggml_scratch { + offs: 0, + size, + data, + }); + } + } } impl Drop for Context { @@ -315,6 +332,25 @@ impl Drop for Context { } } +/// Pre-allocated buffer +pub struct Buffer { + data: Vec, +} + +impl Buffer { + /// Creates new buffer + pub fn new(size: usize) -> Self { + let mut data: Vec = Vec::with_capacity(size); + // SAFETY: contents are left uninitialized. Don't use them. + #[allow(clippy::uninit_vec)] + unsafe { data.set_len(size) }; + + Buffer { + data + } + } +} + /// Tensors are owned by the context. A tensor is alive as long as the /// underlying context it was created with is alive. pub struct Tensor { diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index d5ef2a23..68e11a26 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -27,6 +27,8 @@ mod util; /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) +const SCRATCH_SIZE: usize = 512 * 1024 * 1024; // 512MB + /// The hyperparameters of the model. #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)] pub struct Hyperparameters { @@ -103,6 +105,9 @@ pub struct InferenceSession { /// The logits that were last predicted by the network. Zeroed out otherwise. last_logits: Vec, + + /// Scratch buffers + scratch: [ggml::Buffer; 2] } impl InferenceSession { fn repetition_penalty_tokens(&self) -> &[TokenId] { @@ -128,10 +133,15 @@ impl Clone for InferenceSession { mem_per_token: self.mem_per_token, tokens: self.tokens.clone(), last_logits: self.last_logits.clone(), + scratch: inference_session_scratch_buffers(), } } } +fn inference_session_scratch_buffers() -> [ggml::Buffer; 2] { + [ggml::Buffer::new(SCRATCH_SIZE), ggml::Buffer::new(SCRATCH_SIZE)] +} + #[derive(serde::Serialize, Clone, PartialEq)] /// A serializable snapshot of the inference process. Can be saved to disk. // Keep in sync with [InferenceSession] and [InferenceSnapshot] @@ -1116,6 +1126,7 @@ impl Model { mem_per_token: 0, tokens: vec![], last_logits: vec![0.0; n_vocab], + scratch: inference_session_scratch_buffers(), } } @@ -1150,7 +1161,13 @@ impl Model { // For the first run, we need to guess a maximum buffer size so we can measure // the actual memory consumption of the temporary ggml context. - let mut buf_size = 1024 * 1024 * 1024; + let mut buf_size = 1024 * 1024 * if n_layer >= 80 { + 1536 + } else if n_layer >= 60 { + 1280 + } else { + 1024 + }; if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { // add 10% to account for ggml object overhead buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; @@ -1189,6 +1206,8 @@ impl Model { let input_self_attention = input_layer.share(); let mut current: ggml::Tensor; + ctx0.use_scratch(Some(&mut session.scratch[0])); + // norm { current = ctx0.op_rms_norm(&input_layer); @@ -1312,6 +1331,8 @@ impl Model { current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t); } + ctx0.use_scratch(Some(&mut session.scratch[1])); + let input_feed_forward = ctx0.op_add(¤t, &input_self_attention); // feed-forward network @@ -1345,6 +1366,8 @@ impl Model { input_layer = current; } + ctx0.use_scratch(Some(&mut session.scratch[0])); + // Used at the end to optionally extract the embeddings. let embeddings_tensor; @@ -1362,6 +1385,8 @@ impl Model { input_layer = ctx0.op_mul_mat(&self.output, &input_layer); } + ctx0.use_scratch(None); + // logits -> probs // inpL = ctx0.op_soft_max(&inpL);