Skip to content

Commit

Permalink
add --slot-save-path arg to enable save restore and restrict save loc…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
kaetemi committed Mar 27, 2024
1 parent 02a1840 commit b8e8fac
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct server_params {

bool slots_endpoint = true;
bool metrics_endpoint = false;
std::string slot_save_path;
};

struct server_slot {
Expand Down Expand Up @@ -1628,6 +1629,7 @@ struct server_context {
const int64_t t_start = ggml_time_us();

std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"];
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
Expand All @@ -1645,7 +1647,7 @@ struct server_context {
}
GGML_ASSERT(nwrite <= state_data.size());

std::ofstream outfile(filename, std::ios::binary);
std::ofstream outfile(filepath, std::ios::binary);
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
outfile.close();

Expand Down Expand Up @@ -1678,8 +1680,9 @@ struct server_context {

const int64_t t_start = ggml_time_us();

std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params?
std::ifstream infile(filename, std::ios::binary);
std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"];
std::ifstream infile(filepath, std::ios::binary);
if (!infile.is_open()) {
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
break;
Expand Down Expand Up @@ -2392,6 +2395,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --log-disable disables logging to a file.\n");
printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
printf("\n");
printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
printf(" --override-kv KEY=TYPE:VALUE\n");
Expand Down Expand Up @@ -2798,6 +2802,16 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
sparams.slots_endpoint = false;
} else if (arg == "--metrics") {
sparams.metrics_endpoint = true;
} else if (arg == "--slot-save-path") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.slot_save_path = argv[i];
// if doesn't end with DIRECTORY_SEPARATOR, add it
if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
sparams.slot_save_path += DIRECTORY_SEPARATOR;
}
} else if (arg == "--chat-template") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -3300,18 +3314,24 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};

const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
const auto handle_slot_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"];
std::string filename = request_data["filename"];
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
res_error(res, "Invalid filename");
return;
}
std::string filepath = sparams.slot_save_path + filename;

server_task task;
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
task.data = {
{ "id_slot", id_slot },
{ "filename", filename },
{ "filepath", filepath }
};

const int id_task = ctx_server.queue_tasks.post(task);
Expand All @@ -3327,18 +3347,24 @@ int main(int argc, char ** argv) {
}
};

const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
const auto handle_slot_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"];
std::string filename = request_data["filename"];
if (filename.find('/') != std::string::npos || filename.find('\\') != std::string::npos || filename.find("..") != std::string::npos) {
res_error(res, "Invalid filename");
return;
}
std::string filepath = sparams.slot_save_path + filename;

server_task task;
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
task.data = {
{ "id_slot", id_slot },
{ "filename", filename },
{ "filepath", filepath }
};

const int id_task = ctx_server.queue_tasks.post(task);
Expand Down Expand Up @@ -3741,9 +3767,12 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
svr->Post("/slot/save", handle_slot_save);
svr->Post("/slot/restore", handle_slot_restore);
svr->Post("/slot/erase", handle_slot_erase);
if (!sparams.slot_save_path.empty()) {
// only enable slot endpoints if slot_save_path is set
svr->Post("/slot/save", handle_slot_save);
svr->Post("/slot/restore", handle_slot_restore);
svr->Post("/slot/erase", handle_slot_erase);
}

//
// Start the server
Expand Down

0 comments on commit b8e8fac

Please sign in to comment.