Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support Mamba Selective State Space Models #5328

Merged
merged 43 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8cd0a28
mamba : begin working on support for Mamba SSM
compilade Jan 26, 2024
5a69a26
mamba : begin figuring out how to (ab)use the kv cache for Mamba
compilade Jan 27, 2024
f680364
mamba : recurrent inference almost works, but incoherent
compilade Jan 28, 2024
54d3e48
mamba : recurrent inference WORKS!!!
compilade Jan 28, 2024
74eea85
convert : optionally use d_conv and d_state from config.json for Mamba
compilade Jan 29, 2024
9e77061
mamba : refactor recurrent conv, resulting in 20% perf increase
compilade Jan 29, 2024
3f7233b
ggml : parallelize ggml_exp
compilade Jan 29, 2024
e9cc45e
mamba : simplify the conv step with a self-overlapping view
compilade Jan 31, 2024
81b57bb
mamba : fix self-overlapping view depth stride
compilade Jan 31, 2024
ffc116f
mamba : handle batches of more than 1 token
compilade Feb 1, 2024
78a853b
ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation
compilade Feb 2, 2024
5816ae6
mamba : very basic quantization support
compilade Feb 2, 2024
a3f4a1c
mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator
compilade Feb 3, 2024
9f55809
convert : for Mamba, also consider the "MambaLMHeadModel" arch name
compilade Feb 4, 2024
cd0f33f
mamba : fix vocab size problems with official models
compilade Feb 4, 2024
de92f15
ggml : remove ggml_exp and ggml_soft_plus
compilade Feb 4, 2024
766db75
mamba : remove some useless comments
compilade Feb 4, 2024
c52fb3c
convert : fix flake8 linter errors
compilade Feb 5, 2024
6ff34da
mamba : apply suggestions from code review
compilade Feb 5, 2024
8a43ffc
mamba : multiple sequences, but one at a time
compilade Feb 14, 2024
e73eaa7
mamba : in comments, properly refer to KV cells instead of slots
compilade Feb 14, 2024
de50c54
mamba : reduce memory usage of ggml_ssm_scan
compilade Feb 18, 2024
9473ec2
mamba : simultaneous sequence processing
compilade Feb 19, 2024
3dcf798
mamba : support llama_kv_cache_seq_cp copy chains
compilade Feb 25, 2024
34e2fca
mamba : make the server and parallel examples work with whole sequences
compilade Feb 25, 2024
79d636c
mamba : dedicate an input tensor for state copy indices
compilade Feb 25, 2024
8f605cf
mamba : adapt perplexity, batched, and batched-bench examples
compilade Feb 27, 2024
206e8ee
mamba : stop abusing attention metadata
compilade Feb 28, 2024
1af1000
mamba : more correctly update the "used" field of the KV cache
compilade Mar 2, 2024
d52dd50
ggml : in ggml_ssm_scan, use a threshold for soft_plus
compilade Mar 3, 2024
b83fbc9
convert : for Mamba, fallback to internal NeoX tokenizer
compilade Mar 3, 2024
eefb794
mamba : support state saving and restoring
compilade Mar 3, 2024
2a99d1b
ggml : implicitly pass src tensors through dst for Mamba-related ops
compilade Mar 4, 2024
93fd4b8
mamba : clarify some comments
compilade Mar 4, 2024
5544f52
Merge branch 'master' into support-mamba-ssm
compilade Mar 5, 2024
916b586
Merge branch 'master' into support-mamba-ssm
compilade Mar 7, 2024
7cd5a1f
server : fix cache_tokens not getting correctly resized
compilade Mar 7, 2024
d8024a4
convert-hf : support new metadata keys for Mamba
compilade Mar 8, 2024
17e4d6c
mamba : rename metadata to be more similar to transformers library
compilade Mar 8, 2024
1c8ea55
mamba : add missing spaces
compilade Mar 8, 2024
d0d32dc
convert-hf : omit output.weight when identical with token_embd.weight
compilade Mar 8, 2024
3e5685f
readme : add Mamba to supported models, and add recent API changes
compilade Mar 8, 2024
39579d3
mamba : move state_seq and state_mask views outside layer loop
compilade Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849

Expand Down Expand Up @@ -110,6 +111,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma)
- [x] [Mamba](https://github.com/state-spaces/mamba)

**Multimodal models:**

Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_parallel = params.n_parallel;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
Expand Down
118 changes: 118 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,124 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size

if (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
neox_reader = gguf.GGUFReader(tokenizer_path, "r")

field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)

def write_tensors(self):
block_count = self.hparams["n_layer"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)

tok_embd = None
tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight"
output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight"

for name, data_torch in self.get_tensors():
old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

if name.endswith(".A_log"):
print("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

# assuming token_embd.weight is seen before output.weight
if tok_embd is not None and new_name == output_name:
if torch.equal(tok_embd, data_torch):
print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
continue
if new_name == tok_embd_name:
tok_embd = data_torch

data = data_torch.squeeze().numpy()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert big float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)


###### CONVERSION LOGIC ######


Expand Down
13 changes: 8 additions & 5 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

// ensure enough sequences are available
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

if (ctx == NULL) {
Expand Down Expand Up @@ -174,10 +177,10 @@ int main(int argc, char ** argv) {

llama_batch_clear(batch);

const int n_tokens = is_pp_shared ? pp : pl*pp;

for (int i = 0; i < n_tokens; ++i) {
llama_batch_add(batch, 0, i, { 0 }, false);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
llama_batch_add(batch, 0, i, { j }, false);
}
}
batch.logits[batch.n_tokens - 1] = true;

Expand All @@ -192,7 +195,7 @@ int main(int argc, char ** argv) {

if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}

Expand Down
3 changes: 2 additions & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_parallel = n_parallel;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

Expand Down Expand Up @@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (int32_t i = 1; i < n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

if (n_parallel > 1) {
Expand Down
20 changes: 13 additions & 7 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;

// dedicate one sequence to the system prompt
params.n_parallel += 1;

// requests to simulate
const int32_t n_seq = params.n_sequences;

Expand Down Expand Up @@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
}

// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

LOG_TEE("\n");
Expand All @@ -221,15 +224,17 @@ int main(int argc, char ** argv) {

client.i_batch = batch.n_tokens;

llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);

client.n_decoded += 1;
}

if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1);
// but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

LOG_TEE("%s: clearing the KV cache\n", __func__);
Expand All @@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);

for (size_t i = 0; i < tokens_prompt.size(); ++i) {
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
}

// extract the logits only for the last token
Expand Down Expand Up @@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
}

// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);

const auto t_main_end = ggml_time_us();

Expand Down
9 changes: 6 additions & 3 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

Expand Down Expand Up @@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 128;
const int max_seq = 2*max_tasks_per_batch;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

Expand Down Expand Up @@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

Expand Down Expand Up @@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;

// ensure there's at least enough seq_ids for HellaSwag
params.n_parallel = std::max(4, params.n_parallel);

// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
Expand Down
Loading
Loading