-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : simplify Mamba with advanced batch splits #8526
Conversation
This includes equal-sequence-length batch splits which are useful to simplify recurrent model operators. * llama : always make recurrent state slots contiguous * ggml : simplify mamba operators
Otherwise, the server embeddings tests failed. This was likely an existing problem but was only detected here because of an additional assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested t5-small
and it currently segfaults - let me know if you need help with resolving it
345d590
to
7b7db0b
Compare
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
The following command produces identical perplexity on ./llama-perplexity \
-m models/gemma-2-9b/ggml-model-f16.gguf \
-f build/wikitext-2-raw/wiki.test.raw \
-ngl 99 -c 8192 Is this enough to confirm the SWA functionality? |
I think so. Might also be relevant to test SWA with parallel sequences too (I think this is what using a bigger |
Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)? |
Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it. First thing I'm noticing is the lack of metadata in the I've also recently started to simplify the session file save & restore code in |
I also encountered difficulties running mamba-codestral. I tried to run this model on https://github.com/state-spaces/mamba. But there is no config.json in the model repository. mamba-codestral includes a new tokenizer v3. Please see the discussion here: Maybe this will help development. |
Yes, we can hardcode initially
Sounds good - a separate PR would be easier to review Regarding Codestral - want to highlight again the comment by Mistral team about |
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py This includes some details that may be interesting for you. |
Also move it closer to llama_output_reserve. * llama : fix pooled embeddings when using batches with equal_seqs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we should refactor the KV cache using object-oriented design so that the implementation of non-recurrent, recurrent and other modes are better separated and easier to read.
); | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a follow-up PRs we can move the batch structs into llama-batch.h/.cpp
and write some unit tests
I'll be re-running a few tests before merging this in hopefully less than 2 days. There is now both Mamba-2 and RWKV v6 which kind of need this to simplify the implementation. Still, I don't want to accidentally have broken something with the batch splits, so I'll try to convince myself that there is no problem by running more tests. |
I've ran some tests, and there's a problem: pooled embeddings with Mamba can't work with multiple sequences anymore. This is because This could be fixed by letting causal embeddings be split over multiple Where the checkbox is checked, it means the behavior is the same as on
|
Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings.
This will make it easier to more cleanly support RWKV-v6 and Mamba-2.
I've fixed the pooled embeddings problem with Mamba in b264edd by making it only process a single sequence per In the future, the pooled embeddings will be refactored to allow causal embeddings to be split across ubatches. It should also be possible to remove I'm postponing that pooled embeddings refactor to another PR. I consider this ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work as always!
Let's merge and resolve any remaining issues from master
. I'll follow up with the initial SSM Metal kernels shortly after (#8546)
@compilade btw, I have the SSD implementation on CPU, more or less, if it's
interesting for you.
…On Wed, Aug 21, 2024 at 2:58 PM compilade ***@***.***> wrote:
Merged #8526 <#8526> into
master.
—
Reply to this email directly, view it on GitHub
<#8526 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/BGA7QVODOGJ22XVGJENS4D3ZSUEQ5AVCNFSM6AAAAABK7U7OGKVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJTHE3DOMBVGMZDINY>
.
You are receiving this because you commented.Message ID: <ggerganov/llama.
***@***.***>
|
I get this error quantizing deepseek2 since the merge of this PR: |
* llama : advanced batch splits This includes equal-sequence-length batch splits which are useful to simplify recurrent model operators. * llama : always make recurrent state slots contiguous * ggml : simplify mamba operators * llama : fix integer signedness mixing * llama : logits_all has priority over batch->logits Otherwise, the server embeddings tests failed. This was likely an existing problem but was only detected here because of an additional assertion. * llama : apply suggestions Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : fix t5 segfault * llama : fix Mamba session save and restore * llama : minor cosmetic changes * llama : rename llama_reorder_outputs to llama_output_reorder Also move it closer to llama_output_reserve. * llama : fix pooled embeddings when using batches with equal_seqs * minor : add struct members for clarity ggml-ci * llama : fix T5 segfault again * llama : fix Mamba pooled embeddings with multiple sequences Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings. * llama : add llama_model_is_recurrent to simplify figuring that out This will make it easier to more cleanly support RWKV-v6 and Mamba-2. * llama : fix simple splits when the batch contains embeddings --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* llama : advanced batch splits This includes equal-sequence-length batch splits which are useful to simplify recurrent model operators. * llama : always make recurrent state slots contiguous * ggml : simplify mamba operators * llama : fix integer signedness mixing * llama : logits_all has priority over batch->logits Otherwise, the server embeddings tests failed. This was likely an existing problem but was only detected here because of an additional assertion. * llama : apply suggestions Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : fix t5 segfault * llama : fix Mamba session save and restore * llama : minor cosmetic changes * llama : rename llama_reorder_outputs to llama_output_reorder Also move it closer to llama_output_reserve. * llama : fix pooled embeddings when using batches with equal_seqs * minor : add struct members for clarity ggml-ci * llama : fix T5 segfault again * llama : fix Mamba pooled embeddings with multiple sequences Until the pooled embeddings are refactored to allow splitting across ubatches for causal embeddings, recurrent models can only process a single sequence per ubatch when calculating pooled embeddings. * llama : add llama_model_is_recurrent to simplify figuring that out This will make it easier to more cleanly support RWKV-v6 and Mamba-2. * llama : fix simple splits when the batch contains embeddings --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
As promised in #7531 (comment), I've been extracting the advanced batch splits out of the Jamba PR (#7531).
I've also backported the contiguous allocation of recurrent state slots, which makes it possible to also include the changes from #7531 which simplify the
ggml
operators used specifically for Mamba. Hopefully this isn't too much at once.See #7531 (comment) for an explanation of the batch splits.
Summary
ggml.c
ggml_ssm_conv
andggml_ssm_scan
by assuming batched sequences have the same number of new tokens, and that the states are contiguous and ordered correctly.ggml_concat
to work with a non-contiguous second argument.llama.cpp
lctx.sbatch
for persistent buffersllama_sbatch
, which allows avoiding repeated allocations by re-using the same buffers.lctx.sbatch.split_simple(n_tokens)
to build allama_ubatch
with a max size ofn_tokens
.lctx.sbatch.split_equal(n_tokens)
, and are used to simplify the operators of recurrent models.llama_ubatch
. Similar tollama_batch
, but aware of equal-length sequences.llama_set_inputs
(and others) usellama_ubatch
instead ofllama_batch
.llama_kv_cache_find_slot
llm_build_mamba
to build a Mamba block, used for Mamba, and will be used for Jamballm_build_copy_mask_state
(maybe not a good name) to abstract away the shuffling and masking of recurrent states. Used for Mamba, and it should be usable for other recurrent architectures too.qs.n_attention_wv
inllama_model_quantize_internal
to make it future proof for hybrid models.split_equal
in conjunction withllama_get_logits
, because the API makes it so that the outputs should have the same order they had as the user-provided batch, not something based on batch split rules.For simplicity, this does not include the separation of the KV cache and the recurrent state cache. Both still use the same buffers (
lctx.kv_self.k_l
, andlctx.kv_self.v_l
, as onmaster
). The separation (necessary for hybrid models) will be introduced at the same time as Jamba.TODO
llama_kv_cache_find_slot
with the--hellaswag
benchmark inllama-perplexity
with a Mamba modeltail
metadata for recurrent states. (i.e. which cell is the end of which sequence)bge-small
withllama-embeddings
with parallel prompts with--pooling cls
,--pooling last
and--pooling mean
; results exactly matchmaster
.llama_reorder_outputs
tollama_output_reorder
and move it close tollama_output_reserve
.Future ideas
cls
andlast
within theubatch.outputs
when splitting a batch;inp_cls
is redundant withinp_out_ids
.