Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server : fix segfault on long system prompt #8987

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,28 +1136,19 @@ struct server_context {
if (!system_prompt.empty()) {
system_tokens = ::llama_tokenize(ctx, system_prompt, true);

llama_batch_clear(batch);
const int32_t n_batch = llama_n_batch(ctx);
const int32_t n_tokens_prompt = system_tokens.size();

for (int i = 0; i < (int)system_tokens.size(); ++i) {
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
}
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);

const int32_t n_batch = llama_n_batch(ctx);
llama_batch_clear(batch);

for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there is some outdated logic in update_slots as well, because there is a similar loop there.

Copy link
Collaborator Author

@compilade compilade Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only problem I see in update_slots is when the batch size is smaller than the number of slots, in which case tokens are added to the batch anyway, which could be out-of-bounds.

llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);

When the batch size is bigger than the number of slots, it does not seem to have the same problem as the system prompt, because there is already never more than n_batch tokens added to a batch:

// add prompt tokens for processing in the current batch
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
}
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);

So yes, this loop:

for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {

might be a bit outdated, but I think it's still relevant for the error handling logic:

// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;

From the above I think what could be fixed is parallel generation with small batch sizes:

diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index a1bbafbc..b2bf5bb5 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -753,13 +753,13 @@ struct server_context {
         default_generation_settings_for_props = get_formated_generation(slots.front());
         default_generation_settings_for_props["seed"] = -1;
 
-        // the update_slots() logic will always submit a maximum of n_batch tokens
+        // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
         // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
         {
             const int32_t n_batch = llama_n_batch(ctx);
 
             // only a single seq_id per token is needed
-            batch = llama_batch_init(n_batch, 0, 1);
+            batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
         }
 
         metrics.init();

const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
for (int32_t j = 0; j < n_tokens; ++j) {
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
}

if (llama_decode(ctx, batch_view) != 0) {
if (llama_decode(ctx, batch) != 0) {
LOG_ERROR("llama_decode() failed", {});
return;
}
Expand Down
Loading