From c1bad4a5494c681b8d0acb74a20e73d314b42a2c Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Wed, 21 Feb 2024 16:49:27 +0900 Subject: [PATCH 01/17] (WIP) Implement stochastic speculative decoding --- common/sampling.cpp | 78 +++++++++++ common/sampling.h | 7 + examples/speculative/speculative.cpp | 193 +++++++++++++++++++++------ 3 files changed, 239 insertions(+), 39 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index de4331a1182d6..be19972ad940e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -295,6 +295,76 @@ static llama_token llama_sampling_sample_impl( return id; } +static llama_token_data_array llama_sample_probability_distribution_impl( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + const llama_sampling_params & params = ctx_sampling->params; + + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const float temp = params.temp; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + + // Get a pointer to the logits + float * logits = llama_get_logits_ith(ctx_main, idx); + + // Declare original_logits at the beginning of the function scope + std::vector original_logits; + + // apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + if (ctx_cfg) { + float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + } + + cur.clear(); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + + // apply penalties + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + llama_sample_softmax(ctx_main, &cur_p); + return cur_p; +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -304,6 +374,14 @@ llama_token llama_sampling_sample( return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); } +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 95d8753942b40..48b2459d1f944 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,6 +131,13 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); +// returns the probability that token of given id will be sampled +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 3848791d475ad..20938cb7d25e4 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -18,6 +18,7 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; + std::vector dist; struct llama_sampling_context * ctx_sampling; }; @@ -166,7 +167,6 @@ int main(int argc, char ** argv) { std::vector drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); @@ -196,48 +196,149 @@ int main(int argc, char ** argv) { int i_dft = 0; int s_keep = 0; + llama_token token_id; + std::string token_str; + + // loop until we fail to accept a drafted token or we run out of drafted tokens while (true) { - LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + // check if the target token matches any of the drafts + // for stochastic sampling, attempt to match the token with the drafted tokens + { + bool accept = false; + if (params.sparams.temp > 0) { + // stochastic verification + + llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + + float p_tgt, p_dft; + // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); + + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + if (i_dft >= (int) drafts[s].tokens.size()) { + drafts[s].active = false; + continue; + } + if (accept) { + // if we already accepted a token, we can skip the rest + if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { + drafts[s].active = false; + } + continue; + } - llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); + float r = rand() / (float) RAND_MAX; + llama_token_data_array dist_dft = drafts[s].dist[i_dft]; + // acquire the probability of the token from the draft model + for (int i = 0; i < dist_tgt.size; i++) { + + if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { + p_tgt = dist_tgt.data[i].p; + } + if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { + p_dft = dist_dft.data[i].p; + } + if (p_tgt && p_dft) { + break; + } + } + LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); + if (r <= p_tgt / p_dft) { + s_keep = s; + accept = true; + token_id = drafts[s].tokens[i_dft]; + token_str = llama_token_to_piece(ctx_tgt, token_id); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + + LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); + break; + } else { + LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); + drafts[s].active = false; + + // calculate residual probability + GGML_ASSERT(dist_tgt.sorted); + GGML_ASSERT(dist_dft.sorted); + float sum_probs = 0.0f; + + // sort dist by id + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + + for (int i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + sum_probs += dist_tgt.data[i].p; + } + for (int i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p /= sum_probs; + } - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + // sort dist_tgt by p desc + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.p > b.p; + }); + } - const std::string token_str = llama_token_to_piece(ctx_tgt, id); + for(int i = s; i < n_seq_dft; i++) { + if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + // synchronize active status for sequences with the same drafted token + drafts[i].active = drafts[i].active & accept; + } + } - if (!params.use_color) { - printf("%s", token_str.c_str()); - } + } - if (id == llama_token_eos(model_tgt)) { - has_eos = true; - } + if (!accept) { + // all drafted tokens were rejected + // sample from the target model + token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_str = llama_token_to_piece(ctx_tgt, token_id); + } - ++n_predict; - // check if the target token matches any of the drafts - { - bool matches = false; + } else { + // greedy verification - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; - } + // sample from the target model + LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); - s_keep = s; - matches = true; - } else { - drafts[s].active = false; + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + + token_str = llama_token_to_piece(ctx_tgt, token_id); + + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) { + LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); + + s_keep = s; + accept = true; + } else { + drafts[s].active = false; + } } } - if (matches) { + if (token_id == llama_token_eos(model_tgt)) { + has_eos = true; + } + ++n_predict; + + if (accept) { ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -245,17 +346,21 @@ int main(int argc, char ** argv) { if (params.use_color) { // Color token according to its origin sequence printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); - fflush(stdout); + } else { + printf("%s", token_str.c_str()); } + fflush(stdout); continue; + } else { + printf("%s", token_str.c_str()); + fflush(stdout); + break; } } - if (params.use_color) { - printf("%s", token_str.c_str()); - } - fflush(stdout); + } - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + { + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); // TODO: simplify { @@ -275,21 +380,25 @@ int main(int argc, char ** argv) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); + // free dist and clear + for (int i = 0; i < drafts[s].dist.size(); i++) { + free(drafts[s].dist[i].data); + } + drafts[s].dist.clear(); } // note: will be erased after the speculation phase - drafts[0].tokens.push_back(id); + drafts[0].tokens.push_back(token_id); + drafts[0].dist.push_back(llama_token_data_array{}); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); + llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode (ctx_dft, batch_dft); + llama_decode(ctx_dft, batch_dft); ++n_past_dft; - - break; } if (n_predict > params.n_predict || has_eos) { @@ -367,6 +476,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].dist = drafts[s].dist; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -389,6 +499,10 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); + // save cur_p into drafts[s].dist + llama_token_data *data = (llama_token_data *)malloc(sizeof(llama_token_data) * cur_p.size()); + memcpy(data, cur_p.data(), sizeof(llama_token_data) * cur_p.size()); + drafts[s].dist.push_back(llama_token_data_array{data, cur_p.size(), true}); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -440,6 +554,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); + drafts[s].dist.erase(drafts[s].dist.begin()); } } From a9335a5c2a24a22e7c2acc804e6b14ba7385dc0d Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 22 Feb 2024 13:50:30 +0900 Subject: [PATCH 02/17] sample from residual distribution on draft accept failure --- examples/speculative/speculative.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 20938cb7d25e4..74d883410660d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -298,12 +298,12 @@ int main(int argc, char ** argv) { if (!accept) { // all drafted tokens were rejected // sample from the target model - token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + LOG("all drafted tokens were rejected, sampling from residual distribution\n"); + token_id = llama_sample_token(ctx_tgt, &dist_tgt); llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } - } else { // greedy verification From 4694edde14cb85ce884dcaa8dc5319be763c0629 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 22 Feb 2024 14:46:19 +0900 Subject: [PATCH 03/17] fix #5657: force greedy sampling with probs when temp is 0 --- examples/speculative/speculative.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 74d883410660d..86824b400005d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -167,6 +167,9 @@ int main(int argc, char ** argv) { std::vector drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar + if (params.sparams.temp == 0) { + params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + } for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); From fb18827b4eddab3029fdaccd340a48d26416c25f Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Tue, 27 Feb 2024 15:09:12 +0900 Subject: [PATCH 04/17] remove p_accept parameter --- common/common.cpp | 7 ------- common/common.h | 3 +-- examples/speculative/speculative.cpp | 9 --------- 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 10ef11829cc50..e997736f39284 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -497,12 +497,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_sequences = std::stoi(argv[i]); - } else if (arg == "--p-accept" || arg == "-pa") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.p_accept = std::stof(argv[i]); } else if (arg == "--p-split" || arg == "-ps") { if (++i >= argc) { invalid_param = true; @@ -1020,7 +1014,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); diff --git a/common/common.h b/common/common.h index 935771d44ca9c..ec4a91ee9fdf5 100644 --- a/common/common.h +++ b/common/common.h @@ -53,11 +53,10 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 8; // number of tokens to draft during speculative decoding + int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_accept = 0.5f; // speculative decoding accept probability float p_split = 0.1f; // speculative decoding split probability int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 86824b400005d..209fbe1d076a4 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -38,9 +38,6 @@ int main(int argc, char ** argv) { // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; - // probability threshold for accepting a token from the draft model - const float p_accept = params.p_accept; - // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; @@ -446,12 +443,6 @@ int main(int argc, char ** argv) { k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); - drafts[s].drafting = false; - continue; - } - std::vector sa(1, s); // attempt to split the branch if the probability is high enough From 34b942a429c66a6b4383fbe12769593d9d5b241d Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Tue, 27 Feb 2024 15:29:14 +0900 Subject: [PATCH 05/17] fix style --- examples/speculative/speculative.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 209fbe1d076a4..e6623ff007661 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -208,12 +208,11 @@ int main(int argc, char ** argv) { bool accept = false; if (params.sparams.temp > 0) { // stochastic verification - + llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - float p_tgt, p_dft; - // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); + // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; @@ -234,7 +233,6 @@ int main(int argc, char ** argv) { llama_token_data_array dist_dft = drafts[s].dist[i_dft]; // acquire the probability of the token from the draft model for (int i = 0; i < dist_tgt.size; i++) { - if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { p_tgt = dist_tgt.data[i].p; } @@ -258,7 +256,7 @@ int main(int argc, char ** argv) { } else { LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); drafts[s].active = false; - + // calculate residual probability GGML_ASSERT(dist_tgt.sorted); GGML_ASSERT(dist_dft.sorted); From 875319b32388f2867e25fafaeb56cd55c0793106 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Tue, 27 Feb 2024 15:30:52 +0900 Subject: [PATCH 06/17] remove unused variables --- common/sampling.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index be19972ad940e..2f150b0219845 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -304,14 +304,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const float temp = params.temp; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; auto & prev = ctx_sampling->prev; From 6afc1f60e1dac858ca5b45dcd3c528ad0afae1b4 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Wed, 28 Feb 2024 02:26:01 +0900 Subject: [PATCH 07/17] add srand() in speculative.cpp --- examples/speculative/speculative.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e6623ff007661..5a8cdc32a10a5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -41,6 +41,12 @@ int main(int argc, char ** argv) { // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; + if (params.seed >= 0) { + srand(params.seed); + } else { + srand(time(NULL)); + } + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); From 94f6256fd0c3caa8821e5d8a8413634b994212c5 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 29 Feb 2024 00:26:23 +0900 Subject: [PATCH 08/17] replace use of rand() with mt19937 sampling --- examples/speculative/speculative.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5a8cdc32a10a5..0c01e7d66a32a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -41,12 +41,14 @@ int main(int argc, char ** argv) { // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; + std::mt19937 r_gen; if (params.seed >= 0) { - srand(params.seed); + r_gen = std::mt19937(params.seed); } else { - srand(time(NULL)); + r_gen = std::mt19937(time(NULL)); } - + std::uniform_int_distribution u_dist(0, RAND_MAX); + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -235,7 +237,7 @@ int main(int argc, char ** argv) { continue; } - float r = rand() / (float) RAND_MAX; + float r = u_dist(r_gen) / (float) RAND_MAX; llama_token_data_array dist_dft = drafts[s].dist[i_dft]; // acquire the probability of the token from the draft model for (int i = 0; i < dist_tgt.size; i++) { From e4896e71b5c2c369fe236324f8366ca0de34e368 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 29 Feb 2024 00:41:31 +0900 Subject: [PATCH 09/17] fixes based on review (@JohannesGaessler) --- examples/speculative/speculative.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0c01e7d66a32a..b2634890d4b3a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { r_gen = std::mt19937(time(NULL)); } std::uniform_int_distribution u_dist(0, RAND_MAX); - + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -218,7 +218,7 @@ int main(int argc, char ** argv) { // stochastic verification llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - float p_tgt, p_dft; + float p_tgt = 0, p_dft = 0; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); for (int s = 0; s < n_seq_dft; ++s) { @@ -239,7 +239,7 @@ int main(int argc, char ** argv) { float r = u_dist(r_gen) / (float) RAND_MAX; llama_token_data_array dist_dft = drafts[s].dist[i_dft]; - // acquire the probability of the token from the draft model + // acquire the token probabilities assigned by the draft and target models for (int i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { p_tgt = dist_tgt.data[i].p; @@ -295,7 +295,7 @@ int main(int argc, char ** argv) { for(int i = s; i < n_seq_dft; i++) { if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { // synchronize active status for sequences with the same drafted token - drafts[i].active = drafts[i].active & accept; + drafts[i].active = drafts[i].active && accept; } } From 6b35c8b3cf17ef9cee99436b540d3ddf948b85ab Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 29 Feb 2024 13:27:29 +0900 Subject: [PATCH 10/17] fix r random generation --- examples/speculative/speculative.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index b2634890d4b3a..57893c2ebc835 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -41,13 +41,11 @@ int main(int argc, char ** argv) { // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; - std::mt19937 r_gen; - if (params.seed >= 0) { - r_gen = std::mt19937(params.seed); - } else { - r_gen = std::mt19937(time(NULL)); + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); } - std::uniform_int_distribution u_dist(0, RAND_MAX); + std::default_random_engine rng(params.seed); + std::uniform_real_distribution<> u_dist; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -237,7 +235,7 @@ int main(int argc, char ** argv) { continue; } - float r = u_dist(r_gen) / (float) RAND_MAX; + float r = u_dist(rng); llama_token_data_array dist_dft = drafts[s].dist[i_dft]; // acquire the token probabilities assigned by the draft and target models for (int i = 0; i < dist_tgt.size; i++) { From 2ad3f7c28ccac927491903bab5b9481407b4841a Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 29 Feb 2024 15:47:41 +0900 Subject: [PATCH 11/17] randomly select next sequence to verify + fix bug in memory freeing --- examples/speculative/speculative.cpp | 33 ++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 57893c2ebc835..124133a351c1a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -46,6 +47,7 @@ int main(int argc, char ** argv) { } std::default_random_engine rng(params.seed); std::uniform_real_distribution<> u_dist; + std::uniform_int_distribution<> u_int_dist; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -188,12 +190,15 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt[0] = 0; while (true) { + std::set active_seqs = {}; + // print current draft sequences for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; } + active_seqs.insert(s); const auto & tokens = drafts[s].tokens; LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); @@ -219,12 +224,13 @@ int main(int argc, char ** argv) { float p_tgt = 0, p_dft = 0; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; - } + + while (active_seqs.size() > 0) { + // randomly select a sequence to verify from active sequences + int s = *std::next(active_seqs.begin(), u_int_dist(rng) % active_seqs.size()); if (i_dft >= (int) drafts[s].tokens.size()) { drafts[s].active = false; + active_seqs.erase(s); continue; } if (accept) { @@ -232,9 +238,10 @@ int main(int argc, char ** argv) { if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { drafts[s].active = false; } + active_seqs.erase(s); continue; } - + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); llama_token_data_array dist_dft = drafts[s].dist[i_dft]; // acquire the token probabilities assigned by the draft and target models @@ -290,13 +297,19 @@ int main(int argc, char ** argv) { }); } - for(int i = s; i < n_seq_dft; i++) { + active_seqs.erase(s); + for(int i = 0; i < n_seq_dft; i++) { + if (i == s) { + continue; + } if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { // synchronize active status for sequences with the same drafted token drafts[i].active = drafts[i].active && accept; + if (!drafts[i].active) { + active_seqs.erase(s); + } } } - } if (!accept) { @@ -380,16 +393,22 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_keep(ctx_tgt, 0); } + std::set freed_addrs; for (int s = 0; s < n_seq_dft; ++s) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); // free dist and clear for (int i = 0; i < drafts[s].dist.size(); i++) { + if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) { + continue; + } free(drafts[s].dist[i].data); + freed_addrs.insert(drafts[s].dist[i].data); } drafts[s].dist.clear(); } + freed_addrs.clear(); // note: will be erased after the speculation phase drafts[0].tokens.push_back(token_id); drafts[0].dist.push_back(llama_token_data_array{}); From c2cd292307acc799078b5be517f2fc74d241640a Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Thu, 29 Feb 2024 16:01:34 +0900 Subject: [PATCH 12/17] fix bug in active_seqs sync --- examples/speculative/speculative.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 124133a351c1a..37c9823a744ac 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -237,8 +237,8 @@ int main(int argc, char ** argv) { // if we already accepted a token, we can skip the rest if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { drafts[s].active = false; + active_seqs.erase(s); } - active_seqs.erase(s); continue; } LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); From 7463569cad717b78fe5c62508636f01a706757ce Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Fri, 1 Mar 2024 02:24:55 +0900 Subject: [PATCH 13/17] fix uniform int distribution initialization --- examples/speculative/speculative.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 37c9823a744ac..a1fa67afc48ac 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -47,7 +47,6 @@ int main(int argc, char ** argv) { } std::default_random_engine rng(params.seed); std::uniform_real_distribution<> u_dist; - std::uniform_int_distribution<> u_int_dist; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); @@ -227,7 +226,8 @@ int main(int argc, char ** argv) { while (active_seqs.size() > 0) { // randomly select a sequence to verify from active sequences - int s = *std::next(active_seqs.begin(), u_int_dist(rng) % active_seqs.size()); + std::uniform_int_distribution u_int_dist(0, active_seqs.size() - 1); + int s = *std::next(active_seqs.begin(), u_int_dist(rng)); if (i_dft >= (int) drafts[s].tokens.size()) { drafts[s].active = false; active_seqs.erase(s); From c76135401f3f1c59960e819f240292c887f91ab7 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Sat, 2 Mar 2024 16:45:07 +0900 Subject: [PATCH 14/17] remove warnings from comparison between int and size_t --- examples/speculative/speculative.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index a1fa67afc48ac..0546e68fbdbc1 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -245,7 +245,7 @@ int main(int argc, char ** argv) { float r = u_dist(rng); llama_token_data_array dist_dft = drafts[s].dist[i_dft]; // acquire the token probabilities assigned by the draft and target models - for (int i = 0; i < dist_tgt.size; i++) { + for (size_t i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { p_tgt = dist_tgt.data[i].p; } @@ -283,11 +283,11 @@ int main(int argc, char ** argv) { return a.id < b.id; }); - for (int i = 0; i < dist_tgt.size; i++) { + for (size_t i = 0; i < dist_tgt.size; i++) { dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); sum_probs += dist_tgt.data[i].p; } - for (int i = 0; i < dist_tgt.size; i++) { + for (size_t i = 0; i < dist_tgt.size; i++) { dist_tgt.data[i].p /= sum_probs; } @@ -399,7 +399,7 @@ int main(int argc, char ** argv) { drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); // free dist and clear - for (int i = 0; i < drafts[s].dist.size(); i++) { + for (size_t i = 0; i < drafts[s].dist.size(); i++) { if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) { continue; } From 45465b21d17a38fab751afad31ef88ef6a6f104b Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Sun, 3 Mar 2024 03:09:11 +0900 Subject: [PATCH 15/17] check grammar in `llama_sample_probability_distribution_impl` --- common/sampling.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/common/sampling.cpp b/common/sampling.cpp index 2f150b0219845..776e60c82cade 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -357,6 +357,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl( } } + // apply grammar checks + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + } + llama_sample_softmax(ctx_main, &cur_p); return cur_p; } From 67ad517e110567ecc52404fc7d800cb17b4b5412 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Mon, 4 Mar 2024 14:55:35 +0900 Subject: [PATCH 16/17] remove malloc code by utilizing vectors --- examples/speculative/speculative.cpp | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0546e68fbdbc1..85bc0a762ad08 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -19,7 +19,7 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; - std::vector dist; + std::vector> dists; struct llama_sampling_context * ctx_sampling; }; @@ -243,7 +243,7 @@ int main(int argc, char ** argv) { } LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); - llama_token_data_array dist_dft = drafts[s].dist[i_dft]; + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; // acquire the token probabilities assigned by the draft and target models for (size_t i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { @@ -393,25 +393,15 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_keep(ctx_tgt, 0); } - std::set freed_addrs; for (int s = 0; s < n_seq_dft; ++s) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); - // free dist and clear - for (size_t i = 0; i < drafts[s].dist.size(); i++) { - if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) { - continue; - } - free(drafts[s].dist[i].data); - freed_addrs.insert(drafts[s].dist[i].data); - } - drafts[s].dist.clear(); + drafts[s].dists.clear(); } - freed_addrs.clear(); // note: will be erased after the speculation phase drafts[0].tokens.push_back(token_id); - drafts[0].dist.push_back(llama_token_data_array{}); + drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); @@ -493,7 +483,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; - drafts[n_seq_cur].dist = drafts[s].dist; + drafts[n_seq_cur].dists = drafts[s].dists; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -516,10 +506,8 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); - // save cur_p into drafts[s].dist - llama_token_data *data = (llama_token_data *)malloc(sizeof(llama_token_data) * cur_p.size()); - memcpy(data, cur_p.data(), sizeof(llama_token_data) * cur_p.size()); - drafts[s].dist.push_back(llama_token_data_array{data, cur_p.size(), true}); + // save cur_p.data into drafts[s].dists + drafts[s].dists.push_back(cur_p); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -571,7 +559,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); - drafts[s].dist.erase(drafts[s].dist.begin()); + drafts[s].dists.erase(drafts[s].dists.begin()); } } From 056bdb3029870001fe85803b0d5b88726805312e Mon Sep 17 00:00:00 2001 From: Minsoo Cheong Date: Mon, 4 Mar 2024 15:07:40 +0900 Subject: [PATCH 17/17] add PR link to README --- examples/speculative/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/speculative/README.md b/examples/speculative/README.md index 814efa592d94f..a6608c5fe8e3a 100644 --- a/examples/speculative/README.md +++ b/examples/speculative/README.md @@ -6,3 +6,4 @@ More info: - https://github.com/ggerganov/llama.cpp/pull/2926 - https://github.com/ggerganov/llama.cpp/pull/3624 +- https://github.com/ggerganov/llama.cpp/pull/5625