Skip to content

Commit

Permalink
context shift fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
FSSRepo committed Oct 16, 2023
1 parent 2d9f11d commit d7eca25
Showing 1 changed file with 39 additions and 20 deletions.
59 changes: 39 additions & 20 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,8 @@ struct llama_server_context
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
slot.sent_count += result.text_to_send.size();
// add the token to slot queue and cache
slot.addTokenString(result);
}
slot.addTokenString(result);
if (slot.multibyte_pending > 0)
{
slot.multibyte_pending -= token_str.size();
Expand Down Expand Up @@ -925,8 +925,8 @@ struct llama_server_context
}

// context shift takes effect only when there is a single slot
if(slots.size() == 1) {
llama_client_slot slot = slots[0];
if(params.n_parallel == 1) {
llama_client_slot &slot = slots[0];
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
{
// Shift context
Expand Down Expand Up @@ -1028,22 +1028,16 @@ struct llama_server_context

slot.num_prompt_tokens = prompt_tokens.size();

slot.n_past = slot.params.cache_prompt ? common_part(slot.cache_tokens, prompt_tokens) : 0;

slot.cache_tokens = prompt_tokens;

if (slot.n_past == slot.num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
slot.n_past--;
}

slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;

if(!slot.params.cache_prompt) {
if(!slot.params.cache_prompt) {
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
slot.n_past = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
} else {
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
if (params.n_keep < 0 && params.n_parallel == 1)
{
params.n_keep = (int)slot.num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
//if input prompt is too big, truncate like normal
if (slot.num_prompt_tokens >= (size_t)n_ctx)
{
Expand All @@ -1059,14 +1053,26 @@ struct llama_server_context
});
slot.truncated = true;
prompt_tokens = new_tokens;
slot.num_prompt_tokens = prompt_tokens.size();
}
const size_t ps = slot.num_prompt_tokens;
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
}

llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);

slot.cache_tokens = prompt_tokens;

if (slot.n_past == slot.num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits.
printf("we have to evaluate at least 1 token to generate logits\n");
slot.n_past--;
}

LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
Expand Down Expand Up @@ -1185,7 +1191,7 @@ struct llama_server_context
}
}

if(kv_cache_free < 0) {
if(kv_cache_free < 0 && params.n_parallel > 1) {
LOG_TEE("\nError: kv cache is full, increase context size.");
return false;
}
Expand Down Expand Up @@ -1581,6 +1587,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
}

static void slot_print_timings(struct llama_client_slot * slot) {
LOG_TEE("\n");
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, slot->t_prompt_processing, slot->num_prompt_tokens_processed, slot->t_prompt_processing / slot->num_prompt_tokens_processed, 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed);
LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, slot->t_token_generation, slot->n_decoded, slot->t_token_generation / slot->n_decoded, 1e3 / slot->t_token_generation * slot->n_decoded);
LOG_TEE("%s: total time = %10.2f ms\n", __func__, slot->t_prompt_processing + slot->t_token_generation);
}

static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot)
{
const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx));
Expand All @@ -1606,7 +1621,7 @@ static json format_generation_settings(llama_server_context &llama, llama_client
{"penalize_nl", slot->sparams.penalize_nl},
{"stop", slot->params.antiprompt},
{"n_predict", slot->params.n_predict},
// {"n_keep", slot.params.n_keep},
{"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos},
{"stream", slot->params.stream},
{"logit_bias", slot->sparams.logit_bias},
Expand Down Expand Up @@ -1730,7 +1745,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", -1);
llama.params.n_keep = json_value(body, "n_keep", 0);
slot->params.seed = json_value(body, "seed", default_params.seed);
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
Expand Down Expand Up @@ -2089,6 +2104,7 @@ int main(int argc, char **argv)
}

const json data = format_final_response(llama, slot, completion_text, probs);
slot_print_timings(slot);
slot->release();
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
Expand Down Expand Up @@ -2131,6 +2147,7 @@ int main(int argc, char **argv)
slot->generated_token_probs.begin(),
slot->generated_token_probs.begin() + sent_token_probs_index)
);
slot_print_timings(slot);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
Expand Down Expand Up @@ -2197,6 +2214,7 @@ int main(int argc, char **argv)
}

const json data = format_final_response(llama, slot, completion_text, probs);
slot_print_timings(slot);
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
} else {
Expand Down Expand Up @@ -2238,6 +2256,7 @@ int main(int argc, char **argv)
slot->generated_token_probs.begin(),
slot->generated_token_probs.begin() + sent_token_probs_index)
);
slot_print_timings(slot);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
Expand Down

0 comments on commit d7eca25

Please sign in to comment.