Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : save and restore kv cache for single seq id #6341

Merged
merged 34 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
662aaea
llama : save and restore kv cache for single seq id
kaetemi Mar 27, 2024
5462817
remove trailing whitespace
kaetemi Mar 27, 2024
ab1c46a
respond error in case there's no space in the kv cache
kaetemi Mar 27, 2024
02a1840
add kv seq save restore to test case
kaetemi Mar 27, 2024
b8e8fac
add --slot-save-path arg to enable save restore and restrict save loc…
kaetemi Mar 27, 2024
b182f8f
Returning 0 for some cases, instead of asserting.
martindevans Mar 27, 2024
a2b48b9
cleanup error cases
kaetemi Mar 27, 2024
c4443d7
rename sequence state functions
kaetemi Mar 28, 2024
4d5356b
rename state get set functions
kaetemi Mar 28, 2024
bbcbf47
add previous function names back in with DEPRECATED notice
kaetemi Mar 29, 2024
8b5ae29
update doc
kaetemi Mar 29, 2024
a71ec3d
adjust endpoints to preferred style
kaetemi Mar 29, 2024
bf1d493
fix restoring zero cell count
kaetemi Mar 29, 2024
8ab1a17
handle seq rm return value
kaetemi Mar 29, 2024
0d22136
unused param
kaetemi Mar 29, 2024
29f18c2
keep in the size check
kaetemi Mar 29, 2024
f2e41b3
fix return types
kaetemi Mar 29, 2024
92c4681
add server test case for slot save restore
kaetemi Mar 29, 2024
60f685f
cleanup
kaetemi Mar 29, 2024
d38eef4
add cake
kaetemi Mar 30, 2024
ea717f7
cleanup style
kaetemi Mar 30, 2024
b509b8b
add special
kaetemi Mar 30, 2024
129b6ff
removing a whole sequence never fails
kaetemi Mar 30, 2024
8af7211
move sequence state file functionality from server to llama to match …
kaetemi Mar 30, 2024
3d6fa5b
catch exceptions on save as well
kaetemi Apr 1, 2024
b3f6da3
error log messages
kaetemi Apr 1, 2024
be714a0
check types for stricter restore
kaetemi Apr 1, 2024
0ccfbf2
update server doc
kaetemi Apr 1, 2024
205c44c
readme : update API changes date
ggerganov Apr 4, 2024
d9fd0d7
Merge branch 'master' into feature/save-restore-seq
kaetemi Apr 4, 2024
f2a4777
strict filename validation
kaetemi Apr 5, 2024
4a4f399
move include, reject bom as well
kaetemi Apr 5, 2024
2fbf0c3
also reject empty filename
kaetemi Apr 5, 2024
bf94e9f
reject whitespace and trailing dot
kaetemi Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Apr 4] State and session file functions reorganized under `llama_state_*` https://github.com/ggerganov/llama.cpp/pull/6341
- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
Expand Down
73 changes: 72 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <unordered_set>
#include <vector>
#include <cinttypes>
#include <codecvt>

#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
Expand All @@ -27,7 +28,6 @@
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <codecvt>
#include <locale>
#include <windows.h>
#include <fcntl.h>
Expand Down Expand Up @@ -1500,6 +1500,77 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
GGML_UNREACHABLE();
}

// Validate if a filename is safe to use
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
bool validate_file_name(const std::string & filename) {
if (!filename.length()) {
// Empty filename invalid
return false;
}
if (filename.length() > 255) {
// Limit at common largest possible filename on Linux filesystems
// to avoid unnecessary further validation
// (On systems with smaller limits it will be caught by the OS)
return false;
}

std::u32string filename_utf32;
try {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
filename_utf32 = converter.from_bytes(filename);

// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
return false;
}
} catch (const std::exception &) {
return false;
}

// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|| c == 0xFF0E // Fullwidth Full Stop (period equivalent)
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
return false;
}
}

// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
// Unicode and other whitespace is not affected, only 0x20 space
if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
return false;
}

// Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
if (filename.find("..") != std::string::npos) {
return false;
}

// Reject "."
if (filename == ".") {
return false;
}

return true;
}

//
// String utils
//
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ std::string gpt_random_prompt(std::mt19937 & rng);

void process_escapes(std::string& input);

bool validate_file_name(const std::string & filename);

//
// String utils
//
Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
// The file exists and is not empty
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1;
}
Expand Down Expand Up @@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
// optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());

LOG("saved session to %s\n", path_session.c_str());
}
Expand Down Expand Up @@ -935,7 +935,7 @@ int main(int argc, char ** argv) {

if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}

llama_print_timings(ctx);
Expand Down
101 changes: 95 additions & 6 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ int main(int argc, char ** argv) {

std::string result0;
std::string result1;
std::string result2;

// init
llama_model * model;
Expand All @@ -44,8 +45,8 @@ int main(int argc, char ** argv) {

// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
const size_t written = llama_copy_state_data(ctx, state_mem.data());
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data());

FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
Expand Down Expand Up @@ -97,13 +98,13 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_set_state_data(ctx2, state_mem.data())) {
if (read != llama_state_set_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Expand Down Expand Up @@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
n_past += 1;
}

printf("\n");
printf("\n\n");

llama_free(ctx2);
llama_free_model(model);

if (result0 != result1) {
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
return 1;
}

// make new context
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));

printf("\nsingle seq run: %s", params.prompt.c_str());

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));

FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx3, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
return 1;
}

fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
}

// restore state (last tokens)
n_past = n_past_saved;

// save seq 0 and load into seq 1
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
llama_free_model(model);
return 1;
}
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);

// erase whole kv
llama_kv_cache_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__);

// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);
llama_free_model(model);
return 1;
}
fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
}

// third run with seq 1 instead of 0
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx3);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx3, &candidates_p);
auto next_token_str = llama_token_to_piece(ctx3, next_token);

printf("%s", next_token_str.c_str());
result2 += next_token_str;

if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx3);
llama_free_model(model);
return 1;
}
n_past += 1;
}

printf("\n");

llama_free(ctx3);
llama_free_model(model);

if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
return 1;
}

fprintf(stderr, "\n%s : success\n", __func__);

return 0;
Expand Down
52 changes: 52 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/
- `-n N, --n-predict N`: Set the maximum tokens to predict. Default: `-1`
- `--slots-endpoint-disable`: To disable slots state monitoring endpoint. Slots state may contain user data, prompts included.
- `--metrics`: enable prometheus `/metrics` compatible endpoint. Default: disabled
- `--slot-save-path PATH`: Specifies the path where the state of slots (the prompt cache) can be stored. If not provided, the slot management endpoints will be disabled.
- `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template)
- `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled
- `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json`
Expand Down Expand Up @@ -517,6 +518,57 @@ Available metrics:
- `llamacpp:requests_processing`: Number of requests processing.
- `llamacpp:requests_deferred`: Number of requests deferred.

- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.

*Options:*

`filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.

### Result JSON

```json
{
"id_slot": 0,
"filename": "slot_save_file.bin",
"n_saved": 1745,
"n_written": 14309796,
"timings": {
"save_ms": 49.865
}
}
```

- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.

*Options:*

`filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.

### Result JSON

```json
{
"id_slot": 0,
"filename": "slot_save_file.bin",
"n_restored": 1745,
"n_read": 14309796,
"timings": {
"restore_ms": 42.937
}
}
```

- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.

### Result JSON

```json
{
"id_slot": 0,
"n_erased": 1745
}
```

## More examples

### Change system prompt on runtime
Expand Down
Loading
Loading