Skip to content

Commit

Permalink
update test-backend-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Mar 29, 2024
1 parent 93db37e commit 325e5ef
Showing 1 changed file with 2 additions and 96 deletions.
98 changes: 2 additions & 96 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,17 +979,13 @@ struct test_mul_mat_id : public test_case {

ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
std::vector<ggml_tensor *> mats;
for (int i = 0; i < n_mats; i++) {
ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
mats.push_back(a);
}
ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
if (v) {
ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
}
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b);
return out;
}

Expand Down Expand Up @@ -1477,91 +1473,6 @@ struct test_leaky_relu : public test_case {
}
};

// Mixtral MOE
struct test_moe : public test_case {
const int n_experts;
const int n_experts_per_tok;
const int n_tokens;
const int n_embd;
const int n_ff;

std::string op_desc(ggml_tensor * t) override {
return "MOE";

GGML_UNUSED(t);
}

std::string vars() override {
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
}

test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
: n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);

std::vector<ggml_tensor *> ffn_up_exp(n_experts);
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
std::vector<ggml_tensor *> ffn_down_exp(n_experts);

for (int i = 0; i < n_experts; ++i) {
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
}

ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);

ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);

ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);

weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);

ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);

weights = ggml_div(ctx, weights, weights_sum);

// compute expert outputs
ggml_tensor * moe_out = nullptr;

for (int i = 0; i < n_experts_per_tok; ++i) {
ggml_tensor * cur_expert;

ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);

ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);

cur_gate = ggml_silu(ctx, cur_gate);

cur_expert = ggml_mul(ctx, cur_up, cur_gate);

cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);

cur_expert = ggml_mul(ctx, cur_expert,
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));

if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}

cur = moe_out;

return cur;
}
};


enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
Expand Down Expand Up @@ -2182,11 +2093,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op

// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
//test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
#endif
test_cases.emplace_back(new test_llama(1));
test_cases.emplace_back(new test_llama(2));
test_cases.emplace_back(new test_falcon(1));
Expand Down

0 comments on commit 325e5ef

Please sign in to comment.