From 9f0a3ce3b573feb122666a01be6da00c2cc3e6e0 Mon Sep 17 00:00:00 2001 From: yanxing Date: Thu, 4 Jul 2024 17:54:44 +0800 Subject: [PATCH] [refactor] add chat prompts. --- include/llm.hpp | 11 ++++++-- src/llm.cpp | 70 ++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/include/llm.hpp b/include/llm.hpp index a57ae49e..393b6f44 100644 --- a/include/llm.hpp +++ b/include/llm.hpp @@ -214,6 +214,10 @@ class LlmConfig { return llm_config_.value("attention_mask", "int"); } + std::string chat_template() const { + return llm_config_.value("chat_template", ""); + } + std::string prompt_template() const { return llm_config_.value("prompt_template", ""); } @@ -222,6 +226,7 @@ class LlmConfig { class Llm { public: + using PromptItem = std::pair; // Llm(std::shared_ptr config) : config_(config) {} virtual ~Llm() { modules_.clear(); @@ -232,8 +237,10 @@ class Llm { virtual void load(); VARP forward(const std::vector& input_ids); int sample(VARP logits, const std::vector& pre_ids); - std::string apply_chat_template(const std::string& input_str) const; - std::string response(const std::string& input_str, std::ostream* os = &std::cout, const char* end_with = nullptr); + std::string apply_prompt_template(const std::string& user_content) const; + std::string apply_chat_template(const std::vector& chat_prompts) const; + std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr); + std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); void generate_init(); std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with); std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); diff --git a/src/llm.cpp b/src/llm.cpp index ae319220..e59d55f2 100644 --- a/src/llm.cpp +++ b/src/llm.cpp @@ -193,31 +193,60 @@ int Llm::sample(VARP logits, const std::vector& pre_ids) { return token_id; } -std::string Llm::apply_chat_template(const std::string& input_str) const { - auto prompt = config_->prompt_template(); - if (prompt.empty()) return input_str; +static std::string apply_template(std::string prompt_template, const std::string& content, const std::string& role = "") { + if (prompt_template.empty()) return content; + if (!role.empty()) { + const std::string placeholder = "%r"; + size_t start_pos = prompt_template.find(placeholder); + if (start_pos == std::string::npos) return content; + prompt_template.replace(start_pos, placeholder.length(), role); + } const std::string placeholder = "%s"; - size_t start_pos = prompt.find(placeholder); - if (start_pos == std::string::npos) return input_str; - prompt.replace(start_pos, placeholder.length(), input_str); - return prompt; + size_t start_pos = prompt_template.find(placeholder); + if (start_pos == std::string::npos) return content; + prompt_template.replace(start_pos, placeholder.length(), content); + return prompt_template; +} + +std::string Llm::apply_prompt_template(const std::string& user_content) const { + auto chat_prompt = config_->prompt_template(); + return apply_template(chat_prompt, user_content); +} + +std::string Llm::apply_chat_template(const std::vector& chat_prompts) const { + auto chat_template = config_->chat_template(); + std::string prompt_result; + auto iter = chat_prompts.begin(); + for (; iter != chat_prompts.end() - 1; ++iter) { + prompt_result += apply_template(chat_template, iter->second, iter->first); + } + if (iter->first == "user") { + prompt_result += apply_prompt_template(iter->second); + } else { + prompt_result += apply_template(chat_template, iter->second, iter->first); + } + return prompt_result; } void Llm::chat() { + std::vector history; + history.push_back(std::make_pair("system", "You are a helpful assistant.")); while (true) { std::cout << "\nQ: "; - std::string input_str; - std::cin >> input_str; - if (input_str == "/exit") { + std::string user_str; + std::cin >> user_str; + if (user_str == "/exit") { break; } - if (input_str == "/reset") { - // reset(); + if (user_str == "/reset") { + history.resize(1); std::cout << "\nA: reset done." << std::endl; continue; } std::cout << "\nA: " << std::flush; - response(input_str); + history.emplace_back(std::make_pair("user", user_str)); + auto assistant_str = response(history); + history.emplace_back(std::make_pair("assistant", assistant_str)); std::cout << std::endl; } } @@ -292,7 +321,7 @@ std::string Llm::generate(const std::vector& input_ids, std::ostream* os, c } std::vector Llm::tokenizer(const std::string& query) { - auto prompt = apply_chat_template(query); + auto prompt = apply_prompt_template(query); auto input_ids = tokenizer_->encode(prompt); return input_ids; } @@ -304,6 +333,15 @@ std::string Llm::response(const std::string& query, std::ostream* os, const char return generate(input_ids, os, end_with); } +std::string Llm::response(const std::vector& chat_prompts, std::ostream* os, const char* end_with) { + if (chat_prompts.empty()) { return ""; } + generate_init(); + if (!end_with) { end_with = "\n"; } + auto prompt = apply_chat_template(chat_prompts); + auto input_ids = tokenizer_->encode(prompt); + return generate(input_ids, os, end_with); +} + void Llm::print_speed() { auto prefill_s = prefill_us_ * 1e-6; auto decode_s = decode_us_ * 1e-6; @@ -475,7 +513,7 @@ std::vector Lvlm::url_encode(const std::string& url) { } std::vector Lvlm::tokenizer(const std::string& query) { - auto prompt = apply_chat_template(query); + auto prompt = apply_prompt_template(query); // split query std::regex img_regex("(.*?)"); std::string::const_iterator searchStart(prompt.cbegin()); @@ -619,7 +657,7 @@ std::vector Embedding::tokenizer(const std::string& query) { if (query.size() <= 256) { prompt = "为这个句子生成表示以用于检索相关文章:" + query; } - prompt = apply_chat_template(prompt); + prompt = apply_prompt_template(prompt); auto ids = tokenizer_->encode(prompt); return ids; }