Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : fix llama_copy_state_data with fragmented KV cache #5840

Merged
merged 1 commit into from
Mar 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
}

// find how many cells are currently in use
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
for (uint32_t i = cache.size - 1; i > 0; --i) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
return i + 1;
static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
for (uint32_t i = cache.size; i > 0; --i) {
const llama_kv_cell & cell = cache.cells[i - 1];

if (cell.pos >= 0 && !cell.is_empty()) {
return i;
}
}

Expand Down Expand Up @@ -8178,7 +8180,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);

//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
Expand Down Expand Up @@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
const size_t s_kv_size = sizeof(size_t);
const size_t s_kv_ntok = sizeof(int);
const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
// TODO: assume the max is more than 1 seq_id per KV cell
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;

const size_t s_total = (
+ s_rng_size
Expand All @@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
+ s_logits
+ s_embedding_size
+ s_embedding
+ s_kv_buf_size
+ s_kv_head
+ s_kv_size
+ s_kv_ntok
+ s_kv_used
+ s_kv
+ s_kv_cells
);

return s_total;
Expand Down Expand Up @@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;

const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_ctx = cparams.n_ctx;

const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = kv_self.head;
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size;
const uint32_t kv_used = kv_self.used;

Expand All @@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat

// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);

tmp_buf.resize(v_row_size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
Expand All @@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
}
}

for (uint32_t i = 0; i < kv_size; ++i) {
for (uint32_t i = 0; i < kv_head; ++i) {
const auto & cell = kv_self.cells[i];

const llama_pos pos = cell.pos;
Expand Down Expand Up @@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
{
const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;

const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_ctx = cparams.n_ctx;

size_t kv_buf_size;
uint32_t kv_head;
Expand All @@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {

// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);

for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
Expand All @@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
}
}

GGML_ASSERT(kv_self.size == kv_size);

ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;

ctx->kv_self.cells.resize(kv_size);

for (uint32_t i = 0; i < kv_size; ++i) {
for (uint32_t i = 0; i < kv_head; ++i) {
llama_pos pos;
size_t seq_id_size;

Expand All @@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ctx->kv_self.cells[i].seq_id.insert(seq_id);
}
}

for (uint32_t i = kv_head; i < kv_size; ++i) {
ctx->kv_self.cells[i].pos = -1;
ctx->kv_self.cells[i].seq_id.clear();
}
}

const size_t nread = inp - src;
Expand Down
Loading