Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batched : add bench tool #3545

Merged
merged 7 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ models-mnt
/server
/simple
/batched
/batched-bench
/export-lora
/finetune
/speculative
Expand Down
13 changes: 11 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml simple batched save-load-state server embd-input-test gguf llama-bench baby-llama beam-search speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench baby-llama beam-search \
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe
TEST_TARGETS = \
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe

# Code coverage output files
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
Expand Down Expand Up @@ -557,6 +563,9 @@ simple: examples/simple/simple.cpp build-info.h ggml.
batched: examples/batched/batched.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

batched-bench: examples/batched-bench/batched-bench.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ else()
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple)
add_subdirectory(batched)
add_subdirectory(batched-bench)
add_subdirectory(speculative)
add_subdirectory(parallel)
add_subdirectory(embd-input)
Expand Down
5 changes: 5 additions & 0 deletions examples/batched-bench/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET batched-bench)
add_executable(${TARGET} batched-bench.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
51 changes: 51 additions & 0 deletions examples/batched-bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# llama.cpp/example/batched-bench

Benchmark the batched decoding performance of `llama.cpp`

## Usage

There are 2 modes of operation:

- `prompt not shared` - each batch has a separate prompt of size `PP` (i.e. `N_KV = B*(PP + TG)`)
- `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`)

```bash
./batched-bench MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL>

# LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared
./batched-bench ./models/llama-7b/ggml-model-f16.gguf 16384 0 99

# LLaMA 7B, Q8_0, N_KV_MAX = 16384 (8GB), prompt is shared
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 16384 1 99

# custom set of batches
./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32
```

## Sample results

- `PP` - prompt tokens per batch
- `TG` - generated tokens per batch
- `B` - number of batches
- `N_KV` - required KV cache size
- `T_PP` - prompt processing time (i.e. time to first token)
- `S_PP` - prompt processing speed (`(B*PP)/T_PP` or `PP/T_PP`)
- `T_TG` - time to generate all batches
- `S_TG` - text generation speed (`(B*TG)/T_TG`)
- `T` - total time
- `S` - total speed (i.e. all tokens / total time)

| PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
|-------|--------|------|--------|----------|----------|----------|----------|----------|----------|
| 128 | 128 | 1 | 256 | 0.108 | 1186.64 | 3.079 | 41.57 | 3.187 | 80.32 |
| 128 | 128 | 2 | 512 | 0.198 | 1295.19 | 5.029 | 50.90 | 5.227 | 97.95 |
| 128 | 128 | 4 | 1024 | 0.373 | 1373.96 | 6.878 | 74.44 | 7.251 | 141.23 |
| 128 | 128 | 8 | 2048 | 0.751 | 1363.27 | 7.344 | 139.43 | 8.095 | 252.99 |
| 128 | 128 | 16 | 4096 | 1.570 | 1304.68 | 8.455 | 242.23 | 10.024 | 408.60 |
| 128 | 128 | 32 | 8192 | 3.408 | 1201.73 | 8.801 | 465.40 | 12.209 | 670.96 |
| 128 | 256 | 1 | 384 | 0.107 | 1196.70 | 6.329 | 40.45 | 6.436 | 59.67 |
| 128 | 256 | 2 | 768 | 0.194 | 1317.45 | 10.239 | 50.00 | 10.433 | 73.61 |
| 128 | 256 | 4 | 1536 | 0.366 | 1399.03 | 13.960 | 73.35 | 14.326 | 107.22 |
| 128 | 256 | 8 | 3072 | 0.751 | 1363.92 | 15.110 | 135.54 | 15.861 | 193.69 |
| 128 | 256 | 16 | 6144 | 1.569 | 1304.93 | 18.073 | 226.64 | 19.642 | 312.80 |
| 128 | 256 | 32 | 12288 | 3.409 | 1201.35 | 19.223 | 426.15 | 22.633 | 542.93 |
251 changes: 251 additions & 0 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
#include "common.h"
#include "llama.h"

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>

// mutates the input string
static std::vector<int> parse_list(char * p) {
std::vector<int> ret;

char * q = p;

while (*p) {
if (*p == ',') {
*p = '\0';
ret.push_back(std::atoi(q));
q = p + 1;
}

++p;
}

ret.push_back(std::atoi(q));

return ret;
}

int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL>\n" , argv[0]);
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
printf(" example: %s ggml-model-f16.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
return 1 ;
}

int n_kv_max = 2048;
int is_pp_shared = 0;
int n_gpu_layers = 0;
int mmq = 0;

std::vector<int> n_pp = { 128, 256, 512, 1024, 2048, 3584, 7680, };
std::vector<int> n_tg = { 128, 256, };
std::vector<int> n_pl = { 1, 2, 4, 8, 16, 32, };
//std::vector<int> n_pl = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, };

if (argc >= 2) {
params.model = argv[1];
}

if (argc >= 3) {
n_kv_max = std::atoi(argv[2]);
}

if (argc >= 4) {
is_pp_shared = std::atoi(argv[3]);
}

if (argc >= 5) {
n_gpu_layers = std::atoi(argv[4]);
}

if (argc >= 6) {
mmq = std::atoi(argv[5]);
}

if (argc >= 7) {
n_pp = parse_list(argv[6]);
}

if (argc >= 8) {
n_tg = parse_list(argv[7]);
}

if (argc >= 9) {
n_pl = parse_list(argv[8]);
}

// init LLM

llama_backend_init(params.numa);

// initialize the model

llama_model_params model_params = llama_model_default_params();

model_params.n_gpu_layers = n_gpu_layers;

llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}

llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = 512;
ctx_params.mul_mat_q = mmq;

ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}

llama_batch batch = llama_batch_init(n_kv_max, 0);

// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
}

return true;
};

// warm up
{
batch.n_tokens = 16;

for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
}

LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");

for ( int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
for ( int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) {
const int pp = n_pp[i_pp];
const int tg = n_tg[i_tg];
const int pl = n_pl[i_pl];

const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);

if (n_ctx_req > n_kv_max) {
continue;
}

batch.n_tokens = is_pp_shared ? pp : pl*pp;

for (int i = 0; i < batch.n_tokens; ++i) {
batch.token[i] = 0;
batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
}
batch.logits[batch.n_tokens - 1] = true;

const auto t_pp_start = ggml_time_us();

llama_kv_cache_tokens_rm(ctx, -1, -1);

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}

if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
}
}

const auto t_pp_end = ggml_time_us();

const auto t_tg_start = ggml_time_us();

for (int i = 0; i < tg; ++i) {
batch.n_tokens = pl;

for (int j = 0; j < pl; ++j) {
batch.token[j] = 0;
batch.pos[j] = pp + i;
batch.seq_id[j] = j;
batch.logits[j] = true;
}

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
}

const auto t_tg_end = ggml_time_us();

const int32_t n_kv = n_ctx_req;

const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
const float t = t_pp + t_tg;

const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
const float speed_tg = pl*tg / t_tg;
const float speed = n_kv / t;

LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
}
}
}

llama_print_timings(ctx);

llama_batch_free(batch);

llama_free(ctx);
llama_free_model(model);

llama_backend_free();

fprintf(stderr, "\n\n");

return 0;
}
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

llama_context * ctx = llama_new_context_with_model(model, ctx_params);
Expand Down
Loading