Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : simplify Mamba with advanced batch splits #8526

Merged
merged 19 commits into from
Aug 21, 2024

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented Jul 17, 2024

As promised in #7531 (comment), I've been extracting the advanced batch splits out of the Jamba PR (#7531).

I've also backported the contiguous allocation of recurrent state slots, which makes it possible to also include the changes from #7531 which simplify the ggml operators used specifically for Mamba. Hopefully this isn't too much at once.

See #7531 (comment) for an explanation of the batch splits.

Summary

  • ggml.c
    • Simplify ggml_ssm_conv and ggml_ssm_scan by assuming batched sequences have the same number of new tokens, and that the states are contiguous and ordered correctly.
    • Allow ggml_concat to work with a non-contiguous second argument.
      • The CPU implementation already supported this, but it was guarded with an assertion. Meanwhile, I think the CUDA implementation already supports this too, and does not prevent its usage (not totally sure), so I did not change it.
  • llama.cpp
    • Advanced batch splits handled with lctx.sbatch for persistent buffers
      • Refactor "helpers for smoother batch API transition", by handling them in llama_sbatch, which allows avoiding repeated allocations by re-using the same buffers.
      • Simple batch splits should be equivalent to the previous behavior and are made with lctx.sbatch.split_simple(n_tokens) to build a llama_ubatch with a max size of n_tokens.
      • Equal-sequence-lengths splits are made with lctx.sbatch.split_equal(n_tokens), and are used to simplify the operators of recurrent models.
      • Add llama_ubatch. Similar to llama_batch, but aware of equal-length sequences.
        • Make llama_set_inputs (and others) use llama_ubatch instead of llama_batch.
    • Make recurrent state slot allocation contiguous in llama_kv_cache_find_slot
    • Add llm_build_mamba to build a Mamba block, used for Mamba, and will be used for Jamba
    • Add llm_build_copy_mask_state (maybe not a good name) to abstract away the shuffling and masking of recurrent states. Used for Mamba, and it should be usable for other recurrent architectures too.
    • Simplify the sanity checks for qs.n_attention_wv in llama_model_quantize_internal to make it future proof for hybrid models.
    • Reorder the outputs when using advanced batch splits like split_equal in conjunction with llama_get_logits, because the API makes it so that the outputs should have the same order they had as the user-provided batch, not something based on batch split rules.

For simplicity, this does not include the separation of the KV cache and the recurrent state cache. Both still use the same buffers (lctx.kv_self.k_l, and lctx.kv_self.v_l, as on master). The separation (necessary for hybrid models) will be introduced at the same time as Jamba.

TODO

  • Test the slot allocation of llama_kv_cache_find_slot with the --hellaswag benchmark in llama-perplexity with a Mamba model
    • This uses lots of parallel sequences in an unusual way, and so I think it's a great stress test.
  • Session file saving and reloading
    • Reloading needs to rebuild the tail metadata for recurrent states. (i.e. which cell is the end of which sequence)
    • The server tests need to pass
  • Make sure T5 still works
  • Make sure the pooled embeddings still work
    • tested bge-small with llama-embeddings with parallel prompts with --pooling cls, --pooling last and --pooling mean; results exactly match master.
  • Make sure Gemma's sliding window mask still works
  • Decide whether to rename llama_reorder_outputs to llama_output_reorder and move it close to llama_output_reserve.
    • renamed and moved

Future ideas

  • whole-sequence splits for embeddings
  • handle pooling types like cls and last within the ubatch.outputs when splitting a batch; inp_cls is redundant with inp_out_ids.

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Jul 17, 2024
@compilade compilade added refactoring Refactoring Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels Jul 17, 2024
@compilade compilade marked this pull request as draft July 17, 2024 01:54
Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested t5-small and it currently segfaults - let me know if you need help with resolving it

src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
@github-actions github-actions bot added the testing Everything test related label Jul 17, 2024
@ggerganov ggerganov force-pushed the compilade/batch-splits branch from 345d590 to 7b7db0b Compare July 17, 2024 18:37
compilade and others added 2 commits July 17, 2024 14:48
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@ggerganov
Copy link
Owner

