Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama_get_kv_cache_token_count() deprecation hurts debugging - suggested API enhancement #4035

Closed
4 tasks done
WeirdConstructor opened this issue Nov 11, 2023 · 4 comments
Closed
4 tasks done
Labels
enhancement New feature or request stale

Comments

@WeirdConstructor
Copy link
Contributor

Prerequisites

Please answer the following questions for yourself before submitting an issue.

  • I am running the latest code. Development is very rapid so there are no tagged versions as of now.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new bug or useful enhancement to share.

Feature Description

I would love to have more insight in the state of the kv cache. I see the following deprecation in the llama.h:

    // Returns the number of tokens in the KV cache
    LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
            "avoid using this, it will be removed in the future, "
            "instead - count the tokens in user code");

I understand that the API of this function is insufficient with the new sequence IDs. But telling the developer that they shall not make any errors and "just count right" is kind of mocking them :-)

This is what I propose, most for debugging purposes:

  • llama_get_kv_seq_token_count(ctx, seq_id)
  • llama_get_kv_seq_token_count_in_range(ctx, seq_id, p0, p1) find the number of tokens in the given range [p0, p1).
  • llama_get_kv_seq_min_pos(ctx, seq_id) find the minimum position of that sequence
  • llama_get_kv_seq_max_pos(ctx, seq_id) find the maximum position of that sequence

I know that llama_get_kv_seq_min_pos() and llama_get_kv_seq_max_pos() would not detect holes in the sequence, but it would still be immensely useful for debugging. For detecting holes llama_get_kv_seq_token_count_in_range() would be useful. Even better would of course be a function to get all the positions of a sequence, but that would be more cumbersome to solve as C API and more complicated to implement.

And change the return value of following functions:

  • int32_t llama_kv_cache_tokens_rm(...) to return the number of deleted tokens.
  • int32_t llama_kv_cache_seq_rm(...) to return the number of deleted tokens.
  • int32_t llama_kv_cache_seq_cp(...) to return the number of copied tokens.
  • int32_t llama_kv_cache_seq_shift(...) to return the number of shifted tokens.

Motivation

I as developer can make errors. Especially when I try to understand an unfamiliar API. Not being able to check if the state of the kv cache is what I expect it to be, is seriously limiting.

Possible Implementation

  • llama_get_kv_seq_token_count(ctx, seq_id) loop over the cells and count the occupied positions with cells with that sequence ID.
  • llama_get_kv_seq_token_count_in_range(ctx, seq_id, p0, p1) same as above, but only look at positions in the given range.
  • llama_get_kv_seq_min_pos loop over the cells and find the minimum positions of cells with that sequence ID.
  • llama_get_kv_seq_max_pos loop over the cells and find the maximum positions of cells with that sequence ID.
  • Add a counter in the functions llama_kv_cache_tokens_rm, llama_kv_cache_seq_rm, llama_kv_cache_seq_cp and llama_kv_cache_seq_shift. And return the value of that counter.
@WeirdConstructor WeirdConstructor added the enhancement New feature or request label Nov 11, 2023
@WeirdConstructor WeirdConstructor changed the title llama_get_kv_cache_token_count() deprecation is mocking developers llama_get_kv_cache_token_count() deprecation is mocking developers - suggested API enhancement Nov 11, 2023
@KerfuffleV2
Copy link
Collaborator

I think you're reading a mocking tone into it that didn't really exist. It's just describing how to deal with the problem without that function.

Making the KV cache modification functions return the number of affected tokens seems like a good change.

I'm not against the other debugging functions either, though they probably wouldn't get used too often. They also shouldn't be used for anything other than debugging because, for example, finding the minimum position of a sequence currently requires walking every KV cache cell. It's much better for the application to keep track of that itself.

@WeirdConstructor
Copy link
Contributor Author

I think you're reading a mocking tone into it that didn't really exist. It's just describing how to deal with the problem without that function.

