From d9f75f3ccf1fda0d324ab231a4ea034303dffd0f Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Sat, 5 Aug 2023 14:05:15 +0200 Subject: [PATCH 1/3] Allow passing grammar to completion endpoint --- Makefile | 2 +- examples/server/README.md | 2 ++ examples/server/server.cpp | 49 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a692a39ea85e0..23bb3d1051aba 100644 --- a/Makefile +++ b/Makefile @@ -377,7 +377,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml. save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS) +server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) diff --git a/examples/server/README.md b/examples/server/README.md index aee31ae42e517..e56ca063a9f0e 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -151,6 +151,8 @@ node . `mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1). + `grammar`: Set grammar for grammar-based sampling (default: no grammar) + `seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed). `ignore_eos`: Ignore end of stream token and continue generating (default: false). diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c0725088f1018..19bdca40430bd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -195,6 +196,8 @@ struct llama_server_context llama_context *ctx = nullptr; gpt_params params; + llama_grammar *grammar = nullptr; + bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -226,6 +229,7 @@ struct llama_server_context void rewind() { params.antiprompt.clear(); + params.grammar.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; @@ -237,6 +241,7 @@ struct llama_server_context stopped_limit = false; stopping_word = ""; multibyte_pending = 0; + grammar = nullptr; n_remain = 0; n_past = 0; @@ -257,6 +262,35 @@ struct llama_server_context return true; } + void loadGrammar() + { + if (!params.grammar.empty()) { + grammar_parser::parse_state parsed_grammar; + + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + fprintf(stderr, "%s: grammar parse error\n", __func__); + return; + } + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + } + void loadPrompt() { params.prompt.insert(0, 1, ' '); // always add a first space @@ -420,6 +454,10 @@ struct llama_server_context logits[llama_token_nl()] = nl_logit; } + if (grammar != nullptr) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling @@ -457,10 +495,15 @@ struct llama_server_context } } + if (grammar != nullptr) { + llama_grammar_accept_token(ctx, grammar, result.tok); + } + for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) { result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(result.tok); num_tokens_predicted++; @@ -947,6 +990,7 @@ static json format_generation_settings(llama_server_context &llama) {"stream", llama.stream}, {"logit_bias", llama.params.logit_bias}, {"n_probs", llama.params.n_probs}, + {"grammar", llama.params.grammar}, }; } @@ -1048,6 +1092,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.seed = body.value("seed", default_params.seed); llama.params.prompt = body.value("prompt", default_params.prompt); + llama.params.grammar = body.value("grammar", default_params.grammar); llama.params.n_probs = body.value("n_probs", default_params.n_probs); llama.params.logit_bias.clear(); @@ -1179,6 +1224,7 @@ int main(int argc, char **argv) parse_options_completion(json::parse(req.body), llama); + llama.loadGrammar(); llama.loadPrompt(); llama.beginCompletion(); @@ -1359,6 +1405,9 @@ int main(int argc, char **argv) return 1; } + if (llama.grammar != nullptr) { + llama_grammar_free(llama.grammar); + } llama_backend_free(); return 0; From b6524985df4c10442e06d26ad0b4196426a87cc3 Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Mon, 7 Aug 2023 11:46:26 +0200 Subject: [PATCH 2/3] Include review comments --- examples/server/server.cpp | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 19bdca40430bd..51197256693aa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -262,7 +262,7 @@ struct llama_server_context return true; } - void loadGrammar() + bool loadGrammar() { if (!params.grammar.empty()) { grammar_parser::parse_state parsed_grammar; @@ -270,18 +270,15 @@ struct llama_server_context parsed_grammar = grammar_parser::parse(params.grammar.c_str()); // will be empty (default) if there are parse errors if (parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: grammar parse error\n", __func__); - return; + LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); + return false; } - fprintf(stderr, "%s: grammar:\n", __func__); grammar_parser::print_grammar(stderr, parsed_grammar); - fprintf(stderr, "\n"); { auto it = params.logit_bias.find(llama_token_eos()); if (it != params.logit_bias.end() && it->second == -INFINITY) { - fprintf(stderr, - "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); } } @@ -289,6 +286,7 @@ struct llama_server_context grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + return true; } void loadPrompt() @@ -1224,7 +1222,12 @@ int main(int argc, char **argv) parse_options_completion(json::parse(req.body), llama); - llama.loadGrammar(); + if (!llama.loadGrammar()) + { + res.status = 400; + return; + } + llama.loadPrompt(); llama.beginCompletion(); @@ -1376,8 +1379,12 @@ int main(int argc, char **argv) svr.set_error_handler([](const Request &, Response &res) { - res.set_content("File Not Found", "text/plain"); - res.status = 404; }); + if (res.status == 400) { + res.set_content("Invalid request", "text/plain"); + } else { + res.set_content("File Not Found", "text/plain"); + res.status = 404; + } }); // set timeouts and change hostname and port svr.set_read_timeout(sparams.read_timeout); From 8b73356b2d340645c189736b7952ae5bb6004886 Mon Sep 17 00:00:00 2001 From: Martin Krasser Date: Tue, 8 Aug 2023 09:44:28 +0200 Subject: [PATCH 3/3] Fix trailing whitespace --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 51197256693aa..4cbe5ac2c9892 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1383,7 +1383,7 @@ int main(int argc, char **argv) res.set_content("Invalid request", "text/plain"); } else { res.set_content("File Not Found", "text/plain"); - res.status = 404; + res.status = 404; } }); // set timeouts and change hostname and port