Make sure Gemma's sliding window mask still works

The following command produces identical perplexity on master and this branch:

./llama-perplexity \
    -m models/gemma-2-9b/ggml-model-f16.gguf \
    -f build/wikitext-2-raw/wiki.test.raw \
    -ngl 99 -c 8192

Is this enough to confirm the SWA functionality?

@compilade
Copy link
Collaborator Author

Is this enough to confirm the SWA functionality?

I think so. Might also be relevant to test SWA with parallel sequences too (I think this is what using a bigger -b (and -ub?) than -c does with llama-perplexity).

@hackey
Copy link

hackey commented Jul 23, 2024

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

@compilade
Copy link
Collaborator Author

compilade commented Jul 24, 2024

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.

First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes.
For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

@hackey
Copy link

hackey commented Jul 24, 2024

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.

First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

I also encountered difficulties running mamba-codestral. I tried to run this model on https://github.com/state-spaces/mamba. But there is no config.json in the model repository. mamba-codestral includes a new tokenizer v3.
Although Mistral writes that the model can be run on state-spaces/mamba, nothing worked for me.

Please see the discussion here:
NVIDIA/TensorRT-LLM#1968
and a few hours ago an example for running mamba appeared:
https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mamba

Maybe this will help development.

@ggerganov
Copy link
Owner

For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

Yes, we can hardcode initially

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

Sounds good - a separate PR would be easier to review

Regarding Codestral - want to highlight again the comment by Mistral team about ngroups = 8: #8519 (comment). Seems important

@awgr
Copy link

awgr commented Jul 29, 2024

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.

First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py

This includes some details that may be interesting for you.

Also move it closer to llama_output_reserve.

* llama : fix pooled embeddings when using batches with equal_seqs
@compilade compilade marked this pull request as ready for review August 8, 2024 01:20
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we should refactor the KV cache using object-oriented design so that the implementation of non-recurrent, recurrent and other modes are better separated and easier to read.

);
}
};

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PRs we can move the batch structs into llama-batch.h/.cpp and write some unit tests

@compilade
Copy link
Collaborator Author

compilade commented Aug 19, 2024

I'll be re-running a few tests before merging this in hopefully less than 2 days. There is now both Mamba-2 and RWKV v6 which kind of need this to simplify the implementation.

Still, I don't want to accidentally have broken something with the batch splits, so I'll try to convince myself that there is no problem by running more tests.

@compilade
Copy link
Collaborator Author

compilade commented Aug 21, 2024

I've ran some tests, and there's a problem: pooled embeddings with Mamba can't work with multiple sequences anymore.

This is because lctx.embd_seq is overwritten at each ubatch which makes it only work if everything fits in a single ubatch, which is not the case when sequences don't have the same length and are split to makes them all equal in Mamba's ubatch.

This could be fixed by letting causal embeddings be split over multiple ubatch. I'll try to find a way to do this cleanly.

Where the checkbox is checked, it means the behavior is the same as on master or better.

  • perplexity
    • v0-mamba-100k
      • 1 chunk per batch
      • 4 chunks per batch
      • 4 batches per chunk
      • 4 ubatches per batch
    • v0-llama2-100k
      • 1 chunk per batch
      • 4 chunks per batch
      • 4 batches per chunk
      • 4 ubatches per batch
  • llama-embedding
    • t5-small
      • fails with GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor") as on master.
    • mamba-130M
      • ❌ Does NOT work with more than one sequence anymore
    • bge-small
  • llama-parallel -c 1024 -np 5 -ns 7 --seed 42 --temp 1
    • t5-small
      • fails to decode, as on master (fixed segfault in 652e9b0)
    • stories-MoE
      • works with -c 1024, but segfaults otherwise (as on master)
    • mamba-130M
      • works without problem
  • perplexity --hellaswag (parallel sequences of uneven length, also the only test with batches having more than one seq_id per token)
    • v0-mamba-100k
    • v0-llama2-100k
  • save load state
    • Mamba-130M
      • works (unlike on master), but needs -np 2 for the sequence load test.
    • v0-llama2-100k
    • t5-small
      • fails, as on master
  • quantization (because an assertion was changed over there)
    • OpenELM-270M
    • t5-small
    • Mamba-370M
  • Gemma2 sliding window
    • Gemma2-2B-it perplexity with -c 5120 (its sliding window is 4096), first chunk generates the same perplexity.
    • parallel sliding windows (on more than one sequence)
      • not sure how to test that

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.
This will make it easier to more cleanly support RWKV-v6 and Mamba-2.
@compilade
Copy link
Collaborator Author

