Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow exporting a view of the KV cache #4180

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <regex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cinttypes>
Expand Down Expand Up @@ -1386,3 +1387,66 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}

//
// KV cache utils
//

void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it better to use LOG_TEE instead of printf?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can make that change. I find the logging behavior sort of confusing. I.E. if I compile without logging I'll see the LOG_TEE output. However, if I compile with logging but --disable-logs it will just never show up on the console.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, actually... Using the LOG stuff is possibly sort of weird if I'm printing stuff out in pieces rather than line by line like those functions do. What do you think, still change it?

view.n_cells, view.n_max_seq, view.used_cells, view.token_count);
llama_kv_cache_view_cell * c_curr = view.cells;
struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
int seq_count = 0;
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j].seq_id >= 0) { seq_count++; }
}
putchar(int('0' + (std::min(9, seq_count))));
}
printf("\n=== Done dumping\n");
}

void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count);

std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;
struct llama_kv_cache_view_cell_sequence * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j].seq_id < 0) { continue; }
if (seqs.find(cs_curr[j].seq_id) == seqs.end()) {
seqs[cs_curr[j].seq_id] = seqs.size();
if (seqs.size() >= 10) { break; }
}
}
if (seqs.size() >= 10) { break; }
}
printf("=== Sequence legend: ");
for (const auto & it : seqs) {
printf("%zu=%d, ", it.second, it.first);
}

c_curr = view.cells;
cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j].seq_id >= 0) {
const auto & it = seqs.find(cs_curr[j].seq_id);
putchar(it != seqs.end() ? int('0' + it->second) : '+');
} else {
putchar('.');
}
}
putchar(' ');
}
printf("\n=== Done dumping\n");
}
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,10 @@ std::string get_sortable_timestamp();
void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);

//
// KV cache utils
//

void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 80);
5 changes: 5 additions & 0 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ int main(int argc, char ** argv) {
int32_t n_total_gen = 0;
int32_t n_cache_miss = 0;

struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);

const auto t_main_start = ggml_time_us();

LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
Expand Down Expand Up @@ -201,6 +203,9 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n");

while (true) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
KerfuffleV2 marked this conversation as resolved.
Show resolved Hide resolved

llama_batch_clear(batch);

// decode any currently ongoing sequences
Expand Down
67 changes: 67 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8805,6 +8805,73 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
}
}

struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view result = {
/*.n_cells*/ 0,
/*.n_max_seq*/ n_max_seq,
/*.token_count*/ 0,
/*.used_cells*/ llama_get_kv_cache_used_cells(ctx),
/*.cells*/ nullptr,
/*.cells_sequences*/ nullptr,
};
return result;
}

void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
if (view->cells != nullptr) {
free(view->cells);
view->cells = nullptr;
}
if (view->cells_sequences != nullptr) {
free(view->cells_sequences);
view->cells_sequences = nullptr;
}
}

void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
view->n_cells = int32_t(ctx->kv_self.size);
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(struct llama_kv_cache_view_cell_sequence) * view->n_max_seq * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (struct llama_kv_cache_view_cell_sequence *)p;
}

const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
llama_kv_cache_view_cell * c_curr = view->cells;
struct llama_kv_cache_view_cell_sequence * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
int32_t token_count = 0;

for (uint32_t i = 0; i < ctx->kv_self.size; i++, c_curr++, cs_curr += view->n_max_seq) {
token_count += ctx->kv_self.cells[i].seq_id.size();
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;

int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
break;
}
cs_curr[seq_idx].seq_id = it;
seq_idx++;
}
if (seq_idx != 0) {
used_cells++;
}
for (; seq_idx < view->n_max_seq; seq_idx++) {
cs_curr[seq_idx].seq_id = -1;
}
}
view->token_count = token_count;
view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) {
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
__func__, ctx->kv_self.used, used_cells);
}
}

int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
int result = 0;

Expand Down
23 changes: 23 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,29 @@ extern "C" {
// KV cache
//

struct llama_kv_cache_view_cell {
llama_pos pos;
};

struct llama_kv_cache_view_cell_sequence {
llama_seq_id seq_id;
};

struct llama_kv_cache_view {
int32_t n_cells;
int32_t n_max_seq;
int32_t token_count;
int32_t used_cells;
struct llama_kv_cache_view_cell *cells;
KerfuffleV2 marked this conversation as resolved.
Show resolved Hide resolved
struct llama_kv_cache_view_cell_sequence * cells_sequences;
};

LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);

LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);

LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);

// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
Expand Down
Loading