Skip to content

Commit

Permalink
[refactor] add chat prompts.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jul 4, 2024
1 parent 289fa44 commit 9f0a3ce
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 18 deletions.
11 changes: 9 additions & 2 deletions include/llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ class LlmConfig {
return llm_config_.value("attention_mask", "int");
}

std::string chat_template() const {
return llm_config_.value("chat_template", "");
}

std::string prompt_template() const {
return llm_config_.value("prompt_template", "");
}
Expand All @@ -222,6 +226,7 @@ class LlmConfig {

class Llm {
public:
using PromptItem = std::pair<std::string, std::string>; // <role, content>
Llm(std::shared_ptr<LlmConfig> config) : config_(config) {}
virtual ~Llm() {
modules_.clear();
Expand All @@ -232,8 +237,10 @@ class Llm {
virtual void load();
VARP forward(const std::vector<int>& input_ids);
int sample(VARP logits, const std::vector<int>& pre_ids);
std::string apply_chat_template(const std::string& input_str) const;
std::string response(const std::string& input_str, std::ostream* os = &std::cout, const char* end_with = nullptr);
std::string apply_prompt_template(const std::string& user_content) const;
std::string apply_chat_template(const std::vector<PromptItem>& chat_prompts) const;
std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr);
std::string response(const std::vector<PromptItem>& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr);
void generate_init();
std::string generate(const std::vector<int>& input_ids, std::ostream* os, const char* end_with);
std::vector<int> generate(const std::vector<int>& input_ids, int max_new_tokens = -1);
Expand Down
70 changes: 54 additions & 16 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,31 +193,60 @@ int Llm::sample(VARP logits, const std::vector<int>& pre_ids) {
return token_id;
}

std::string Llm::apply_chat_template(const std::string& input_str) const {
auto prompt = config_->prompt_template();
if (prompt.empty()) return input_str;
static std::string apply_template(std::string prompt_template, const std::string& content, const std::string& role = "") {
if (prompt_template.empty()) return content;
if (!role.empty()) {
const std::string placeholder = "%r";
size_t start_pos = prompt_template.find(placeholder);
if (start_pos == std::string::npos) return content;
prompt_template.replace(start_pos, placeholder.length(), role);
}
const std::string placeholder = "%s";
size_t start_pos = prompt.find(placeholder);
if (start_pos == std::string::npos) return input_str;
prompt.replace(start_pos, placeholder.length(), input_str);
return prompt;
size_t start_pos = prompt_template.find(placeholder);
if (start_pos == std::string::npos) return content;
prompt_template.replace(start_pos, placeholder.length(), content);
return prompt_template;
}

std::string Llm::apply_prompt_template(const std::string& user_content) const {
auto chat_prompt = config_->prompt_template();
return apply_template(chat_prompt, user_content);
}

std::string Llm::apply_chat_template(const std::vector<PromptItem>& chat_prompts) const {
auto chat_template = config_->chat_template();
std::string prompt_result;
auto iter = chat_prompts.begin();
for (; iter != chat_prompts.end() - 1; ++iter) {
prompt_result += apply_template(chat_template, iter->second, iter->first);
}
if (iter->first == "user") {
prompt_result += apply_prompt_template(iter->second);
} else {
prompt_result += apply_template(chat_template, iter->second, iter->first);
}
return prompt_result;
}

void Llm::chat() {
std::vector<PromptItem> history;
history.push_back(std::make_pair("system", "You are a helpful assistant."));
while (true) {
std::cout << "\nQ: ";
std::string input_str;
std::cin >> input_str;
if (input_str == "/exit") {
std::string user_str;
std::cin >> user_str;
if (user_str == "/exit") {
break;
}
if (input_str == "/reset") {
// reset();
if (user_str == "/reset") {
history.resize(1);
std::cout << "\nA: reset done." << std::endl;
continue;
}
std::cout << "\nA: " << std::flush;
response(input_str);
history.emplace_back(std::make_pair("user", user_str));
auto assistant_str = response(history);
history.emplace_back(std::make_pair("assistant", assistant_str));
std::cout << std::endl;
}
}
Expand Down Expand Up @@ -292,7 +321,7 @@ std::string Llm::generate(const std::vector<int>& input_ids, std::ostream* os, c
}

std::vector<int> Llm::tokenizer(const std::string& query) {
auto prompt = apply_chat_template(query);
auto prompt = apply_prompt_template(query);
auto input_ids = tokenizer_->encode(prompt);
return input_ids;
}
Expand All @@ -304,6 +333,15 @@ std::string Llm::response(const std::string& query, std::ostream* os, const char
return generate(input_ids, os, end_with);
}

std::string Llm::response(const std::vector<PromptItem>& chat_prompts, std::ostream* os, const char* end_with) {
if (chat_prompts.empty()) { return ""; }
generate_init();
if (!end_with) { end_with = "\n"; }
auto prompt = apply_chat_template(chat_prompts);
auto input_ids = tokenizer_->encode(prompt);
return generate(input_ids, os, end_with);
}

void Llm::print_speed() {
auto prefill_s = prefill_us_ * 1e-6;
auto decode_s = decode_us_ * 1e-6;
Expand Down Expand Up @@ -475,7 +513,7 @@ std::vector<int> Lvlm::url_encode(const std::string& url) {
}

std::vector<int> Lvlm::tokenizer(const std::string& query) {
auto prompt = apply_chat_template(query);
auto prompt = apply_prompt_template(query);
// split query
std::regex img_regex("<img>(.*?)</img>");
std::string::const_iterator searchStart(prompt.cbegin());
Expand Down Expand Up @@ -619,7 +657,7 @@ std::vector<int> Embedding::tokenizer(const std::string& query) {
if (query.size() <= 256) {
prompt = "为这个句子生成表示以用于检索相关文章:" + query;
}
prompt = apply_chat_template(prompt);
prompt = apply_prompt_template(prompt);
auto ids = tokenizer_->encode(prompt);
return ids;
}
Expand Down

0 comments on commit 9f0a3ce

Please sign in to comment.