Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always sort logits before nucleus sampling #812

Merged
merged 2 commits into from
Apr 7, 2023

Conversation

ivanstepanovftw
Copy link
Collaborator

@ivanstepanovftw ivanstepanovftw commented Apr 6, 2023

Logits are not sorted before nucleus sampling if TopK is 0 or out of bounds. Reported and solved by @Piezoid #779 (comment).

Other changes:
Since logits are sorted, first index has maximum probability.

Also remove normalization, because std::discrete_distribution already divides by the sum.

@Piezoid
Copy link
Contributor

Piezoid commented Apr 6, 2023

The normalization may not be needed, std::discrete_distribution already divides by the sum.

@ivanstepanovftw
Copy link
Collaborator Author

ivanstepanovftw commented Apr 6, 2023

I can confirm that normalization is not needed.

Tested with following program, output is the same for {10, 10, 10, 100} and {100, 100, 100, 1000}

#include <iostream>
#include <iomanip>
#include <map>
#include <random>
 
int main()
{
    std::random_device rd;
    std::mt19937 gen(rd());
    // std::discrete_distribution<> d({10, 10, 10, 100});
    std::discrete_distribution<> d({100, 100, 100, 1000});
    std::map<int, int> map;
 
    for (int n = 0; n < 1e4; ++n)
        ++map[d(gen)];
 
    for(const auto& [num, count] : map)
        std::cout << num << " generated " << std::setw(4) << count << " times\n";
}

- fix windows build
- remove normalization since std::discrete_distribution does not require it
@Piezoid
Copy link
Contributor

Piezoid commented Apr 6, 2023

On a related note, I tried an optimization (which is getting out of scope for this PR): during the initial pass over logits and the application of temperature scaling, there's a way to bypass tokens that have probabilities much lower than the highest probability using a heuristic:

const float maxl = *std::max_element(plogits, plogits + n_logits);
const float minp_ratio = 1e-6; // Ratio lowest admissible prob : highest prob
const float minl = maxl / repeat_penalty + log(minp_ratio) * temp; // Assumes the best token is penalized
// Then, when plogits[i] < minl, don't push_back the pair(logit, i) into logits_id

Sampling is clearly not the most compute heavy task, so it might not be a good idea to introduce a heuristic that could potentially fail. I was experimenting with a more complex repetition search, that could justify it.

@ggerganov ggerganov merged commit 4953e90 into ggerganov:master Apr 7, 2023
@ivanstepanovftw ivanstepanovftw deleted the broken_topk0 branch April 22, 2023 12:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants