Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sampling : one sequence per sampling context #3601

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 11 additions & 50 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,50 +1,14 @@
#include "sampling.h"

llama_sampling_context::~llama_sampling_context() {
for (auto & it : sequence_contexts) {
if (it.second.grammar != NULL) {
llama_grammar_free(it.second.grammar);
it.second.grammar = NULL;
}
}
}

llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar) {
llama_sampling_context result;

result.params = params.sampling_params;
result.grammar = grammar;
return result;
}
llama_sampling_context result;

// Note: Creates the context if it doesn't exist, so this always return something.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it != ctx_sampling.sequence_contexts.end()) {
return it->second;
}
llama_sampler_sequence_context new_ctx = {
2.0f * ctx_sampling.params.mirostat_tau,
ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL,
};
return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second;
}
result.params = params.sampling_params;
result.grammar = grammar;

bool llama_sampling_context_reset(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq) {
const auto it = ctx_sampling.sequence_contexts.find(seq);
if (it == ctx_sampling.sequence_contexts.end()) return false;
if (it->second.grammar != NULL) {
llama_grammar_free(it->second.grammar);
it->second.grammar = NULL;
}
ctx_sampling.sequence_contexts.erase(it);
return true;
return result;
}

llama_token llama_sampling_sample(
Expand All @@ -53,8 +17,7 @@ llama_token llama_sampling_sample(
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx,
llama_seq_id seq) {
const int idx) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));

Expand Down Expand Up @@ -115,10 +78,8 @@ llama_token llama_sampling_sample(
}
}

llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq);

if (ctx_seq.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar);
if (ctx_sampling.grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
}

if (temp <= 0) {
Expand All @@ -128,10 +89,10 @@ llama_token llama_sampling_sample(
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
} else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
} else {
// Temperature sampling
size_t min_keep = std::max(1, params.n_probs);
Expand All @@ -158,8 +119,8 @@ llama_token llama_sampling_sample(
}
}

if (ctx_seq.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_seq.grammar, id);
if (ctx_sampling.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
}

return id;
Expand Down
29 changes: 4 additions & 25 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,14 @@ typedef struct llama_sampling_params {

} llama_sampling_params;

// per-sequence sampler context
typedef struct llama_sampler_sequence_context {
float mirostat_mu; // mirostat sampler state
llama_grammar * grammar;
} llama_sampler_sequence_context;

// general sampler context
typedef struct llama_sampling_context {
~llama_sampling_context();

// parameters that will be used for sampling and when creating
// new llama_sampler_sequence_context instances
// parameters that will be used for sampling
llama_sampling_params params;

// map of sequence ids to sampler contexts
std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
// mirostat sampler state
float mirostat_mu;

// when non-NULL, new instances of llama_sampler_sequence_context
// will get a copy of the grammar here
// note: only the pointer is stored here, it is not a copy of
// the grammar and shouldn't be freed
llama_grammar * grammar;
} llama_sampling_context;

Expand All @@ -65,13 +52,6 @@ llama_sampling_context llama_sampling_context_init(
const struct gpt_params & params,
llama_grammar * grammar = NULL);

// Fetches the sampler context for the specified sequence id (defaults to 0).
// If the context for that sequence id doesn't already exist, it will be created with
// default values based on the parameters in the ctx_sampling argument.
llama_sampler_sequence_context & llama_sampling_get_sequence_context(
llama_sampling_context & ctx_sampling,
const llama_seq_id seq = 0);

// Reset the sampler context for the supplied sequence id (defaults to 0).
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required.
Expand Down Expand Up @@ -104,5 +84,4 @@ llama_token llama_sampling_sample(
struct llama_sampling_context & ctx_sampling,
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates,
const int idx = 0,
llama_seq_id seq = 0);
const int idx = 0);
9 changes: 5 additions & 4 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ struct client {
std::string response;

std::vector<llama_token> tokens_prev;

llama_sampling_context ctx_sampling;
};

static void print_date_time() {
Expand Down Expand Up @@ -125,8 +127,6 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);

llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);

// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
Expand Down Expand Up @@ -156,6 +156,7 @@ int main(int argc, char ** argv) {
client.id = i;
client.tokens_prev.resize(std::max(256, params.n_predict));
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
client.ctx_sampling = llama_sampling_context_init(params, NULL);
}

std::vector<llama_token_data> candidates;
Expand Down Expand Up @@ -341,7 +342,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);

const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
const llama_token id = llama_sampling_sample(ctx, NULL, client.ctx_sampling, client.tokens_prev, candidates, client.i_batch - i);

if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
Expand Down Expand Up @@ -386,7 +387,7 @@ int main(int argc, char ** argv) {

n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
llama_sampling_context_reset(ctx_sampling, client.seq_id);

client.seq_id = -1;
}

Expand Down
15 changes: 8 additions & 7 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#include <string>
#include <vector>

struct seq_draft {
std::vector<llama_token> tokens;

struct llama_grammar * grammar = NULL;
};

int main(int argc, char ** argv) {
gpt_params params;

Expand Down Expand Up @@ -213,13 +219,8 @@ int main(int argc, char ** argv) {
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
// Note: Hardcoded to sequence id 0, if this ever supports parallel generation
// that will need to change.
auto it = ctx_sampling.sequence_contexts.find(0);
GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
// This is necessary because each sequence id in sequence_contexts
// uses a copy of the original grammar.
grammar_dft = llama_grammar_copy(it->second.grammar);

grammar_dft = llama_grammar_copy(ctx_sampling.grammar);

LOG("copied target grammar to draft grammar\n");
}
Expand Down
Loading