Skip to content

Commit

Permalink
mamba : dedicate an input tensor for state copy indices
Browse files Browse the repository at this point in the history
This makes it easier to adapt when/if token positions
(and by extension, inp_K_shift) are no longer integers.
  • Loading branch information
compilade committed Feb 26, 2024
1 parent 12de5c7 commit 3421d17
Showing 1 changed file with 89 additions and 25 deletions.
114 changes: 89 additions & 25 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,7 @@ struct llama_layer {
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
int32_t src = 0; // used by recurrent state models to copy states

std::set<llama_seq_id> seq_id;

Expand All @@ -1763,6 +1764,7 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool do_copy = false;
// with Mamba, a cell can hold the state for more than one past token
bool unlimited = false;

Expand Down Expand Up @@ -2001,7 +2003,8 @@ struct llama_context {
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]

#ifdef GGML_USE_MPI
Expand Down Expand Up @@ -2043,9 +2046,9 @@ static bool llama_kv_cache_init(

if (cache.unlimited) {
for (uint32_t i = 0; i < cache.size; ++i) {
cache.cells[i].delta = i;
cache.cells[i].src = i;
}
} // else, delta is already initialized to zero
}

#ifdef GGML_USE_CLBLAST
offload = false;
Expand Down Expand Up @@ -2296,19 +2299,20 @@ static void llama_kv_cache_seq_cp(

if (cache.unlimited) {
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
seq_id_src = cache.cells[seq_id_src].delta;
seq_id_src = cache.cells[seq_id_src].src;
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
// intent to "copy from"
// supports copy chains thanks to taking the source of the source
cache.cells[seq_id_dst].delta = seq_id_src;
cache.cells[seq_id_dst].src = seq_id_src;

// prevent the destination from getting cleared if the source is not empty
// preserve the "keep or clear" status of the copied sequence
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
} else {
cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
}
// repurposed as a "need copy" flag
// (shifting can't be done anyway for this kind of KV cache)
cache.has_shift = true;

cache.do_copy = true;

cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
}
Expand Down Expand Up @@ -5352,6 +5356,25 @@ struct llm_build_context {
return gf;
}

struct ggml_cgraph * build_s_copy() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);

conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);

// TODO: name the intermediate tensors with cb()

ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
}

return gf;
}

struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

Expand Down Expand Up @@ -7816,16 +7839,6 @@ struct llm_build_context {
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], (d_conv-1)*(d_inner), kv_self.size);
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], (d_state)*(d_inner), kv_self.size);

// do copies between states when needed (nothing to do with rope or shifts)
// TODO: do this in a another graph, a bit like build_k_shift
if (kv_self.has_shift) {
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);

ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
}

// clear states of sequences which are starting at the beginning of this batch
{
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
Expand Down Expand Up @@ -7978,6 +7991,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
return result;
}

static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
llama_batch dummy;
dummy.n_tokens = 0;

llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };

struct llm_build_context llm(lctx, dummy, cb, false);

llm.init();

struct ggml_cgraph * result = llm.build_s_copy();

llm.free();

return result;
}

static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch,
Expand Down Expand Up @@ -8113,6 +8143,18 @@ static void llama_set_k_shift(llama_context & lctx) {
}
}

static void llama_set_s_copy(llama_context & lctx) {
const int64_t kv_size = lctx.kv_self.size;

assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));

int32_t * data = (int32_t *) lctx.inp_s_copy->data;

for (int i = 0; i < kv_size; ++i) {
data[i] = lctx.kv_self.cells[i].src;
}
}

static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
Expand Down Expand Up @@ -8227,17 +8269,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}

if (kv_self.unlimited) {
const int64_t n_kv = kv_self.n;
const int64_t n_kv = kv_self.n;

{
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
float * data = (float *) lctx.inp_s_mask->data;

// states which are not affected by the current batch are left untouched
for (int i = 0; i < n_kv; ++i) {
llama_seq_id seq_id = i + lctx.kv_self.head;
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
bool has_self_seq = kv_cell.has_seq_id(seq_id);
llama_seq_id seq_id = i + lctx.kv_self.head;
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
bool has_self_seq = kv_cell.has_seq_id(seq_id);

data[i] = (float) has_self_seq;

Expand Down Expand Up @@ -8739,7 +8781,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
kv_self.has_shift = false;

for (uint32_t i = 0; i < kv_self.size; ++i) {
kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
kv_self.cells[i].delta = 0;
}
}
}

if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
llama_set_s_copy(lctx);

{
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);

llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
}

{
auto & kv_self = lctx.kv_self;

kv_self.do_copy = false;

for (uint32_t i = 0; i < kv_self.size; ++i) {
kv_self.cells[i].src = i;
}
}
}
Expand Down Expand Up @@ -12418,7 +12480,7 @@ struct llama_context * llama_new_context_with_model(
// graph inputs
{
ggml_init_params init_params = {
/* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)),
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)),
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
Expand All @@ -12433,6 +12495,7 @@ struct llama_context * llama_new_context_with_model(
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
if (ctx->kv_self.unlimited) {
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
}
Expand All @@ -12446,6 +12509,7 @@ struct llama_context * llama_new_context_with_model(
ggml_set_name(ctx->inp_mean, "inp_mean");
ggml_set_name(ctx->inp_cls, "inp_cls");
if (ctx->kv_self.unlimited) {
ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
}
Expand Down

0 comments on commit 3421d17

Please sign in to comment.