I've fixed the pooled embeddings problem with Mamba in b264edd by making it only process a single sequence per ubatch. When the sequences are short, this is slightly slower than processing them all at once, unfortunately.

In the future, the pooled embeddings will be refactored to allow causal embeddings to be split across ubatches. It should also be possible to remove inp_cls, because it's redundant with inp_out_ids. LLAMA_POOLING_TYPE_CLS and LLAMA_POOLING_TYPE_LAST could be handled directly when splitting batches, since they only affect which tokens get their output selected. LLAMA_POOLING_TYPE_MEAN will be a bit harder to allow splitting, but since the total number of tokens per sequence per batch is known in advance, there might still be a way.

I'm postponing that pooled embeddings refactor to another PR. I consider this ready.

@compilade compilade mentioned this pull request Aug 21, 2024
4 tasks
@compilade compilade added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label Aug 21, 2024
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work as always!

Let's merge and resolve any remaining issues from master. I'll follow up with the initial SSM Metal kernels shortly after (#8546)

@compilade compilade merged commit a1631e5 into master Aug 21, 2024
53 checks passed
@awgr
Copy link

awgr commented Aug 22, 2024 via email

@mann1x
Copy link

mann1x commented Aug 24, 2024

@compilade

I get this error quantizing deepseek2 since the merge of this PR:
#9155

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Aug 27, 2024
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
* llama : advanced batch splits

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators

* llama : fix integer signedness mixing

* llama : logits_all has priority over batch->logits

Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.

* llama : apply suggestions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : fix t5 segfault

* llama : fix Mamba session save and restore

* llama : minor cosmetic changes

* llama : rename llama_reorder_outputs to llama_output_reorder

Also move it closer to llama_output_reserve.

* llama : fix pooled embeddings when using batches with equal_seqs

* minor : add struct members for clarity

ggml-ci

* llama : fix T5 segfault again

* llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.

* llama : add llama_model_is_recurrent to simplify figuring that out

This will make it easier to more cleanly support RWKV-v6 and Mamba-2.

* llama : fix simple splits when the batch contains embeddings

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* llama : advanced batch splits

This includes equal-sequence-length batch splits which are useful
to simplify recurrent model operators.

* llama : always make recurrent state slots contiguous

* ggml : simplify mamba operators

* llama : fix integer signedness mixing

* llama : logits_all has priority over batch->logits

Otherwise, the server embeddings tests failed.
This was likely an existing problem but was only detected here
because of an additional assertion.

* llama : apply suggestions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : fix t5 segfault

* llama : fix Mamba session save and restore

* llama : minor cosmetic changes

* llama : rename llama_reorder_outputs to llama_output_reorder

Also move it closer to llama_output_reserve.

* llama : fix pooled embeddings when using batches with equal_seqs

* minor : add struct members for clarity

ggml-ci

* llama : fix T5 segfault again

* llama : fix Mamba pooled embeddings with multiple sequences

Until the pooled embeddings are refactored to allow splitting
across ubatches for causal embeddings,
recurrent models can only process a single sequence per ubatch
when calculating pooled embeddings.

* llama : add llama_model_is_recurrent to simplify figuring that out

This will make it easier to more cleanly support RWKV-v6 and Mamba-2.

* llama : fix simple splits when the batch contains embeddings

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning merge ready indicates that this may be ready to merge soon and is just holding out in case of objections refactoring Refactoring Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants