Skip to content

Commit

Permalink
mamba : apply suggestions from code review
Browse files Browse the repository at this point in the history
* mamba : remove unecessary branch for row-wise ssm_state and C multiplication

It was previously done to avoid permuting when only one token is processed
at a time (like when generating text), but permuting is cheap,
and dynamically changing the compute graph is not future-proof.

* ggml : in ggml_ssm_scan, use more appropriate asserts

* ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32
  • Loading branch information
compilade committed Feb 5, 2024
1 parent 64fbce0 commit 98e6328
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
13 changes: 7 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5901,8 +5901,8 @@ struct ggml_tensor * ggml_ssm_scan(
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
ggml_are_same_shape(x, dt);
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D

{
const int64_t d_state = s->ne[0];
Expand All @@ -5919,6 +5919,7 @@ struct ggml_tensor * ggml_ssm_scan(
bool is_node = false;

if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
GGML_ASSERT(false); // TODO: implement
is_node = true;
}

Expand Down Expand Up @@ -14177,7 +14178,7 @@ static void ggml_compute_forward_ssm_scan_f32(

// first batch
{
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
Expand All @@ -14191,14 +14192,14 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// ssm_state * dA + dB * x
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
}
}
}

// compute state for rest of tokens, previous state comes from dest
for (int i2 = 1; i2 < n_t; ++i2) {
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
Expand All @@ -14212,7 +14213,7 @@ static void ggml_compute_forward_ssm_scan_f32(
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// ssm_state * dA + dB * x
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
pdst[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
}
}
}
Expand Down
15 changes: 4 additions & 11 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7012,17 +7012,10 @@ struct llm_build_context {
ggml_view_2d(ctx0, ssm_state, d_state, d_inner, ssm_state->nb[1], (n_tok-1)*ssm_state->nb[2]),
ggml_view_tensor(ctx0, kv_self.v_l[il])));

struct ggml_tensor * y;
if (n_tok == 1) {
// row-wise dot product ("dn,n->d")
// {d_state, d_inner} * {d_state, 1} => {d_inner, 1}
y = ggml_mul_mat(ctx0, ssm_state, C);
} else {
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
// => {d_inner, n_tok}
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
}
// {d_state, d_inner, n_tok} * {d_state, n_tok} => {d_inner, 1, n_tok}
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, ggml_permute(ctx0, C, 0, 2, 1, 3));
// => {d_inner, n_tok}
y = ggml_permute(ctx0, y, 0, 2, 1, 3);
// {d_inner, n_tok} * {d_inner} => {d_inner, n_tok}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
Expand Down

0 comments on commit 98e6328

Please sign in to comment.