From 9c4c257e5fe8511a5bcce0adf866cc61c98ca7cf Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 13 Feb 2024 19:06:18 -0500 Subject: [PATCH] mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok --- common/common.cpp | 1 + ggml.c | 26 ++-- llama.cpp | 313 +++++++++++++++++++++++++++++++++++----------- llama.h | 1 + 4 files changed, 253 insertions(+), 88 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3302caa2004593..e11663412cd9ed 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1176,6 +1176,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ctx = params.n_ctx; cparams.n_batch = params.n_batch; + cparams.n_parallel = params.n_parallel; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.mul_mat_q = params.mul_mat_q; diff --git a/ggml.c b/ggml.c index fdeb9843487018..ef6e1757c6e0c7 100644 --- a/ggml.c +++ b/ggml.c @@ -5905,15 +5905,15 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D { - const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_tok = x->ne[1]; + const int64_t d_state = s->ne[0]; + const int64_t d_inner = s->ne[1]; + const int64_t n_tokens = x->ne[1]; GGML_ASSERT(x->ne[0] == d_inner); GGML_ASSERT(A->ne[0] == d_state); GGML_ASSERT(A->ne[1] == d_inner); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_tok); + GGML_ASSERT(B->ne[1] == n_tokens); } bool is_node = false; @@ -14178,12 +14178,12 @@ static void ggml_compute_forward_ssm_scan_f32( // first batch { - float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok} + float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens} float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens} float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data); // {d_state, n_tok} + float * B = (float *) ((char *) src4->data); // {d_state, n_tokens} // d_inner for (int i1 = 0; i1 < ir; ++i1) { float dt_soft_plus = log1pf(expf(dt[i1])); @@ -14199,12 +14199,12 @@ static void ggml_compute_forward_ssm_scan_f32( // compute state for rest of tokens, previous state comes from dest for (int i2 = 1; i2 < n_t; ++i2) { - float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok} - float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok} - float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok} - float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok} + float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens} + float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens} + float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens} + float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens} float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok} + float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens} // d_inner for (int i1 = 0; i1 < ir; ++i1) { float dt_soft_plus = log1pf(expf(dt[i1])); diff --git a/llama.cpp b/llama.cpp index 02cd296730213d..98d77f0906ce9e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1607,6 +1607,8 @@ struct llama_kv_cell { // ring-buffer of cached KV data struct llama_kv_cache { bool has_shift = false; + // with Mamba, a slot can hold the state for more than one past token + bool unlimited = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -1829,11 +1831,12 @@ struct llama_context { // input tensors ggml_backend_buffer_t buf_input = nullptr; ggml_context * ctx_input = nullptr; - struct ggml_tensor * inp_tokens; // I32 [n_batch] - struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] - struct ggml_tensor * inp_K_shift; // I32 [n_ctx] + struct ggml_tensor * inp_tokens; // I32 [n_batch] + struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] + struct ggml_tensor * inp_state_mask; // F32 [kv_size] (only used by constant state models like Mamba) #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -1849,7 +1852,7 @@ static bool llama_kv_cache_init( const llama_model & model, ggml_type ktype, ggml_type vtype, - uint32_t n_ctx, + uint32_t kv_size, bool offload) { const struct llama_hparams & hparams = model.hparams; @@ -1857,19 +1860,23 @@ static bool llama_kv_cache_init( const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_layer = hparams.n_layer; - if (model.arch == LLM_ARCH_MAMBA) { - // only one slot is needed for Mamba - n_ctx = 1; - } - cache.has_shift = false; + // for now, only Mamba can hold state for more than one past token per slot + cache.unlimited = model.arch == LLM_ARCH_MAMBA; + cache.head = 0; - cache.size = n_ctx; + cache.size = kv_size; cache.used = 0; cache.cells.clear(); - cache.cells.resize(n_ctx); + cache.cells.resize(kv_size); + + if (cache.unlimited) { + for (uint32_t i = 0; i < cache.size; ++i) { + cache.cells[i].delta = i; + } + } // else, delta is already initialized to zero #ifdef GGML_USE_CLBLAST offload = false; @@ -1908,8 +1915,8 @@ static bool llama_kv_cache_init( for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, ktype, n_embd_k_gqa*n_ctx); - ggml_tensor * v = ggml_new_tensor_1d(ctx, vtype, n_embd_v_gqa*n_ctx); + ggml_tensor * k = ggml_new_tensor_1d(ctx, ktype, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, vtype, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -1943,11 +1950,51 @@ static bool llama_kv_cache_find_slot( const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; - // for Mamba and/or other model archs that only ever use one slot - if (n_ctx == 1) { - // hopefully no one actually uses a context size of 1 on Transformer-based models - return true; + if (cache.unlimited) { + // For unlimited context architectures (like Mamba), + // each KV cache cell can store the state for a whole sequence. + + // starting point to find the minimum seq_id used in the batch + cache.head = cache.size - 1; + // likewise, to find the max seq_id in the batch + cache.used = 0; + for (uint32_t i = 0; i < n_tokens; ++i) { + for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + // make sure it's a valid seq_id + if ((uint32_t)seq_id < cache.size) { + // the number of "used" cells is simply the biggest seq_id + if (cache.used < (uint32_t)seq_id) { + cache.used = seq_id; + } + // the "head" is the smallest seq_id + if (cache.head > (uint32_t)seq_id) { + cache.head = seq_id; + } + // Assuming the tokens are in-order + if (batch.pos[i] != cache.cells[seq_id].pos + 1) { + // What should happen when the pos backtracks? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); + return false; + } + cache.cells[seq_id].pos = batch.pos[i]; + // NOTE: seq_ids are not inserted here, because they are handled when the graph is built. + } else { + // too big seq_id + // TODO: would it be possible to resize the KV cache size instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size); + return false; + } + } + } + + cache.n = cache.used - cache.head + 1; + // sanity check (max >= min) + return cache.used >= cache.head; } + // otherwise, one cell per token. if (n_tokens > n_ctx) { LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); @@ -2026,6 +2073,13 @@ static void llama_kv_cache_seq_rm( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.unlimited) { + // can only remove whole sequences for models like Mamba + GGML_ASSERT(p0 == 0); + GGML_ASSERT((uint32_t)seq_id < cache.size); + GGML_ASSERT(cache.cells[seq_id].pos < p1); + } + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { if (seq_id < 0) { @@ -2058,6 +2112,26 @@ static void llama_kv_cache_seq_cp( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.unlimited) { + if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) { + // intent to "copy from" (does not support copy chains) + cache.cells[seq_id_dst].delta = seq_id_src; + // NOTE: a sequence can't have multiple sources, but can have multiple destinations. + // For compatibility with the other KV cache API functions, + // the seq_id(s) of a slot suggests an intent to "copy to" those id(s), + // so that when a sequence is copied, it can initially be found from the source cell. + cache.cells[seq_id_src].seq_id.insert(seq_id_dst); + // prevent the destination from getting cleared + cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); + // repurposed as a "need copy" flag + // (shifting can't be done anyway for this kind of KV cache) + cache.has_shift = seq_id_src != seq_id_dst; + // NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet) + cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + } + return; + } + cache.head = 0; for (uint32_t i = 0; i < cache.size; ++i) { @@ -2097,6 +2171,10 @@ static void llama_kv_cache_seq_shift( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.unlimited) { + GGML_ASSERT(false); // not supported + } + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.has_shift = true; @@ -2126,6 +2204,10 @@ static void llama_kv_cache_seq_div( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + if (cache.unlimited) { + GGML_ASSERT(false); // not supported + } + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.has_shift = true; @@ -4453,6 +4535,8 @@ static void llm_build_k_shift( case LLM_ROPE_GLM: rope_type = 4; break; } + GGML_ASSERT(kv.size == n_ctx); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions @@ -4484,6 +4568,8 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(kv.size == n_ctx); + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed @@ -4693,6 +4779,8 @@ static struct ggml_tensor * llm_build_kqv( cb(kq, "kq_soft_max_ext", il); } + GGML_ASSERT(kv.size == n_ctx); + // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -4837,8 +4925,8 @@ struct llm_build_context { norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), - n_kv (worst_case ? kv_self.size : kv_self.n), - kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), + n_kv (worst_case ? kv_self.size : kv_self.n), + kv_head (worst_case ? (kv_self.unlimited ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), cb (cb), @@ -6907,8 +6995,6 @@ struct llm_build_context { struct ggml_cgraph * build_mamba() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - const int32_t n_tok = batch.n_tokens; - const int64_t d_model = n_embd; const int64_t d_inner = n_head; GGML_ASSERT(2 * d_model == d_inner); @@ -6919,33 +7005,55 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; - // {n_embd, n_tok} + GGML_ASSERT(kv_self.used - kv_self.head + 1 == 1); // TODO: support more than one sequence per batch + + // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); cb(inpL, "inp_embd", -1); for (int il = 0; il < n_layer; ++il) { // (ab)using the kv cache to store the state // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner); - ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner); + ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size); + ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size); + + // do copies between states when needed (nothing to do with rope or shifts) + // TODO: maybe hide this in a function, a bit like llm_build_k_shift + if (do_rope_shift) { + conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift); + ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift); - // reset the states when starting a new sequence - // TODO: ensure kv_self clearing is handled - if (!batch.pos || batch.pos[0] == 0) { - conv_state = ggml_scale(ctx0, conv_state, 0); - ssm_state = ggml_scale(ctx0, ssm_state, 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); } + { + ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_state_mask, 1, n_kv, lctx.inp_state_mask->nb[0], 0); + // clear states of sequences which are starting at the beginning of this batch + conv_states = ggml_mul(ctx0, + ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), + state_mask); + ssm_states = ggml_mul(ctx0, + ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), + state_mask); + } + + // TODO: support more than one sequence per batch (these could then use ggml_reshape_3d) + ggml_tensor * conv_state = ggml_view_2d(ctx0, conv_states, d_conv - 1, d_inner, + (d_conv - 1)*ggml_element_size(conv_states), 0); + ggml_tensor * ssm_state = ggml_view_2d(ctx0, ssm_states, d_state, d_inner, + (d_state)*ggml_element_size(ssm_states), 0); + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // {n_embd, 2*d_inner} * {n_embd, n_tok} => {2*d_inner, n_tok} + // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); // split the above in two - // => {d_inner, n_tok} + // => {d_inner, n_tokens} struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner); @@ -6955,10 +7063,10 @@ struct llm_build_context { // The following tensor is too big in order to avoid an assertion error when making an overlapping view. // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation - // This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}, + // This could then be a tensor with ne[] = {(d_conv-1)+n_tokens, d_inner}, // but the size difference is not that big (d_conv is usually 4). - struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok); - const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x); + struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tokens); + const size_t conv_x_nb1 = (d_conv - 1 + n_tokens) * ggml_element_size(conv_x); conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0); // making x contiguous is necessary because ggml_set expects it @@ -6967,18 +7075,18 @@ struct llm_build_context { // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)), - ggml_view_tensor(ctx0, kv_self.k_l[il]))); + ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tokens*ggml_element_size(conv_x)), + ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner), kv_self.head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_x)))); // prepare convolution for all tokens in the batch with a self-overlapping view, // shifting by one column each ... depth? ... with a window of d_conv columns. - // {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok} - conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, 1*ggml_element_size(conv_x), 0); + // {(d_conv-1)+n_tokens, d_inner} => {d_conv, d_inner, n_tokens} + conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tokens, conv_x_nb1, 1*ggml_element_size(conv_x), 0); // perform convolution - // => {1, d_inner, n_tok} + // => {1, d_inner, n_tokens} x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d)); - // => {d_inner, n_tok, 1} + // => {d_inner, n_tokens, 1} x = ggml_permute(ctx0, x, 2, 0, 1, 3); // bias @@ -6989,38 +7097,38 @@ struct llm_build_context { // ssm { - // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok} + // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); // split - struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0); - struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); + struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); - // {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok} + // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); // Custom operator to implement some of the optimizations // described in the Annex D of the Mamba paper. // TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C) - // => {d_state, d_inner, n_tok} + // => {d_state, d_inner, n_tokens} ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B); // only store last state ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]), - ggml_view_tensor(ctx0, kv_self.v_l[il]))); + ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tokens-1)*ssm_state->nb[2]), + ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner, kv_self.head*d_state*d_inner*ggml_element_size(ssm_state)))); - // {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok} + // {d_state, d_inner, n_tokens} * {d_state, n_tokens} => {d_inner, 1, n_tokens} struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3)); - // => {d_inner, n_tok} + // => {d_inner, n_tokens} y = ggml_permute(ctx0, y, 0, 2, 1, 3); - // {d_inner, n_tok} * {d_inner} => {d_inner, n_tok} + // {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens} y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); - // {d_inner, n_embd} * {d_inner, n_tok} => {n_embd, n_tok} + // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); } @@ -7108,6 +7216,9 @@ static struct ggml_cgraph * llama_build_graph( GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); float * data = (float *) lctx.inp_KQ_mask->data; + // For Transformers, use only the previous KV cells + // of the correct sequence for each token of the batch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; @@ -7124,18 +7235,61 @@ static struct ggml_cgraph * llama_build_graph( } } } + // For Mamba (and other constant-time-and-size architectures), + // update the correct state(s)/sequence(s) for each token of the batch. + // Source and destination states are both the same for the sake of implementation simplicity. + // It would be more complex if they were sometimes the same and somtimes not. + // (with Transformers, source KV cells are never the destination, + // which is also simpler, but more memory hungry) } if (llm.do_rope_shift) { - const int64_t n_ctx = llm.n_ctx; + const uint32_t kv_size = lctx.kv_self.size; GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; - for (int i = 0; i < n_ctx; ++i) { + for (uint32_t i = 0; i < kv_size; ++i) { data[i] = lctx.kv_self.cells[i].delta; } } + + if (lctx.kv_self.unlimited) { + const uint32_t kv_size = lctx.kv_self.size; + const uint32_t n_kv = lctx.kv_self.n; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_state_mask->buffer)); + float * data = (float *) lctx.inp_state_mask->data; + + // states which are not affected by the current batch are left untouched + for (uint32_t i = 0; i < n_kv; ++i) { + llama_seq_id seq_id = i + lctx.kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; + bool has_self_seq = kv_cell.has_seq_id(seq_id); + + data[i] = (float) has_self_seq; + + // ensure current sequences will be kept + if (!has_self_seq) { + kv_cell.seq_id.insert(seq_id); + } + } + // remove extraneous seq_ids when state copies are made + if (llm.do_rope_shift) { + for (uint32_t i = 0; i < kv_size; ++i) { + llama_kv_cell & kv_cell = lctx.kv_self.cells[i]; + uint32_t n_seqs = kv_cell.seq_id.size(); + bool has_self_seq = kv_cell.has_seq_id(i); + + if (has_self_seq && n_seqs > 1) { + kv_cell.seq_id.clear(); + kv_cell.seq_id.insert(i); + } else if (!has_self_seq && n_seqs > 0) { + kv_cell.seq_id.clear(); + } + } + } + } } llm.init(); @@ -7309,13 +7463,15 @@ static int llama_decode_internal( return 1; } - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) kv_self.size, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); + if (!kv_self.unlimited) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + kv_self.n = std::min((int32_t) kv_self.size, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + } ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); @@ -7371,7 +7527,7 @@ static int llama_decode_internal( if (kv_self.has_shift) { kv_self.has_shift = false; for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].delta = 0; + kv_self.cells[i].delta = kv_self.unlimited ? i : 0; } } @@ -10526,6 +10682,7 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, + /*.n_parallel =*/ 1, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED, @@ -10689,6 +10846,7 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; cparams.n_batch = params.n_batch; + // TODO: maybe add n_parallel here too cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -10733,14 +10891,19 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; + uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; ggml_type type_v = params.type_v; - // Mamba (mis)uses the KV cache to store its states + // Mamba only needs a constant number of KV cache slots per sequence if (model->arch == LLM_ARCH_MAMBA) { + // Mamba needs as many slots as there are distinct sequences processed at the same time + // The extra slot allows dedicating a sequence id to the system prompt + // TODO: find a better way to get the max number of parallel sequences + kv_size = params.n_parallel + 1; // it's probably best to keep as much precision as possible for the states type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state - type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state + type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state } GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); @@ -10822,7 +10985,7 @@ struct llama_context * llama_new_context_with_model( ctx->backends.push_back(ctx->backend_cpu); if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, - cparams.n_ctx, cparams.offload_kqv)) { + kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -10856,7 +11019,7 @@ struct llama_context * llama_new_context_with_model( // graph inputs { ggml_init_params init_params = { - /* .mem_size */ ggml_tensor_overhead()*5, + /* .mem_size */ ggml_tensor_overhead()*(5 + ctx->kv_self.unlimited), /* .mem_buffer */ nullptr, /* .no_alloc */ true, }; @@ -10865,14 +11028,18 @@ struct llama_context * llama_new_context_with_model( ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); - ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); - ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); + ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, kv_size, cparams.n_batch); + ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size); + if (ctx->kv_self.unlimited) + ctx->inp_state_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size); ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); ggml_set_name(ctx->inp_pos, "inp_pos"); ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask"); ggml_set_name(ctx->inp_K_shift, "inp_K_shift"); + if (ctx->kv_self.unlimited) + ggml_set_name(ctx->inp_state_mask, "inp_state_mask"); ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true)); @@ -11342,12 +11509,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; - const auto & cparams = ctx->cparams; const auto n_layer = hparams.n_layer; const auto n_embd_k_gqa = hparams.n_embd_k_gqa(); const auto n_embd_v_gqa = hparams.n_embd_v_gqa(); - const auto n_ctx = cparams.n_ctx; const size_t kv_buf_size = kv_self.total_size(); const uint32_t kv_head = kv_self.head; @@ -11371,7 +11536,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat // v is not contiguous, copy row by row tmp_buf.resize(elt_size*kv_head); for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*elt_size*n_ctx, tmp_buf.size()); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*elt_size*kv_size, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); } } @@ -11453,12 +11618,10 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; - const auto & cparams = ctx->cparams; const int n_layer = hparams.n_layer; const int n_embd_k_gqa = hparams.n_embd_k_gqa(); const int n_embd_v_gqa = hparams.n_embd_v_gqa(); - const int n_ctx = cparams.n_ctx; size_t kv_buf_size; uint32_t kv_head; @@ -11483,7 +11646,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { // v is not contiguous, copy row by row size_t v_row_size = elt_size*kv_head; for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { - ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*elt_size*n_ctx, v_row_size); + ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*elt_size*kv_size, v_row_size); inp += v_row_size; } } diff --git a/llama.h b/llama.h index cec4158bc8e803..f57e71aaf0e5cb 100644 --- a/llama.h +++ b/llama.h @@ -211,6 +211,7 @@ extern "C" { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_parallel; // number of parallel sequences uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`