Skip to content

Commit

Permalink
mamba : make the server and parallel examples work with whole sequences
Browse files Browse the repository at this point in the history
A seq_id is dedicated to the system prompt in both cases.

* llama : make llama_kv_cache_seq_rm return whether it succeeded or not
  • Loading branch information
compilade committed Feb 26, 2024
1 parent a81b94f commit 12de5c7
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 33 deletions.
20 changes: 13 additions & 7 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;

// dedicate one sequence to the system prompt
params.n_parallel += 1;

// requests to simulate
const int32_t n_seq = params.n_sequences;

Expand Down Expand Up @@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
}

// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

LOG_TEE("\n");
Expand All @@ -221,15 +224,17 @@ int main(int argc, char ** argv) {

client.i_batch = batch.n_tokens;

llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);

client.n_decoded += 1;
}

if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1);
// but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

LOG_TEE("%s: clearing the KV cache\n", __func__);
Expand All @@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);

for (size_t i = 0; i < tokens_prompt.size(); ++i) {
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
}

// extract the logits only for the last token
Expand Down Expand Up @@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
}

// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);

const auto t_main_end = ggml_time_us();

Expand Down
43 changes: 31 additions & 12 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,16 @@ struct llama_server_context
return false;
}

if (params.n_ctx < 2048) { // request larger context for the image embedding
if (params.n_ctx != 0 && params.n_ctx < 2048) { // request larger context for the image embedding
params.n_ctx = 2048;
}
}

// dedicate one sequence to the system prompt
params.n_parallel += 1;

std::tie(model, ctx) = llama_init_from_gpt_params(params);
params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr)
{
LOG_ERROR("unable to load model", {{"model", params.model}});
Expand Down Expand Up @@ -923,9 +927,9 @@ struct llama_server_context
}

// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < params.n_parallel; ++i)
for (int32_t i = 1; i <= params.n_parallel; ++i)
{
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}

Expand Down Expand Up @@ -1400,7 +1404,7 @@ struct llama_server_context
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
for (int i = 0; i < (int) append_tokens.size(); ++i)
{
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id + 1 }, true);
slot.n_past += 1;
}
}
Expand Down Expand Up @@ -1636,8 +1640,8 @@ struct llama_server_context
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);

for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
{
Expand Down Expand Up @@ -1689,7 +1693,7 @@ struct llama_server_context

// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
slot.n_past += 1;
}

Expand Down Expand Up @@ -1852,13 +1856,28 @@ struct llama_server_context
}
}

// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
LOG_INFO("kv cache rm [p0, end)", {
{ "slot_id", slot.id },
{ "task_id", slot.task_id },
{ "p0", p0 }
});
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
// TODO: logging
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);

// there is no common part left (except for the system prompt)
// TODO: maybe find a way to refactor this to reuse the !cache_prompt case above
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
// TODO: is the system prompt ever in the sampling context?
llama_sampling_reset(slot.ctx_sampling);
}

LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
Expand Down Expand Up @@ -1887,7 +1906,7 @@ struct llama_server_context
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
slot_npast++;
}

Expand Down Expand Up @@ -1941,9 +1960,9 @@ struct llama_server_context
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);

llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);

slot.n_past_se -= bd;

Expand Down
38 changes: 25 additions & 13 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.used = 0;
}

static void llama_kv_cache_seq_rm(
static bool llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
Expand All @@ -2241,11 +2241,23 @@ static void llama_kv_cache_seq_rm(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

// models like Mamba can't have a state partially erased
if (cache.unlimited) {
// can only remove whole sequences for models like Mamba
GGML_ASSERT(p0 == 0);
GGML_ASSERT((uint32_t)seq_id < cache.size);
GGML_ASSERT(cache.cells[seq_id].pos < p1);
if (seq_id >= (int64_t) cache.size) {
// could be fatal
return false;
}
if (0 <= seq_id) {
// partial intersection is invalid
if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
return false;
}
} else {
// seq_id is negative, then the range should include everything or nothing
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
return false;
}
}
}

for (uint32_t i = 0; i < cache.size; ++i) {
Expand All @@ -2269,6 +2281,8 @@ static void llama_kv_cache_seq_rm(

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;

return true;
}

static void llama_kv_cache_seq_cp(
Expand Down Expand Up @@ -12283,13 +12297,11 @@ struct llama_context * llama_new_context_with_model(

// Mamba only needs a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA) {
// Mamba needs as many KV cells as there are sequences kept at any time
// The extra cell allows dedicating a sequence id to the system prompt
// TODO: find a better way to get the max number of parallel sequences
kv_size = params.n_parallel + 1;
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_parallel);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}

GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
Expand Down Expand Up @@ -12799,8 +12811,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
llama_kv_cache_clear(ctx->kv_self);
}

void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
}

void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ extern "C" {
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_rm(
LLAMA_API bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
Expand Down

0 comments on commit 12de5c7

Please sign in to comment.