-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : support Mamba Selective State Space Models #5328
Conversation
@compilade ust out of curiosity, is any convolution operation performed? I see some tensors with the name |
@FSSRepo But it turns out that the desired operation in this case is exactly equivalent to making a self-overlapping view which shifts by one column at each stride in the 3rd dimension (which corresponds here to the number of tokens in the batch), and then doing a matrix multiplication with the Not sure if I'm explaining this clearly, because I did not really know anything about convolutions before working on this. (Here are the relevant lines for the "conv" step in my implementation.) I figured this out when thinking about how to process multiple tokens at a time in the "conv" step when starting from how the next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
Turns out that implementing support for a novel model architecture is quite fun (well, at least when it finally works).
Glad to hear 😄
Regarding the KV questions:
IIUC one slot is needed per sequence, so in that sense the KV cache size could be interpreted as the maximum number of distinct sequences that can be processed simultaneously.
GPU kernels likely for future PRs
(What follows are some thoughts about where the number of distinct sequences should be taken from. TL;DR at the end.) For Mamba 3B, each KV slot takes So, let's say I instead take the max number of distinct sequences from the value of But then I have to change how some things are initialized, replacing Another thing regarding TL;DR: I'll try to make Mamba's KV cache size proportional to |
I've been thinking about what parts of the KV cache API can and cannot be supported for Mamba. In general, functions which operate on whole sequences or the whole KV cache can be relatively easily supported. But a lot of KV cache API functions take a range of token positions, and this cannot easily work with Mamba (too many states would need to be kept unnecessarily).
Here, "Partially" means "Only on entire sequences" (all tokens of a sequence, regardless of their position). The most problematic function is I think that most of what is currently done with position ranges (when using This is kind of a blocker for Mamba support in |
Yes, that sounds like the right way to do it
More thoughts on this are welcome |
9c4c257
to
322686e
Compare
Now that multiple sequences can be processed at once, I've been trying to make the
I think I was wrong. Some uses of position ranges do seem necessary. The I'm wondering if it's okay to make But I've been thinking of a way to calculate previous states from more recent ones. From the (2a) equation of the paper, which looks like what is done in Solving for But getting the previous
These are questions I'll ponder during the next month (so probably in another PR), after I make the |
3421d17
to
7b1ff55
Compare
8646535
to
fad8848
Compare
It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions.
This results in 8% faster token generation for Mamba-130M.
Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway.
This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though.
This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD.
Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers")
This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later.
It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation is pretty good
I'm still not convinced we need to introduce n_parallel
and llama_n_max_seq()
. I did some tests using just n_ctx
and things seems to work OK. Only the self attention input buffers (such as KQ_mask
and KQ_pos
) depend on n_ctx
(and now kv_size
), but these are not used for Mamba, so we won't be over-allocating. If in some places we expect the input to not be big bigger than n_ctx
(such as the context shift logic), we can try to fix these (simply disable context shift for Mamba models).
Even if the examples with default arguments are not suitable for Mamba (i.e. n_ctx = 512
), it's not a big problem. As long as it is just a matter of adjusting some of the CLI args, I think it is good
Either way, we can merge it as it is since the API change is quite small
Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case.
Thanks!
Imagine the following case: A user wants to use Mamba 3B to process a prompt with a length of... 1337 tokens. This user is only using a single sequence. Out of habit with how other models work, the user passes Now, the two ways to do this:
Okay that was unfair. Let's say the user is better-informed and passes
I don't really see from where else than Using The same reasoning also applies for examples like I hope this better explains why the context size and the number of sequences were made orthogonal for Mamba.
These checks are also used to avoid overflowing the buffer allocated with Currently, context shifting is faked for recurrent models to let
I agree.
|
Since the I think I should rename the GGUF key-value pairs I added for Mamba to make them more similar to their
This would break existing GGUF-converted Mamba models, though. (EDIT: the above change has been done. If there are any objections, I'd like to know) |
This breaks existing converted-to-GGUF models, but the metadata names are more "standard". mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight
This is purely a formatting change.
Thanks, I agree now. We should actually start using Feel free to merge this (squash in single commit) when you think it is ready. Maybe add a short notice in the "Recent API changes" section in the README.md to help 3rd party devs and consider updating the GGUF spec with the new keys |
Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly.
There might be a misunderstanding here. To be clear, What caused Unless an example really uses ALL available sequences on any single token in a batch,
Noted. |
A few tensors were also missing `struct` in front of `ggml_tensor`.
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
* mamba : begin working on support for Mamba SSM * mamba : begin figuring out how to (ab)use the kv cache for Mamba * mamba : recurrent inference almost works, but incoherent * mamba : recurrent inference WORKS!!! * convert : optionally use d_conv and d_state from config.json for Mamba * mamba : refactor recurrent conv, resulting in 20% perf increase It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions. * ggml : parallelize ggml_exp This results in 8% faster token generation for Mamba-130M. * mamba : simplify the conv step with a self-overlapping view Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway. * mamba : fix self-overlapping view depth stride * mamba : handle batches of more than 1 token This means running Mamba no longer crashes when using the default settings! And probably also slightly faster prompt processing. Both batched and non-batched processing yield the same output. Previously, the state was not cleared when starting a sequence. Next step is to make the KV cache API work as expected for Mamba models. * ggml: add ggml_ssm_scan to help with parallel selective scan If the selective scan was implemented without a custom operator, there would be waaay too many nodes in the graph. For example, for Mamba-130M, with a batch size of 512 (the default), a naive selective scan could add at least 24*512=12288 nodes, which is more than LLAMA_MAX_NODES (8192), and that's only for the smallest Mamba model. So it's much cleaner with a custom operator. Not sure about the name, though. * ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation This will help with performance on CPU if ggml_vec_mul_f32 and ggml_vec_add_f32 are ever optimized with SIMD. * mamba : very basic quantization support Mostly works, but there is currently no difference between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same). Most of the SSM-specific weights can be kept in f32 without affecting the size that much, since they are relatively small. (the linear projection weights are responsible for most of Mamba's size) Too much quantization seems to make the state degrade quite fast, and the model begins to output gibberish. It seems to affect bigger models to a lesser extent than small models, but I'm not sure by how much. Experimentation will be needed to figure out which weights are more important for the _M (and _L?) variants of k-quants for Mamba. * convert : fix wrong name for layer norm weight of offical Mamba models I was using Q-bert/Mamba-* models before, which have a slighlty different naming scheme for the weights. (they start with "model.layers" instead of "backbone.layers") * mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator This increases performance on CPU by around 30% for prompt processing, and by around 20% for text generation. However, it also makes the ggml_exp and ggml_soft_plus operators unused. Whether or not they should be kept will be decided later. * convert : for Mamba, also consider the "MambaLMHeadModel" arch name It's the name of the class of the official implementation, though they don't use it (yet) in the "architectures" field of config.json * mamba : fix vocab size problems with official models The perplexity was waaaay to high for models with a non-round vocab size. Not sure why, but it needed to be fixed in the metadata. Note that this breaks existing GGUF-converted Mamba models, but **only if** the vocab size was not already rounded. * ggml : remove ggml_exp and ggml_soft_plus They did not exist anyway outside of this branch, and since ggml_ssm_scan fused operations together, they are unused. It's always possible to bring them back if needed. * mamba : remove some useless comments No code change. * convert : fix flake8 linter errors * mamba : apply suggestions from code review * mamba : remove unecessary branch for row-wise ssm_state and C multiplication It was previously done to avoid permuting when only one token is processed at a time (like when generating text), but permuting is cheap, and dynamically changing the compute graph is not future-proof. * ggml : in ggml_ssm_scan, use more appropriate asserts * ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32 * mamba : multiple sequences, but one at a time This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok * mamba : in comments, properly refer to KV cells instead of slots * mamba : reduce memory usage of ggml_ssm_scan From 290.37 MiB to 140.68 MiB of CPU compute buffer size with Mamba 3B with a batch size of 512. The result tensor of ggml_ssm_scan was previously a big part of the CPU compute buffer size. To make it smaller, it does not contain the intermediate ssm states anymore. Both y and the last ssm state are combined in the result tensor, because it seems only a single tensor can be returned by an operator with the way the graph is built. * mamba : simultaneous sequence processing A batch can now contain tokens from multiple sequences. This is necessary for at least the parallel example, the server example, and the HellaSwag test in the perplexity example. However, for this to be useful, uses of llama_kv_cache_seq_rm/cp will need to be changed to work on whole sequences. * ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba This operator makes it possible to use and update the correct states for each token of the batch in the same way as ggml_ssm_scan. Other solutions which use existing operators would need loops which would add too many nodes to the graph (at least the ones I thought of). Using this operator further reduces the size of the CPU compute buffer from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512. And (at least on CPU), it's a bit faster than before. Note that "ggml_ssm_conv" is probably not the most appropriate name, and it could be changed if a better one is found. * llama : add inp_s_seq as a new input tensor The most convenient implementation to select the correct state (for Mamba) for each token is to directly get the correct index from a tensor. This is why inp_s_seq is storing int32_t and not floats. The other, less convenient way to select the correct state would be to have inp_KQ_mask contain 1.0f for each state used by a token and 0.0f otherwise. This complicates quickly fetching the first used state of a token, and is also less efficient because a whole row of the mask would always need to be read for each token. Using indexes makes it easy to stop searching when there are no more sequences for a token, and the first sequence assigned is always very quickly available (it's the first element of each row). * mamba : support llama_kv_cache_seq_cp copy chains * mamba : support shifting and dividing the kv cache pos * mamba : make the server and parallel examples work with whole sequences A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not * mamba : dedicate an input tensor for state copy indices This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers. * mamba : adapt perplexity, batched, and batched-bench examples * perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions. * mamba : stop abusing attention metadata This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent" * mamba : more correctly update the "used" field of the KV cache * ggml : in ggml_ssm_scan, use a threshold for soft_plus This is how the official Mamba implementation does it, and it's also what torch.nn.Softplus does. * convert : for Mamba, fallback to internal NeoX tokenizer The resulting models are exactly the same as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there. * mamba : support state saving and restoring * ggml : implicitly pass src tensors through dst for Mamba-related ops * mamba : clarify some comments * server : fix cache_tokens not getting correctly resized Otherwise, when the "we have to evaluate at least 1 token" special case was triggered, an extra token was kept in cache_tokens even if it was removed from the KV cache. For Mamba, this caused useless prompt reprocessing when the previous request triggered the above case. * convert-hf : support new metadata keys for Mamba For the models available at https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406 * mamba : rename metadata to be more similar to transformers library This breaks existing converted-to-GGUF models, but the metadata names are more "standard". * mamba : support mamba-*-hf models These models share their token_embd.weight with their output.weight * mamba : add missing spaces This is purely a formatting change. * convert-hf : omit output.weight when identical with token_embd.weight Only for Mamba for now, but it might be relevant for other models eventually. Most Mamba models actually share these two tensors, albeit implicitly. * readme : add Mamba to supported models, and add recent API changes * mamba : move state_seq and state_mask views outside layer loop A few tensors were also missing `struct` in front of `ggml_tensor`.
Note
Some changes made between this was opened and when this was merged required re-converting previously-converted GGUF Mamba models.
transformers
library 17e4d6cThis should fix #4353
Implementing Mamba in
llama.cpp
took more time than I thought. But it's here! See the TODO section below for a glimpse of some of the challenges. CPU-only for now.I started working on this as an experiment and because I wanted to try Mamba models with
llama.cpp
(also, there have been quite a few finetunes already).Turns out that implementing support for a novel model architecture is quite fun (well, at least when it finally works).
The most powerful machine on which I try LLMs is a low-power laptop with 8GB of ram and an Intel CPU (no discrete GPU), so I can't try Mamba-3B in its full
f32
glory (the full weights take 11GB), but at least now it's possible to use it quantized.Constant memory usage is a big advantage of Mamba models, but this also means that previous states are not all kept in memory (at least in the current implementation, only the last one is kept), which means there might be more prompt re-processing than necessary in the
server
example, especially if your client trims the end of the output (it's also problematic that the stop token(s) are not included in the server's responses). Themain
example has no such problem.Currently, the initial text generation speed for Mamba is a bit slower than for Transformer-based models (with empty context), but unlike them, Mamba's speed does not degrade with the amount of tokens processed.
Also note that quantization may make the state unstable (making the output gibberish), but this needs more testing to figure out how much this happens (because I only saw it happen with very small models (130M), and not yet with bigger ones (3B)).
For testing, I recommend converting from https://huggingface.co/state-spaces/mamba-130m-hf since it's small, the
config.json
doesn't require modification, the tokenizer is already next to the model files, and thetoken_embd
weight is shared with the output weight, so the download is smaller.(EDIT: the following paragraph was written before the re-release of the Mamba models with more metadata (see the mamba-hf collection). Converting these should be more straightforward.)
The official models require modifying their
config.json
to add the line"architectures": ["MambaForCausalLM"],
or"architectures": ["MambaLMHeadModel"],
(either should work). The vocab will automatically come fromllama.cpp/models/ggml-vocab-gpt-neox.gguf
as there are no tokenizer files in the official Mamba model directories (at least, for the non-mamba-hf
repositories).Design decisions
I'd like to discuss a few things before this can be merged:
ssm_out
, I could probably have re-usedattn_output
, but then its relationship withssm_in
would have been less obvious. Conversely, it's not really attention, but I still re-used theattn_norm
type for the layer norms.Currently, the metadata forFixed by 709ea7dHEAD_COUNT
,KEY_LENGHT
andVALUE_LENGTH
are used purely for making the KV cache the right size, depending ond_conv
andd_state
sizes (usually 4 and 16, respectively). This is probably wrong, since changing anything about the cache size (like I did in 7016fe5) breaks existing converted-to-GGUF Mamba models.config.json
next to their model weights, and the effective context length is bigger than that anyway. Should I put a huge number likeconfig.json
, the official Mamba models don't have anarchitectures
field, which makes the model type hard to detect. For now, I've resorted to expecting"architectures": ["MambaForCausalLM"]
, in there, since the Q-bert/Mamba-* models are the only ones I've found which have an actual architecture defined in theconfig.json
. Another architecture name which I've come across isMambaLMHeadModel
, but it has not been used inconfig.json
of any Mamba models I've looked for (I might have missed some). It seems like the class name of the official Mamba implementation, and I first saw it in the description of their 3B model trained on SlimPajama.tokenizer.json
and usellama.cpp
'smodels/ggml-vocab-gpt-neox.gguf
, and I did exactly that._S
,_M
and_L
variants of k-quants are the same, because I don't yet know which weights are more (or less) important. This will require experimentation.mamba-130m
atQ4_K
). It would be nice to find a quant mix which alleviates this.ggml_ssm_scan
(and got a 25% perf boost on Mamba-3B compared to not fusing the operations), so I also removed my addition of theggml_exp
andggml_soft_plus
operators, since they are now unused.ggml_ssm_conv
because managing the states of simultaneous sequences was easier that way.TODO
Things that should (probably) be done before merging, from more important to less important:
llama_kv_cache_seq_rm
on parts of sequences to an equivalent way done using whole sequences (required for at least theserver
andparallel
examples)speculative
andlookahead
examples remain unsupported with Mamba models. This will probably be in a separate PR.seq_id
theperplexity
example usesparallel
example, the HellaSwag benchmark in theperplexity
example, and probably also theserver
example)d_conv
,d_inner
,d_state
, anddt_rank
)tokenizer.json
and usemodels/ggml-vocab-gpt-neox.gguf
when converting a Mamba model to GGUFtokenizer.json
andtokenizer_config.json
from GPT-NeoX were in the model directory.Out of scope for this PR
ggml_ssm_conv
andggml_ssm_scan
speculative
andlookahead
examplesf32
. So Q4_K_M takes 5.76 bits per weight with Mamba 2.8Bserver
exampleReferences
config.json
to add"architectures": ["MambaForCausalLM"],
then usepython3 convert-hf-to-gguf.py ../path/to/mamba-130m/
with the options you want (see--help
) and the correct path.token_embd.weight
andoutput.weight
, so the download is a bit smaller than from the other repositoriesconfig.json
and the presence oftokenizer.json
.