From 738ace394a6f8cf0174e90a97185d9e512c0e200 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 May 2023 09:08:52 +0300 Subject: [PATCH] llama : free ggml context in set / copy state data (close #1425) --- llama.cpp | 48 ++++++++++++++++++++++++++++-------------------- llama.h | 2 +- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/llama.cpp b/llama.cpp index 0a47faa9d738d..f52671b67c636 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2450,8 +2450,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) { } // Copies the state to the specified destination address -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { - uint8_t * out = dest; +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + uint8_t * out = dst; // copy rng { @@ -2511,7 +2511,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { if (kv_size) { const size_t elt_size = ggml_element_size(kv_self.k); + char buffer[4096]; + ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; @@ -2535,10 +2537,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); ggml_graph_compute(cpy_ctx, &gf); + + ggml_free(cpy_ctx); } } - const size_t written = out - dest; + const size_t written = out - dst; const size_t max_size = llama_get_state_size(ctx); LLAMA_ASSERT(written <= max_size); @@ -2548,15 +2552,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { // Sets the state reading from the specified source address size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { - const uint8_t * in = src; + const uint8_t * inp = src; // set rng { size_t rng_size; char rng_buf[LLAMA_MAX_RNG_STATE]; - memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); - memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE; + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; std::stringstream rng_ss; rng_ss.str(std::string(&rng_buf[0], rng_size)); @@ -2570,30 +2574,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t logits_cap; size_t logits_size; - memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap); - memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size); + memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); + memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); if (logits_size) { ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); + memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); } - in += logits_cap * sizeof(float); + inp += logits_cap * sizeof(float); } // set embeddings { size_t embedding_size; - memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size); + memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); if (embedding_size) { - memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); - in += embedding_size * sizeof(float); + memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); + inp += embedding_size * sizeof(float); } } @@ -2608,25 +2612,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t kv_size; int kv_ntok; - memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); - memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); + memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); + memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); if (kv_size) { LLAMA_ASSERT(kv_self.buf.size == kv_size); const size_t elt_size = ggml_element_size(kv_self.k); + char buffer[4096]; + ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_cgraph gf{}; gf.n_threads = 1; ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); - kin3d->data = (void *) in; - in += ggml_nbytes(kin3d); + kin3d->data = (void *) inp; + inp += ggml_nbytes(kin3d); ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); - vin3d->data = (void *) in; - in += ggml_nbytes(vin3d); + vin3d->data = (void *) inp; + inp += ggml_nbytes(vin3d); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, n_embd, kv_ntok, n_layer, @@ -2639,12 +2645,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); ggml_graph_compute(cpy_ctx, &gf); + + ggml_free(cpy_ctx); } ctx->model.kv_self.n = kv_ntok; } - const size_t nread = in - src; + const size_t nread = inp - src; const size_t max_size = llama_get_state_size(ctx); LLAMA_ASSERT(nread <= max_size); diff --git a/llama.h b/llama.h index 1a65cd5892389..ca05645b974de 100644 --- a/llama.h +++ b/llama.h @@ -134,7 +134,7 @@ extern "C" { // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest); + LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); // Set the state reading from the specified address // Returns the number of bytes read