Yes indeed, you are right :-) I changed the title.

Making the KV cache modification functions return the number of affected tokens seems like a good change.

I'm not against the other debugging functions either, though they probably wouldn't get used too often. They also shouldn't be used for anything other than debugging because, for example, finding the minimum position of a sequence currently requires walking every KV cache cell. It's much better for the application to keep track of that itself.

Of course the goal is to have the application do the right kv cache transformations without feedback from searching it. It's just that off-by-1 bugs tend to creep in everywhere and mixing up indices/positions. Keeping track of all that is easier said than done.

I also wish I could dump the contents of the KV cache as tokens. I wonder if it would make sense to store the token IDs from the batch in the cache cells. So you could dump the contents of a sequence and see the tokens at their positions. For debugging this would be amazing.

@WeirdConstructor WeirdConstructor changed the title llama_get_kv_cache_token_count() deprecation is mocking developers - suggested API enhancement llama_get_kv_cache_token_count() deprecation hurts debugging - suggested API enhancement Nov 12, 2023
@WeirdConstructor
Copy link
Contributor Author

I wanted to note this recently implemented debug print function I used to debug some of my kv cache shenanigans.
It prints just to stdout which was sufficient for my purposes. It basically replaced the above mentioned API enhancements for my purposes.

(from https://github.com/WeirdConstructor/llama.cpp/blob/bbc391bc2cf1777be61fa607ecf32b5c83da4656/llama.cpp#L1644-L1687 )

void llama_kv_cache_debug_print(
    struct llama_context * ctx, const std::string &tag
) {
    llama_kv_cache &cache = ctx->kv_self;

    int prev_pos = -1;
    int prev_i = -1;
    std::string prev_seqs;

    printf("[%10s|c.idx]      pos       seq\n", "tag");
    for (uint32_t i = 0; i < cache.size; ++i) {
        if (cache.cells[i].pos < 0)
            continue;
        std::string seqs;
        for (auto seq : cache.cells[i].seq_id) {
            if (seqs.size() > 0)
                seqs += ",";
            seqs += std::to_string(seq);
        }
        if (seqs.size() > 0) {
            if ((prev_pos + 1) != cache.cells[i].pos || prev_seqs != seqs) {
                if (prev_i >= 0) {
                    printf("[%10s|    :]        :         |\n", tag.c_str());
                    printf("[%10s| %4d] pos=%4d seq=%5s\n",
                        tag.c_str(), prev_i, prev_pos, prev_seqs.c_str());
                }

                prev_i = i;
                prev_pos = cache.cells[i].pos;
                prev_seqs = seqs;

                printf("[%10s| %4d] pos=%4d seq=%5s\n",
                    tag.c_str(), i, cache.cells[i].pos, seqs.c_str());
            } else if (prev_pos + 1 == cache.cells[i].pos) {
                prev_pos = cache.cells[i].pos;
                prev_i = i;
            }
        }
    }

    printf("[%10s|    :]        :         |\n", tag.c_str());
    printf("[%10s| %4d] pos=%4d seq=%5s (end)\n",
        tag.c_str(), prev_i, prev_pos, prev_seqs.c_str());
}

It generates the following output, which visualizes the sequences in the kv cache:

[       tag|c.idx]      pos       seq
[         Z|    0] pos=   0 seq=    0
[         Z|    :]        :         |
[         Z|   61] pos=  61 seq=    0
[         Z|   62] pos= 300 seq=    1
[         Z|    :]        :         |
[         Z|  147] pos= 385 seq=    1
[         Z|  148] pos=  62 seq=    0
[         Z|    :]        :         |
[         Z|  151] pos=  65 seq=    0 (end)

It helped me a lot finding off by 1 bugs in p0 and p1 handling.

Copy link
Contributor

github-actions bot commented Apr 2, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Apr 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request stale
Projects
None yet
Development

No branches or pull requests

2 participants