-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression: llama.cpp produces nonsensical outputs when using batched decoding on Metal #6173
Comments
Could you test if #6177 behaves correctly in your tests |
@ggerganov Thanks so much for the quick look on this! It did improve the behavior, but something is still wrong. Applying the change from #6177:
As you can see, when the model continues after the initial response, it prompts itself "What is 2+20?" and then responds "22+20=42" which doesn't make sense after that prompt. If I additionally revert the padding change, replacing The model prompts itself with a different question, and answers it correctly:
I'm not totally sure how to read into this result, but my intuition is that the model responding incorrectly to such a simple question indicates something is going wrong. If I run the same query (which was partially hallucinated) without parallelization, it answers correctly:
Worth noting again that without multiple sequences, the model answers in a very different (and subjectively better) way. However, this may be a separate issue, as this difference in behavior was present in the original implementation of batching and is not part of this regression.
|
Thanks for the detailed look. I've updated #6177 - think it should be good now. Could you give it another try and let me know if you agree |
It looks good now - thank you! |
When using batched decoding with >1 parallel sequences, llama.cpp produces nonsensical outputs. Here is an example:
However, if I use only 1 parallel sequence instead of 2, the output becomes reasonable:
I manually bisected and found that the problem was introduced by @ggerganov's change d7b800b (#4280). Indeed, after reverting the
GGML_PAD
that was added tokv_self.n
, the model output becomes reasonable even with multiple batched sequences:I'm not familiar enough with the details here to understand the utility or necessity of the
GGML_PAD
operation. Any idea why this is causing this issue? Is it possible that we should omit that for Metal specifically?Notes:
Thank you for all of the wonderful work going into this project!
The text was updated successfully, but these errors were encountered: