Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement stochastic speculative sampling #5625

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cstdio>
#include <string>
#include <vector>
#include <set>

#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
Expand Down Expand Up @@ -46,6 +47,7 @@ int main(int argc, char ** argv) {
}
std::default_random_engine rng(params.seed);
std::uniform_real_distribution<> u_dist;
std::uniform_int_distribution<> u_int_dist;

#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
Expand Down Expand Up @@ -188,12 +190,15 @@ int main(int argc, char ** argv) {
drafts[0].i_batch_tgt[0] = 0;

while (true) {
std::set<int> active_seqs = {};

// print current draft sequences
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}

active_seqs.insert(s);
const auto & tokens = drafts[s].tokens;

LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
Expand All @@ -219,22 +224,24 @@ int main(int argc, char ** argv) {
float p_tgt = 0, p_dft = 0;

// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].active) {
continue;
}

while (active_seqs.size() > 0) {
// randomly select a sequence to verify from active sequences
int s = *std::next(active_seqs.begin(), u_int_dist(rng) % active_seqs.size());
mscheong01 marked this conversation as resolved.
Show resolved Hide resolved
if (i_dft >= (int) drafts[s].tokens.size()) {
drafts[s].active = false;
active_seqs.erase(s);
continue;
}
if (accept) {
// if we already accepted a token, we can skip the rest
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
drafts[s].active = false;
}
active_seqs.erase(s);
continue;
}

LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
float r = u_dist(rng);
llama_token_data_array dist_dft = drafts[s].dist[i_dft];
// acquire the token probabilities assigned by the draft and target models
Expand Down Expand Up @@ -290,13 +297,19 @@ int main(int argc, char ** argv) {
});
}

for(int i = s; i < n_seq_dft; i++) {
active_seqs.erase(s);
for(int i = 0; i < n_seq_dft; i++) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
active_seqs.erase(s);
}
}
}

}

if (!accept) {
Expand Down Expand Up @@ -380,16 +393,22 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_keep(ctx_tgt, 0);
}

std::set<llama_token_data *> freed_addrs;
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].active = false;
drafts[s].tokens.clear();
drafts[s].i_batch_tgt.clear();
// free dist and clear
for (int i = 0; i < drafts[s].dist.size(); i++) {
if (freed_addrs.find(drafts[s].dist[i].data) != freed_addrs.end()) {
continue;
}
free(drafts[s].dist[i].data);
freed_addrs.insert(drafts[s].dist[i].data);
}
drafts[s].dist.clear();
}
freed_addrs.clear();
// note: will be erased after the speculation phase
drafts[0].tokens.push_back(token_id);
drafts[0].dist.push_back(llama_token_data_array{});
Expand Down