diff --git a/build.py b/build.py index 315733e561..4f75a99b5d 100644 --- a/build.py +++ b/build.py @@ -11,7 +11,7 @@ import mlc_llm from mlc_llm import utils -from mlc_llm.relax_model import gpt_bigcode, gpt_neox, llama, moss, rwkv +from mlc_llm.relax_model import gpt_bigcode, gpt_neox, llama, moss, rwkv, mpt def _parse_args(): @@ -57,6 +57,12 @@ def _parse_args(): default=1, help="Whether to use previously pickled IRModule and skip trace.", ) + args.add_argument( + "--use-kv-cache", + action="store_false", + default=True, + help="Forcely replace use_cache hyperparameter in model config", + ) args.add_argument("--debug-dump", action="store_true", default=False) args.add_argument("--debug-load-script", action="store_true", default=False) args.add_argument( @@ -274,6 +280,20 @@ def mod_transform_before_build( "get_metadata", "reset_kv_cache", ] + elif ARGS.model.startswith("mpt-"): + if ARGS.use_kv_cache: + model_names = [ + "decode", + "create_kv_cache", + "softmax_with_temperature", + "get_metadata", + ] + else: + model_names = [ + "decode", + "softmax_with_temperature", + "get_metadata", + ] else: model_names = [ "prefill", @@ -337,6 +357,10 @@ def dump_default_mlc_chat_config(args): config["shift_fill_factor"] = 0.3 config["tokenizer_files"] = utils.get_tokenizer_files(params_path) + # TODO(vchernov): create mechanism which gets default config prepared for specific model and covers this one + if args.model_category == "mpt": + config["temperature"] = 0.0 + dump_path = os.path.join(params_path, "mlc-chat-config.json") with open(dump_path, "w", encoding="utf-8") as outfile: json.dump(config, outfile, indent=4) @@ -407,6 +431,8 @@ def main(): mod, params = moss.get_model(ARGS, config) elif ARGS.model_category == "rwkv": mod, params = rwkv.get_model(ARGS, config) + elif ARGS.model_category == "mpt": + mod, params = mpt.get_model(ARGS, config) else: raise ValueError(f"Model {ARGS.model} not supported") mod = mod_transform_before_build(mod, params, ARGS) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 91b6893b46..924573fa1d 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -295,6 +295,25 @@ Conversation CodeGPT() { return conv; } +Conversation MPT() { + Conversation conv; + conv.name = "mpt"; + conv.system = ""; + conv.roles = {"", ""}; + conv.messages = {}; + conv.separator_style = SeparatorStyle::kSepRoleMsg; + conv.offset = 0; + conv.seps = {"\n"}; + conv.role_msg_sep = ""; + conv.role_empty_sep = ""; + // TODO(mlc-team): add eos to mlc-chat-config + // and remove eos from stop token setting. + conv.stop_tokens = {0}; + conv.stop_str = "<|endoftext|>"; + conv.add_bos = false; + return conv; +} + } // namespace using ConvFactory = Conversation (*)(); @@ -312,6 +331,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"moss", MOSS}, {"LM", VanillaLM}, {"code_gpt", CodeGPT}, + {"mpt", MPT}, }; auto it = factory.find(name); if (it == factory.end()) { diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index b56fd88c68..7e397fd407 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -128,7 +128,7 @@ class LLMChat { friend class LLMChatModule; public: - explicit LLMChat(DLDevice device) : device_(device) {} + explicit LLMChat(DLDevice device) : device_(device), debug_index_(0) {} /*! * \return Text describing runtime stats. @@ -289,8 +289,11 @@ class LLMChat { << "Cannot find env function vm.builtin.attention_kv_cache_array_popn"; fkvcache_array_popn_ = *fkvcache_array_popn; - // Step 4. KV cache creation. - kv_cache_ = vm_->GetFunction("create_kv_cache")(); + // Step 4. KV cache creation if need. + auto kv_cache_func = vm_->GetFunction("create_kv_cache"); + if (kv_cache_func.defined()) { + kv_cache_ = kv_cache_func(); + } // Step 5. KV cache reset. reset_kv_cache_func_ = vm_->GetFunction("reset_kv_cache"); @@ -508,6 +511,9 @@ class LLMChat { } std::vector prompt_tokens = this->GetInputTokens(); + if (kv_cache_.empty()) { + full_output_ids_.insert(full_output_ids_.end(), prompt_tokens.begin(), prompt_tokens.end()); + } int64_t token_len = static_cast(prompt_tokens.size()); if (token_len == 0) return; @@ -527,14 +533,18 @@ class LLMChat { } void DecodeStep() { - ICHECK(!output_ids_.empty()); - int32_t last_token = output_ids_.back(); - tvm::runtime::NDArray input_data = GetInputTokenNDArray({last_token}); + std::vector input_tokens; + if (kv_cache_.empty()) { + ICHECK(!full_output_ids_.empty()); + input_tokens = full_output_ids_; + } else { + ICHECK(!output_ids_.empty()); + input_tokens = {output_ids_.back()}; + } auto tstart = std::chrono::high_resolution_clock::now(); - NDArray logits_on_device = this->Forward({last_token}, total_seq_len_ + 1); - total_seq_len_ += 1; + NDArray logits_on_device = this->Forward(input_tokens, ++total_seq_len_); int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); @@ -588,12 +598,7 @@ class LLMChat { auto decoding_end = std::chrono::high_resolution_clock::now(); // print first few logits for eyeballs - std::ostringstream os; - for (int i = 0; i < 10; ++i) { - if (i != 0) os << ", "; - os << static_cast(logits_on_cpu_->data)[i]; - } - LOG(INFO) << "logits[:10] =[" << os.str() << "]"; + PrintNDArray(logits_on_cpu_, 10, "Logits"); double encoding_ms = static_cast((decoding_start - encoding_start).count()) / 1e6; double decoding_ms = static_cast((decoding_end - decoding_start).count()) / 1e6; @@ -602,6 +607,62 @@ class LLMChat { << "decoding-time=" << decoding_ms << "ms."; } + NDArray getArrayToPrint(NDArray array) const { + ICHECK(array->data != nullptr) << "Array data is nullptr"; + // Check that the data on CPU and copy if need + if (array->device.device_type != kDLCPU) { + NDArray array_cpu; + array_cpu = array.CopyTo(DLDevice{kDLCPU, 0}); + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + return array_cpu; + } else { + return array; + } + } + + void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor", bool to_save = false) { + NDArray array_cpu = getArrayToPrint(array); + + size_t ndim = array_cpu->ndim; + int64_t numel = 1; + // Print shape and calculate numel + std::ostringstream os_shape; + for (size_t i = 0; i < ndim; ++i) { + if (i != 0) os_shape << ", "; + numel *= array_cpu->shape[i]; + os_shape << array_cpu->shape[i]; + } + + std::string num_tag = std::to_string(num); + if (num == -1 || num >= numel) { + num = numel; + num_tag = ""; + } + // TODO(vchernov): after test return LOG(INFO) + std::cout << tensor_tag << " shape = [" << os_shape.str() << "]" << std::endl; + // LOG(INFO) << tensor_tag << " shape = [" << os_shape.str() << "]"; + + // Print specified number of values from tensor + std::ostringstream os; + const float* p_data = static_cast(array_cpu->data); + for (int64_t i = 0; i < num; ++i) { + if (i != 0) os << ", "; + os << p_data[i]; + } + // TODO(vchernov): after test return LOG(INFO) + std::cout << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]" << std::endl; + // LOG(INFO) << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]"; + + // Save to binary file + if (to_save) { + std::string file_name = "tensor_" + std::to_string(debug_index_++) + ".bin"; + std::cout << tensor_tag << " is saved in " << file_name << std::endl; + std::ofstream fs(file_name, std::ios::out | std::ios::binary | std::ios::app); + fs.write(reinterpret_cast(p_data), 4 * numel); + fs.close(); + } + } + private: picojson::value SerializeConfigToJSONValue() const { picojson::object config; @@ -656,6 +717,9 @@ class LLMChat { if (!stop_triggered_) { output_ids_.push_back(next_token); + if (kv_cache_.empty()) { + full_output_ids_.push_back(next_token); + } appeared_token_ids_.insert(next_token); } @@ -699,10 +763,16 @@ class LLMChat { ret = prefill_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_); } else { // running decode function when prefill is not available - for (int i = 0; i < input_tokens.size(); ++i) { - NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]}); - int64_t pos = cur_pos + i + 1 - input_tokens.size(); - ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); + if (kv_cache_.empty()){ + // Without kv_cache full sequence of tokens is used + NDArray input_data = this->GetInputTokenNDArray(input_tokens); + ret = decode_func_(input_data, params_); + } else { + for (int i = 0; i < input_tokens.size(); ++i) { + NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]}); + int64_t pos = cur_pos + i + 1 - input_tokens.size(); + ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); + } } } return Downcast(ret[0]); @@ -763,7 +833,10 @@ class LLMChat { // Clear kv cache void ResetKVCache() { reset_kv_cache_func_(kv_cache_); } - void ProcessSystemPrompts() { this->PrefillStep(/*inp=*/"", /*append_conversation=*/false); } + void ProcessSystemPrompts() { + full_output_ids_.clear(); + this->PrefillStep(/*inp=*/"", /*append_conversation=*/false); + } // Utils static double GetRandomNumber() { @@ -783,6 +856,7 @@ class LLMChat { ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined"; ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; + return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber()); } @@ -816,6 +890,8 @@ class LLMChat { double top_p_{0.95}; // output ids till now (refresh after encoding step) std::vector output_ids_; + // output ids till now (sys and client prompt + generated by decoder) + std::vector full_output_ids_; // appeared token ids till now (refresh after encoding step) std::unordered_set appeared_token_ids_; // output message till now (refresh after encoding step) @@ -866,6 +942,8 @@ class LLMChat { Array kv_cache_; // Temp logits on cpu NDArray logits_on_cpu_{nullptr}; + // Debug index + int32_t debug_index_; }; /*! diff --git a/mlc_llm/dispatch/dispatch_tir_operator.py b/mlc_llm/dispatch/dispatch_tir_operator.py index 93b72256c2..5aafa6d8fd 100644 --- a/mlc_llm/dispatch/dispatch_tir_operator.py +++ b/mlc_llm/dispatch/dispatch_tir_operator.py @@ -19,6 +19,9 @@ def __init__(self, model: str): elif model == "rwkv": lookup = None + elif model == "mpt": + lookup = None + else: raise ValueError(f"Model {model} not supported") self.lookup = lookup diff --git a/mlc_llm/relax_model/__init__.py b/mlc_llm/relax_model/__init__.py index 9ee3d0db52..d50967ec9c 100644 --- a/mlc_llm/relax_model/__init__.py +++ b/mlc_llm/relax_model/__init__.py @@ -1 +1,2 @@ from . import llama +from .mpt import mpt diff --git a/mlc_llm/relax_model/mpt/README.md b/mlc_llm/relax_model/mpt/README.md new file mode 100644 index 0000000000..b0ccf83ed1 --- /dev/null +++ b/mlc_llm/relax_model/mpt/README.md @@ -0,0 +1,241 @@ +# MPT-7b-instruct + +There is brief description of mpt-7b-instruct model. It is needed for correct Relax implementation of the model and weights mapping. +MPT-7b-instruct is decoder-like kv_cache free model using flash attention. +Data type is brain float16 by default. But numpy used inside scripts and TVM do not support this type. Due to this to compile MPT-like model use following script: +```bash +python3 bfloat16_to_float16.py +``` +It is saved converted model in `dist/models/-float16` directory. +**Note:** After conversion to float16, only weights and config will be saved. Transfer other files (like tokenizer vocab) from the original directory. + +The list of Tensor name - tensor size for the original (pytorch) model can be found in mpt_topology.txt file. +The original config for the model: +{ + "architectures": [ + "MPTForCausalLM" + ], + "attn_config": { + "alibi": true, + "alibi_bias_max": 8, + "attn_impl": "torch", + "attn_pdrop": 0, + "attn_type": "multihead_attention", + "attn_uses_sequence_id": false, + "clip_qkv": null, + "prefix_lm": false, + "qk_ln": false, + "softmax_scale": null + }, + "auto_map": { + "AutoConfig": "configuration_mpt.MPTConfig", + "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM" + }, + "d_model": 4096, + "emb_pdrop": 0, + "embedding_fraction": 1.0, + "expansion_ratio": 4, + "init_config": { + "emb_init_std": null, + "emb_init_uniform_lim": null, + "fan_mode": "fan_in", + "init_div_is_residual": true, + "init_gain": 0, + "init_nonlinearity": "relu", + "init_std": 0.02, + "name": "kaiming_normal_", + "verbose": 0 + }, + "init_device": "cpu", + "learned_pos_emb": true, + "logit_scale": null, + "max_seq_len": 2048, + "model_type": "mpt", + "n_heads": 32, + "n_layers": 32, + "no_bias": true, + "norm_type": "low_precision_layernorm", + "resid_pdrop": 0, + "tokenizer_name": "EleutherAI/gpt-neox-20b", + "torch_dtype": "bfloat16", + "transformers_version": "4.28.1", + **"use_cache": false,** + "verbose": 0, + "vocab_size": 50432 +} + +This config wraps default one (see below). It should highlight two defaults parameters: +"is_encoder_decoder": false, +"use_cache": false, + +Default config parameters (PretrainedConfig): +"return_dict": True +"output_hidden_states": False +"output_attentions": False +"torchscript": False +"torch_dtype": None +"use_bfloat16": False +"tf_legacy_loss": False +"pruned_heads": {} +"tie_word_embeddings": True + +**"is_encoder_decoder": False** +"is_decoder": False +"cross_attention_hidden_size": None +"add_cross_attention": False +"tie_encoder_decoder": False + +"max_length": 20 +"min_length": 0 +"do_sample": False +"early_stopping": False +"num_beams": 1 +"num_beam_groups": 1 +"diversity_penalty": 0.0 +"temperature": 1.0 +"top_k": 50 +"top_p": 1.0 +"typical_p": 1.0 +"repetition_penalty": 1.0 +"length_penalty": 1.0 +"no_repeat_ngram_size": 0 +"encoder_no_repeat_ngram_size": 0 +"bad_words_ids": None +"num_return_sequences": 1 +"chunk_size_feed_forward": 0 +"output_scores": False +"return_dict_in_generate": False +"forced_bos_token_id": None +"forced_eos_token_id": None +"remove_invalid_values": False +"exponential_decay_length_penalty": None +"suppress_tokens": None +"begin_suppress_tokens": None + +"architectures": None +"finetuning_task": None +"id2label": None +"label2id": None +if self.id2label is not None: + "num_labels": None + id2label = dict((int(key), value) for key, value in id2label.items()) +else: + "num_labels": 2 + +"tokenizer_class": None +"prefix": None +"bos_token_id": None +"pad_token_id": None +"eos_token_id": None +"sep_token_id": None + +"decoder_start_token_id": None + +"task_specific_params": None + +Some parameters from generate() function from transformers: +```python +is_greedy_gen_mode = True +``` + +Start greedy_search method in generate() from transformers: +```python +self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, +) +``` +Where parameters for MPT-7b-instruct: +```python +logits_processor? +stopping_criteria? +pad_token_id = None +eos_token_id None +output_scores = False +return_dict_in_generate = False +synced_gpus = False +streamer = None +model_kwargs = { + 'output_attentions': False, + 'output_hidden_states': False, + 'use_cache': False, + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0') +} +``` + +Refactored greedy_search method for MPT-7b-instruct: +```python +def greedy_search(...): + # init values + logits_processor = LogitsProcessorList() + stopping_criteria = stopping_criteria # max_length and max_time criteria + pad_token_id = None + eos_token_id = None + eos_token_id_tensor = None + output_scores = False + output_attentions = False + output_hidden_states = False + return_dict_in_generate = False + + # init attention / hidden states / scores tuples + scores = None + decoder_attentions = None + cross_attentions = None + decoder_hidden_states = None + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + while True: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # model_inputs = { + # 'input_ids': tensor([[...]], device='cuda:0'), + # 'attention_mask': tensor([[True, ..., True]], device='cuda:0'), + # 'prefix_mask': None, + # 'sequence_id': None, + # 'past_key_values': None, + # 'use_cache': False} + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution. Due to logits_processor is empty next_tokens_scores = next_token_logits + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + # START model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + + model_kwargs["past"] = None + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + # FINISH self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break +``` \ No newline at end of file diff --git a/mlc_llm/relax_model/mpt/bfloat16_to_float16.py b/mlc_llm/relax_model/mpt/bfloat16_to_float16.py new file mode 100644 index 0000000000..380e7887f0 --- /dev/null +++ b/mlc_llm/relax_model/mpt/bfloat16_to_float16.py @@ -0,0 +1,56 @@ +from pathlib import Path +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import WEIGHTS_NAME, CONFIG_NAME + + +def load_bf16_model(dir_path, tokenizer_name): + model = AutoModelForCausalLM.from_pretrained( + dir_path, + trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + return model, tokenizer + + +def save_fp16_model(dir_path, model, tokenizer, manually=False): + new_name = dir_path.name + "-float16" + out_path = dir_path.parent.joinpath(new_name) + + if manually: + # Manual saving + output_model_file = Path.joinpath(out_path, WEIGHTS_NAME) + output_config_file = Path.joinpath(out_path, CONFIG_NAME) + + model_to_save = model.module if hasattr(model, 'module') else model + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(out_path) + else: + # Use transformer API + model.save_pretrained(out_path, from_pt=True) + + +def main(args): + model_root_dir = Path(args.model_path) + + # Load original model (bfloat16) + model, tokenizer = load_bf16_model(model_root_dir, args.tokenizer) + # Convert data type to float 16 + model.to(dtype=torch.float16) + # Save converted model + save_fp16_model(model_root_dir, model, tokenizer) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--model_path', type=str, default="../../../dist/models/mpt-7b-instruct", + help="The path to directory with bfloat16 model") + parser.add_argument('-t', '--tokenizer', type=str, default="EleutherAI/gpt-neox-20b", + help="Tag for transformers to upload correct tokenizer") + + args = parser.parse_args() + main(args) diff --git a/mlc_llm/relax_model/mpt/compare.py b/mlc_llm/relax_model/mpt/compare.py new file mode 100644 index 0000000000..21bdbda2b9 --- /dev/null +++ b/mlc_llm/relax_model/mpt/compare.py @@ -0,0 +1,79 @@ +from pathlib import Path +import argparse + +import torch +import numpy as np + +# std::ofstream fs("tensor.bin", std::ios::out | std::ios::binary | std::ios::app); +# fs.write(reinterpret_cast(&tensor), sizeof tensor); +# fs.close(); + +def save_torch_tensor(t: torch.tensor, path=Path("./orig_input.pt")): + torch.save(t, path) + +def load_torch_tensor(path=Path("./orig_input.pt")): + return torch.load(path) + +def advanced_compare(lft, rht, atol=1e-5, rtol=1e-5): + if len(lft.shape) > 1: + lft = lft.flatten() + if len(rht.shape) > 1: + lft = rht.flatten() + numel = lft.shape[0] + assert numel == rht.shape[0] + counter = 0 + rtols=[rtol] + for i in range(numel): + diff = np.abs(lft[i]-rht[i]) + exp_diff = atol + rtol*np.abs(rht[i]) + if diff > exp_diff: + new_rtol = (diff - atol)/np.abs(rht[i]) + rtols.append(new_rtol) + print("Elements with index", i, " are not the same left:", lft[i], " right:", rht[i]) + counter = counter + 1 + print("Number of diverged values:", counter, " Percent is", 100*float(counter)/numel,"%") + max_rtol = np.max(rtols) + print("Current rtol:", rtol, "Maximum rtol:", max_rtol) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-r', '--rtol', type=float, default=5e-3, + help="Relative tolerance") + parser.add_argument('-a', '--atol', type=float, default=1e-6, + help="Absolute tolerance") + parser.add_argument('-w', '--check_weight', default=False, action="store_true", + help="Compare weights. Corresponding files are required") + + args = parser.parse_args() + + check_num = 10 + # Load data from Relax model + np_input = np.fromfile(Path("./relax_input.bin"), dtype="float32") + np_weight = np.fromfile(Path("./relax_weight.bin"), dtype="float32") + print("RELAX INPUT TYPE:", np_input.dtype, "SHAPE:", np_input.shape) + print("RELAX WEIGHT TYPE:", np_weight.dtype, "SHAPE:", np_weight.shape) + + # Load data from original model + orig_input = load_torch_tensor() + orig_weight = load_torch_tensor(Path("./orig_weight.pt")) + + orig_np_input = orig_input.numpy() + orig_np_weight = orig_weight.numpy() + print("ORIG INPUT TYPE:", orig_np_input.dtype, "SHAPE:", orig_np_input.shape) + print("ORIG WEIGHT TYPE:", orig_np_weight.dtype, "SHAPE:", orig_np_weight.shape) + + print("Compare inputs") + print("ORIG INPUT:", orig_np_input[:check_num]) + print("RELAX INPUT:", np_input[:check_num]) + # np.testing.assert_allclose(orig_np_input, np_input, rtol=rtol, atol=atol, verbose=True) + advanced_compare(orig_np_input, np_input, rtol=args.rtol, atol=args.atol) + + if args.check_weight: + print("Compare weights") + orig_np_line = orig_np_weight[0,:] + print("ORIG WEIGHT:", orig_np_line[:check_num]) + print("RELAX WEIGHT:", np_weight[:check_num]) + np.testing.assert_allclose(orig_np_line, np_weight, rtol=args.rtol, atol=args.atol, verbose=True) + +if __name__ == "__main__": + main() diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py new file mode 100644 index 0000000000..08c27c508e --- /dev/null +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -0,0 +1,927 @@ +import math +import warnings +from typing import Optional, Tuple, List, Dict + +import tvm +from tvm import relax, tir, te +from tvm.relax.testing import nn +from tvm.script import relax as R + +from .mpt_config import MPTConfig, attn_config_defaults +from ...utils import load_torch_pname2binname_map +from ..commons import create_metadata_func +from ..modules import ( + Embedding, + LayerNorm, + Linear, + ModuleList, + named_parameters, +) + + +def _cast_if_autocast_enabled(tensor: relax.Expr, dtype="float32"): + # # TODO: how to check device? + # if tensor.device.type == 'cuda': + # dtype = "float16" + # elif tensor.device.type == 'cpu': + # dtype = "bfloat16" + # else: + # raise NotImplementedError() + return nn.emit(relax.op.astype(tensor, dtype)) + +# Low-precision layer norm for mpt-7b-instruct, where are no biases expected +class LPLayerNormWOBias(nn.Module): + def __init__(self, normalized_shape, dtype, eps=1e-05): + self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") + # TODO(vchernov): need to set something to layer_norm, but not use + self.dummy_bias = relax.op.zeros((normalized_shape,), dtype) + self.eps = eps + + self.dtype = dtype + + def forward(self, x): + dtype = self.dtype # TODO: temporal workaround + downcast_x = _cast_if_autocast_enabled(x, dtype) + downcast_weight = _cast_if_autocast_enabled(self.weight, dtype) + return nn.emit(relax.op.nn.layer_norm(downcast_x, downcast_weight, self.dummy_bias, axes=-1, epsilon=self.eps, center=False)) + +NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNormWOBias} + + +def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): + if original_is_causal and num_query_tokens != num_key_tokens: + if num_query_tokens != 1: + raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') + else: + return False + return original_is_causal + + +######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (BEGIN) ########################## + +def scaled_multihead_dot_product_attention( + query: relax.Expr, + key: relax.Expr, + value: relax.Expr, + n_heads: int, + d_model: int, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_value: Optional[Tuple[relax.Expr]]=None, + softmax_scale: Optional[float]=None, + attn_bias: Optional[relax.Expr]=None, + key_padding_mask: Optional[relax.Expr]=None, + is_causal: bool=False, + needs_weights: bool=False, +): + head_dim = d_model // n_heads + dtype = query.struct_info.dtype + + b, s_q, _ = query.struct_info.shape + assert b == 1, "Only support batch size 1 at this moment." + + q = nn.emit(relax.op.reshape(query, (b, s_q, n_heads, head_dim))) + k = nn.emit(relax.op.reshape(key, (b, -1, n_heads, head_dim))) + v = nn.emit(relax.op.reshape(value, (b, -1, n_heads, head_dim))) + + if past_key_value is not None: + kv_seq_len = all_seq_len_shape.struct_info.values[0] + + kv_shape = k.struct_info.shape + kv_dtype = k.struct_info.dtype + assert kv_shape[0] == 1 # batch size + kv_shape = R.shape( + [kv_shape[0], kv_seq_len, kv_shape[2], kv_shape[3]] + ) + kv_cache_shape = R.shape([kv_seq_len, kv_shape[2], kv_shape[3]]) + + # There is requirement b == 1 used + squeezed_key = nn.emit(relax.op.squeeze(k, axis=0)) + squeezed_value = nn.emit(relax.op.squeeze(v, axis=0)) + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[k_cache, squeezed_key], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[v_cache, squeezed_value], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_dtype)], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_dtype)], + ) + ) + k = nn.emit(relax.op.reshape(k_cache, kv_shape)) + v = nn.emit(relax.op.reshape(v_cache, kv_shape)) + s_k = k.struct_info.shape[1] + if softmax_scale is None: + softmax_scale = 1 / math.sqrt(head_dim) + # TODO(vchernov): matmul(q, k) generates inf when float16 is used. There is workaround + if dtype != "float32": + q = nn.emit(relax.op.astype(q, "float32")) + k = nn.emit(relax.op.astype(k, "float32")) + softmax_scale = relax.op.astype(relax.const(softmax_scale), q.struct_info.dtype) + + q = nn.emit(relax.op.permute_dims(q, [0, 2, 1, 3])) + k = nn.emit(relax.op.permute_dims(k, [0, 2, 1, 3])) + v = nn.emit(relax.op.permute_dims(v, [0, 2, 1, 3])) + + attn_weight = nn.emit(relax.op.matmul(q, relax.op.permute_dims(k, [0, 1, 3, 2])) * softmax_scale) + # TODO(vchernov): attn_bias.shape is None due to it is not calculated in strided_slice with dynamic input + # _, _, s_q_end, s_k_end = attn_bias.struct_info.shape # shape = [1, 32, 1, seq_len] + if attn_bias is not None: + # s_q = 1 for use_cache = True and = seq_len otherwise + # s_k = seq_len always + # TODO(vchernov): _s_q, _s_k can not be calculated due to reason above, but + # Trivial symbolic arithmetic shows that: + # _s_q = 0 always (s_q_end - s_q <= 0) + # _s_k = 0 + # _s_q = relax.op.maximum(0, s_q_end - s_q) + # _s_k = relax.op.maximum(0, s_k_end - s_k) + # TODO(vchernov): due to _s_q = 0 and _s_k = 0 the below slicing can be skipped + # slicing attn_bias[:, :, _s_q:, _s_k:] + # attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) + # TODO(vchernov): matmul(q, k) generates inf when float16 is used. + if dtype != "float32": + attn_bias = nn.emit(relax.op.astype(attn_bias, "float32")) + attn_weight = nn.emit(attn_weight + attn_bias) + min_val = get_type_min_val(q) + if key_padding_mask is not None: + if attn_bias is not None: + warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + key_mask = nn.emit(relax.op.logical_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) + attn_weight = nn.emit(relax.op.masked_fill(attn_weight, key_mask, min_val)) + if is_causal and (not s_q == 1): + # It is the case where is no kv cache, thus s_q == s_k + # s = relax.op.maximum(s_q, s_k) + s = s_q + causal_mask = nn.emit(relax.op.ones((s, s,), dtype="bool")) + causal_mask = nn.emit(relax.op.triu(causal_mask, 1)) + # Due to the case the slicing below can be skipped + # slicing causal_mask[-s_q:, -s_k:] + # s_q_end, s_k_end = causal_mask.struct_info.shape + # causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) + causal_mask = nn.emit(relax.op.broadcast_to(causal_mask, (b, n_heads, s, s))) + attn_weight = nn.emit(relax.op.masked_fill(attn_weight, causal_mask, min_val)) + # TODO(vchernov): matmul(q, k) generates inf when float16 is used. + # There is uncast after workaround with float calculation due to softmax range = [0, 1] + attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) + if dtype != "float32": + attn_weight = nn.emit(relax.op.astype(attn_weight, dtype)) + out = nn.emit(relax.op.matmul(attn_weight, v)) + + out = nn.emit(relax.op.permute_dims(out, [0, 2, 1, 3])) + out = nn.emit(relax.op.reshape(out, (b, tir.const(1, dtype="int64"), tir.const(d_model, dtype="int64")))) + + return out, past_key_value + +######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (END) ########################## + + +def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): + for tensor in tensors: + if tensor.struct_info.dtype not in valid_dtypes: + raise TypeError(f'tensor.dtype={tensor.struct_info.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') + # TODO: check on relax that CUDA is used + # if not tensor.is_cuda: + # raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') + + +def flash_attn_fn( + query, + key, + value, + n_heads, + d_model, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + needs_weights=False, + multiquery=False +): + from ..mha_flash_attn import bert_padding_unpad_input, bert_padding_pad_input, flash_attn_unpadded_func + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) + value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) + past_key_value = (key, value) + batch_size = query.struct_info.shape[0] + seqlen = query.struct_info.shape[1] + key_shape_d1 = key.struct_info.shape[1] + if attn_bias is not None: + _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - seqlen) + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key_shape_d1) + # slicing attn_bias[:, :, _s_q:, _s_k:] + s_q_end = attn_bias.struct_info.shape[2] + s_k_end = attn_bias.struct_info.shape[3] + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) + if attn_bias is not None: + raise NotImplementedError(f'attn_bias not implemented for flash attn.') + if key_padding_mask is None: + key_shape_d0 = key.struct_info.shape[0] + key_padding_mask = nn.emit(relax.op.ones(Tuple(key_shape_d0, key_shape_d1, 1), dtype="bool")) + # slicing key_padding_mask[:, -query.struct_info.shape[1]:] + dim1_length = key_padding_mask.struct_info.shape[1] + query_padding_mask = nn.emit(relax.op.strided_slice(key_padding_mask, [1], [dim1_length - seqlen], [dim1_length])) + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) + + qnnz = query_unpad.struct_info.shape[0] + query_unpad = nn.emit(relax.op.reshape( + query_unpad, + (qnnz, n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + + kv_nnz = key_unpad.struct_info.shape[0] + kv_n_heads = 1 if multiquery else n_heads + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding_unpad_input(key, key_padding_mask) + key_unpad = nn.emit(relax.op.reshape( + key_unpad, + (kv_nnz, kv_n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + (value_unpad, _, _, _) = bert_padding_unpad_input(value, key_padding_mask) + value_unpad = nn.emit(relax.op.reshape( + value_unpad, + (kv_nnz, kv_n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + + if multiquery: + key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) + value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) + reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key_shape_d1, is_causal) + output_unpad = flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) + nnz = output_unpad.struct_info.shape[0] + output_unpad = nn.emit(relax.op.reshape( + output_unpad, + (nnz, n_heads*d_model), + )) # (nnz, h, d)) -> (nnz, (h d)) + output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) + return (output, None, past_key_value) + + +######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (BEGIN) ########################## + +# def triton_flash_attn_fn( +# query, +# key, +# value, +# n_heads, +# d_model, +# past_key_value=None, +# softmax_scale=None, +# attn_bias=None, +# key_padding_mask=None, +# is_causal=False, +# needs_weights=False, +# multiquery=False): +# try: +# from .flash_attn_triton import flash_attn_func +# except: +# _installed = False +# if version.parse(torch.__version__) < version.parse('2.0.0'): +# _installed = True +# try: +# from flash_attn.flash_attn_triton import flash_attn_func +# except: +# _installed = False +# if not _installed: +# raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.') +# check_valid_inputs(query, key, value) +# if past_key_value is not None: +# if len(past_key_value) != 0: +# key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) +# value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) +# past_key_value = (key, value) +# if attn_bias is not None: +# _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) +# _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) +# # slicing attn_bias[:, :, _s_q:, _s_k:] +# s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] +# attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) +# if needs_weights: +# raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') +# if key_padding_mask is not None: +# warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') +# (b_size, s_k) = key_padding_mask.struct_info.shape[:2] +# if attn_bias is None: +# attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) +# key_mask = nn.emit(relax.op.logical_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) +# attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, get_type_min_val(query))) + +# batch_size, seq_len, _ = query.struct_info.shape +# query = nn.emit(relax.op.reshape( +# query, +# (batch_size, seq_len, n_heads, d_model), +# )) # b s (h d) -> b s h d + +# batch_size, seq_len, _ = key.struct_info.shape +# kv_n_heads = 1 if multiquery else n_heads +# key = nn.emit(relax.op.reshape( +# key, +# (batch_size, seq_len, kv_n_heads, d_model), +# )) # b s (h d) -> b s h d +# value = nn.emit(relax.op.reshape( +# value, +# (batch_size, seq_len, kv_n_heads, d_model), +# )) # b s (h d) -> b s h d +# if multiquery: +# key = relax.op.broadcast_to(key, (batch_size, seq_len, n_heads, d_model)) +# value = relax.op.broadcast_to(value, (batch_size, seq_len, n_heads, d_model)) +# reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) +# attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) +# batch_size, seq_len, _, _ = attn_output.struct_info.shape +# output = nn.emit(relax.op.reshape( +# attn_output, +# (batch_size, seq_len, n_heads*d_model), +# )) # (b, s, h, d)) -> (b, s, (h d)) +# return (output, None, past_key_value) + +######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (END) ########################## + + +class MultiheadAttention(nn.Module): + """Multi-head self attention. + Using torch or triton attention implemetation enables user to also use + additive bias. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + dtype: str, + attn_impl: str='triton', + clip_qkv: Optional[float]=None, + qk_ln: bool=False, + softmax_scale: Optional[float]=None + ): + # Init fields + self.d_model = d_model + self.n_heads = n_heads + self.attn_impl = attn_impl + self.clip_qkv = clip_qkv + self.qk_ln = qk_ln + self.softmax_scale = softmax_scale + + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) + self.Wqkv = Linear(self.d_model, 3 * self.d_model, dtype, bias=False) + if self.qk_ln: + self.q_ln = LayerNorm(self.d_model, dtype) + self.k_ln = LayerNorm(self.d_model, dtype) + if self.attn_impl == 'flash': + self.attn_fn = flash_attn_fn + elif self.attn_impl == 'triton': + # While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. + # When training larger models this can trigger alloc retries which hurts performance. + # If encountered, we recommend using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`. + raise NotImplemented("Triton type of flash attention has not been implemented yet") + # self.attn_fn = triton_flash_attn_fn + elif self.attn_impl == 'torch': + # Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` + # otherwise we recommend using `attn_impl: triton`. + self.attn_fn = scaled_multihead_dot_product_attention + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + self.out_proj = Linear(self.d_model, self.d_model, dtype, bias=False) + + def forward( + self, + x: relax.Expr, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_value: Optional[Tuple[relax.Expr]]=None, + attn_bias: Optional[relax.Expr]=None, + attention_mask: Optional[relax.Expr] = None, + is_causal: bool=True, + ): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv = nn.emit(relax.op.clip(qkv, min=relax.const(-self.clip_qkv), max=relax.const(self.clip_qkv))) + qkv_out = relax.op.split(qkv, 3, axis=2) + query = nn.emit(qkv_out[0]) + key = nn.emit(qkv_out[1]) + value = nn.emit(qkv_out[2]) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.struct_info.dtype + query = nn.emit(relax.op.astype(self.q_ln(query), dtype)) + key = nn.emit(relax.op.astype(self.k_ln(key), dtype)) + attn_out, past_key_value = self.attn_fn( + query, + key, + value, + self.n_heads, + self.d_model, + all_seq_len_shape=all_seq_len_shape, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + needs_weights=False, + ) + return self.out_proj(attn_out), past_key_value + +ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} + + +class MPTMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + + def forward(self, x): + return self.down_proj(relax.op.nn.gelu(self.up_proj(x))) + + +class MPTBlock(nn.Module): + def __init__(self, config: MPTConfig): + # Get values from config or defaults + attn_config = config.attn_config if config.attn_config is not None else attn_config_defaults + norm_type = config.norm_type if config.norm_type is not None else 'low_precision_layernorm' + # Define layer norm and attention classes + norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] + attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] + + self.hidden_size = config.d_model + # Init layers + self.attn = attn_class( + d_model=self.hidden_size, + n_heads=config.n_heads, + dtype=config.dtype, + attn_impl=attn_config['attn_impl'], + clip_qkv=attn_config['clip_qkv'], + qk_ln=attn_config['qk_ln'], + softmax_scale=attn_config['softmax_scale'], + ) + self.ffn = MPTMLP( + hidden_size=self.hidden_size, + intermediate_size=config.expansion_ratio*self.hidden_size, + dtype=config.dtype, + ) + self.norm_1 = norm_class(self.hidden_size, config.dtype) + self.norm_2 = norm_class(self.hidden_size, config.dtype) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_value: Optional[Tuple[relax.Expr]]=None, + attn_bias: Optional[relax.Expr] = None, + attention_mask: Optional[relax.Expr] = None, + is_causal: bool=True, + ) -> Tuple[relax.Expr, relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + hidden_states = self.norm_1(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.attn( + hidden_states, + all_seq_len_shape=all_seq_len_shape, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=is_causal + ) + residual = nn.emit(residual + hidden_states) + + # Fully Connected + hidden_states = self.norm_2(residual) + hidden_states = self.ffn(hidden_states) + hidden_states = nn.emit(residual + hidden_states) + + return hidden_states, present_key_value + + +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + if (prefix_lm or not causal) or use_sequence_id: + return (1, n_heads, seq_len, seq_len) + return (1, n_heads, 1, seq_len) + elif prefix_lm or use_sequence_id: + return (1, 1, seq_len, seq_len) + return None + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + + +def gen_slopes(n_heads, alibi_bias_max=8): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = nn.emit(relax.op.arange(1, _n_heads + 1, dtype="float32")) + m = nn.emit(m * relax.const(alibi_bias_max / _n_heads)) + slopes = nn.emit(relax.op.divide(relax.const(1.0), relax.op.power(m, relax.const(2.0)))) + + if _n_heads != n_heads: + slopes_len = slopes.struct_info.shape[0] + slopes = nn.emit(relax.op.strided_slice( + relax.op.concat( + [relax.op.strided_slice(slopes, [0], [relax.const(1, dtype="int64")], [slopes_len], [relax.const(2)]), # [1::2] + relax.op.strided_slice(slopes, [0], [relax.const(0, dtype="int64")], [slopes_len], [relax.const(2)])] # [::2] + ), [0], [relax.const(0, dtype="int64")], [relax.const(n_heads, dtype="int64")]) # slicing [:n_heads] + ) + return nn.emit(relax.op.reshape(slopes, (1, n_heads, 1, 1))) + + +def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, dtype=None): + alibi_bias = nn.emit(relax.op.reshape(relax.op.arange(1 - seq_len, 1, dtype="int32"), (1, 1, 1, seq_len))) + if full: + alibi_bias = nn.emit(alibi_bias - relax.op.reshape(relax.op.arange(1 - seq_len, 1, dtype="int32"), (1, 1, seq_len, 1))) + alibi_bias = nn.emit(relax.op.negative(relax.op.abs(alibi_bias))) + slopes = gen_slopes(n_heads, alibi_bias_max) + alibi_bias = nn.emit(relax.op.astype(alibi_bias, slopes.struct_info.dtype)) + alibi_bias = nn.emit(alibi_bias * slopes) + if dtype is not None: + alibi_bias = nn.emit(relax.op.astype(alibi_bias, dtype)) + return alibi_bias + + +def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + attn_bias = nn.emit(relax.op.add(attn_bias, build_alibi_bias( + n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, dtype=attn_bias.struct_info.dtype + ))) + return attn_bias + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + + +def get_type_min_val(tensor): + return relax.const( + tir.min_value(tensor.struct_info.dtype).value, + tensor.struct_info.dtype, + ) + + +class MPTModel(nn.Module): + def __init__(self, config: MPTConfig): + config._validate_config() + # Init fields from config + self.attn_impl = config.attn_config['attn_impl'] + self.prefix_lm = config.attn_config['prefix_lm'] + self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] + self.alibi = config.attn_config['alibi'] + self.alibi_bias_max = config.attn_config['alibi_bias_max'] + self.is_causal = not self.prefix_lm + + self.n_heads = config.n_heads + self.n_layers = config.n_layers + self.max_seq_len = config.max_seq_len + self.use_cache = config.use_cache + + self._attn_bias_initialized = False + self.attn_bias = None + self.attn_bias_shape = attn_bias_shape( + self.attn_impl, + self.n_heads, + self.max_seq_len, + self.alibi, + prefix_lm=self.prefix_lm, + causal=self.is_causal, + use_sequence_id=self.attn_uses_sequence_id + ) + + # Define layer norm type + if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): + norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) + raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).') + norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] + + # Init layers + self.wte = Embedding(config.vocab_size, config.d_model, dtype=config.dtype) + if not self.alibi: + self.wpe = Embedding(config.max_seq_len, config.d_model, dtype=config.dtype) + self.blocks = ModuleList([MPTBlock(config) for _ in range(config.n_layers)]) + self.norm_f = norm_class(config.d_model, dtype=config.dtype) + + def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None): + if not self._attn_bias_initialized: + if self.attn_bias_shape: + self.attn_bias = nn.emit(relax.op.zeros(self.attn_bias_shape, dtype=dtype)) + self.attn_bias = build_attn_bias( + self.attn_impl, self.attn_bias, self.n_heads, self.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max + ) + self._attn_bias_initialized = True + if self.attn_impl == 'flash': + return (self.attn_bias, attention_mask) + if self.attn_bias is not None: + self.attn_bias = nn.emit(relax.op.astype(self.attn_bias, dtype)) + attn_bias = self.attn_bias + if attention_mask is not None: + s_k = attention_mask.struct_info.shape[1] # seq_len + if attn_bias is None: + attn_bias = nn.emit(relax.op.zeros((1, 1, 1, s_k), dtype=dtype)) + else: + def attn_bias_te_slicing(x: te.Tensor, seq_len: tvm.tir.Var): + return te.compute( + shape=(x.shape[0], x.shape[1], x.shape[2], seq_len), + fcompute=lambda i, j, k, m: x[i, j, k, x.shape[3] - seq_len + m], + name="attn_bias_slice", + ) + + s_k_end = attn_bias.struct_info.shape[3] # config.max_seq_len = 2048 + # TODO(vchernov): it can not be calculated in relax + # _s_k = relax.op.maximum(relax.const(0), s_k_end - s_k) + # slicing attn_bias[:, :, :, _s_k:] + # Need to use _s_k instead of s_k_end - s_k (attn_bias.shape = [1, 32, 1, seq_len]) + # attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [3], [s_k_end - s_k], [s_k_end])) + attn_bias = nn.emit_te(attn_bias_te_slicing, attn_bias, s_k, primfunc_name_hint="attn_bias_slice") + min_val = get_type_min_val(attn_bias) + attn_mask = nn.emit(relax.op.logical_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) + attn_bias = nn.emit(relax.op.masked_fill(attn_bias, attn_mask, min_val)) + return (attn_bias, None) + + def forward( + self, + input_ids: relax.Expr, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_values: Optional[relax.Expr]=None, + attention_mask: Optional[relax.Expr]=None, + use_cache: Optional[bool]=None + ): + tok_emb = self.wte(input_ids) + if self.alibi: + x = tok_emb + # else: + # past_position = 0 + # if past_key_values is not None: + # if len(past_key_values) != self.n_layers: + # raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.n_layers!r}).') + # past_position = past_key_values[0][0].struct_info.shape[1] + # if self.attn_impl == 'torch': + # past_position = past_key_values[0][0].struct_info.shape[3] + # if S + past_position > self.max_seq_len: + # raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.max_seq_len}.') + # pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) + # if attention_mask is not None: + # pos_diff_to_slice = nn.emit(relax.op.cumsum(relax.op.astype(relax.op.bitwise_not(attention_mask), "int32"), axis=1)) + # dim1_len = pos_diff_to_slice.struct_info.shape[1] + # # slicing [:, past_position:] + # pos_diff = nn.emit(relax.op.strided_slice(pos_diff_to_slice, [1], [past_position], [dim1_len])) + # pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) + # pos_emb = self.wpe(pos) + # x = tok_emb + pos_emb + (attn_bias, attention_mask) = self._attn_bias(dtype=x.struct_info.dtype, attention_mask=attention_mask) + + # decoder layers + if past_key_values is not None: + next_decoder_cache = () + else: + next_decoder_cache = None + + for (b_idx, block) in enumerate(self.blocks): + past_key_value = (past_key_values[b_idx * 2], past_key_values[b_idx * 2 + 1]) if past_key_values is not None else None + x, key_value_cache = block( + x, + all_seq_len_shape=all_seq_len_shape, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=self.is_causal + ) + if past_key_values is not None: + next_decoder_cache += key_value_cache + x = self.norm_f(x) + if past_key_values is not None: + assert len(next_decoder_cache) == len(self.blocks) * 2 + return x, next_decoder_cache + + +class MPTForCausalLM(nn.Module): + def __init__(self, config: MPTConfig): + if not config.tie_word_embeddings: + raise ValueError('MPTForCausalLM only supports tied word embeddings') + self.transformer = MPTModel(config) + self.dtype = config.dtype + + self.use_cache = config.use_cache + + def prepare_attention_mask_for_generation(self, input_ids=None, src_len=None): + if src_len is not None: + seq_len = src_len.struct_info.values[0] + shape = R.shape([1, seq_len]) + return nn.emit(relax.op.ones(shape, dtype="bool")) + else: + return nn.emit(relax.op.astype(relax.op.ones_like(input_ids), dtype="bool")) + + def forward( + self, + input_ids: relax.Expr, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_values: Optional[relax.Expr]=None, + ): + attention_mask = self.prepare_attention_mask_for_generation(input_ids, all_seq_len_shape) + + logits, key_value_cache = self.transformer( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache = self.use_cache, + ) + + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = nn.emit_te(te_slicing, logits, primfunc_name_hint="slice") + + logits = nn.emit(relax.op.linear(logits, self.transformer.wte.weight)) + + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: + init_shape = relax.ShapeExpr( + ( + config.max_seq_len, + config.n_heads, + config.d_model // config.n_heads, + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.n_layers * 2): + caches.append( + bb.emit( + relax.Call( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_decoding_func_with_kv_cache(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: + pidx2pname: Dict[int, str] = {} + bsz = 1 + all_seq_len = tvm.tir.Var("n", "int64") + + with bb.function("decode"): + model = MPTForCausalLM(config) + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var( + "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) + ) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.n_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + + named_params = named_parameters(model) + for i, (name, param) in enumerate(named_params.items()): + pidx2pname[i] = name + assert param.same_as(params[i + 3]) + + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var("decode") + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + return pidx2pname + + +def create_decoding_func_wo_kv_cache(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: + pidx2pname: Dict[int, str] = {} + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + + with bb.function("decode"): + model = MPTForCausalLM(config) + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + + with bb.dataflow(): + logits, states = model(input_ids) + params = [ + input_ids, + ] + model.parameters() + + named_params = named_parameters(model) + for i, (name, param) in enumerate(named_params.items()): + pidx2pname[i] = name + assert param.same_as(params[i + 1]) + if states is None: + states = () + gv = bb.emit_output((logits, relax.Tuple(states))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var("decode") + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + return pidx2pname + + +def create_softmax_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, config.vocab_size), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args, hf_config): + model_name = args.model + assert model_name.startswith("mpt-") , f"Unsupported model name: {model_name}" + + model_path = args.model_path + dtype = args.quantization.model_dtype + + if args.max_seq_len is not None and args.max_seq_len > 0: + max_seq_len = args.max_seq_len + elif hf_config["max_seq_len"] > 0: + max_seq_len = hf_config["max_seq_len"] + else: + # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct + max_seq_len = 4096 + + hf_config.update({"max_seq_len": max_seq_len}) + hf_config.update({"use_cache": args.use_kv_cache}) + + config = MPTConfig(**hf_config, dtype=dtype) + + bb = relax.BlockBuilder() + pidx2pname = None + if config.use_cache: + create_kv_cache_func(bb, config) + pidx2pname = create_decoding_func_with_kv_cache(bb, config) + else: + pidx2pname = create_decoding_func_wo_kv_cache(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=-1, + stop_tokens=[0], + add_prefix_space=False, + ) + + mod = bb.get() + + pname2binname = load_torch_pname2binname_map( + model_path, set(pidx2pname.values()) + ) + + args.pidx2pname = pidx2pname + args.pname2binname = pname2binname + args.f_convert_pname_fwd = lambda pname: pname + args.f_convert_param_bkwd = lambda torch_pname, raw_param: [ + (torch_pname, raw_param.astype(dtype)) + ] + + return mod, [None] * len(pidx2pname) diff --git a/mlc_llm/relax_model/mpt/mpt_config.py b/mlc_llm/relax_model/mpt/mpt_config.py new file mode 100644 index 0000000000..18a5d7e6d5 --- /dev/null +++ b/mlc_llm/relax_model/mpt/mpt_config.py @@ -0,0 +1,167 @@ +""" +It is practicaly copy from https://huggingface.co/mosaicml/mpt-7b-instruct/blob/main/configuration_mpt.py +but `dtype` field is added +A HuggingFace-style model configuration. +""" +from typing import Dict, Optional, Union +from transformers import PretrainedConfig + +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8 +} +init_config_defaults: Dict = { + 'name': 'kaiming_normal_', + 'fan_mode': 'fan_in', + 'init_nonlinearity': 'relu', + 'init_div_is_residual': True, + 'emb_init_std': None, + 'emb_init_uniform_lim': None, + 'init_std': None, + 'init_gain': 0.0 +} + + +class MPTConfig(PretrainedConfig): + model_type = 'mpt' + + def __init__( + self, + d_model: int=2048, + n_heads: int=16, + n_layers: int=24, + expansion_ratio: int=4, + max_seq_len: int=2048, + vocab_size: int=50368, + resid_pdrop: float=0.0, + emb_pdrop: float=0.0, + learned_pos_emb: bool=True, + attn_config: Dict=attn_config_defaults, + init_device: str='cpu', + logit_scale: Optional[Union[float, str]]=None, + no_bias: bool=False, + verbose: int=0, + embedding_fraction: float=1.0, + norm_type: str='low_precision_layernorm', + use_cache: bool=False, + init_config: Dict=init_config_defaults, + dtype=None, + **kwargs + ): + """The MPT configuration class. + + Args: + d_model (int): The size of the embedding dimension of the model. + n_heads (int): The number of attention heads. + n_layers (int): The number of layers in the model. + expansion_ratio (int): The ratio of the up/down scale in the MLP. + max_seq_len (int): The maximum sequence length of the model. + vocab_size (int): The size of the vocabulary. + resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. + emb_pdrop (float): The dropout probability for the embedding layer. + learned_pos_emb (bool): Whether to use learned positional embeddings + attn_config (Dict): A dictionary used to configure the model's attention module: + attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention + attn_pdrop (float): The dropout probability for the attention layers. + attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. + qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. + clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to + this value. + softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, + use the default scale of ``1/sqrt(d_keys)``. + prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an + extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix + can attend to one another bi-directionally. Tokens outside the prefix use causal attention. + attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. + When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates + which sub-sequence each token belongs to. + Defaults to ``False`` meaning any provided `sequence_id` will be ignored. + alibi (bool): Whether to use the alibi bias instead of position embeddings. + alibi_bias_max (int): The maximum value of the alibi bias. + init_device (str): The device to use for parameter initialization. + logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. + no_bias (bool): Whether to use bias in all layers. + verbose (int): The verbosity level. 0 is silent. + embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. + norm_type (str): choose type of norm to use + multiquery_attention (bool): Whether to use multiquery attention implementation. + use_cache (bool): Whether or not the model should return the last key/values attentions + init_config (Dict): A dictionary used to configure the model initialization: + init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', + 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or + 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. + init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. + emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. + emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution + used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. + init_std (float): The standard deviation of the normal distribution used to initialize the model, + if using the baseline_ parameter initialization scheme. + init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. + fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. + init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. + --- + See llmfoundry.models.utils.param_init_fns.py for info on other param init config options + """ + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.attn_config = attn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.verbose = verbose + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + self.init_config = init_config + self.dtype = dtype + if 'name' in kwargs: + del kwargs['name'] + if 'loss_fn' in kwargs: + del kwargs['loss_fn'] + super().__init__(**kwargs) + self._validate_config() + + def _set_config_defaults(self, config, config_defaults): + for (k, v) in config_defaults.items(): + if k not in config: + config[k] = v + return config + + def _validate_config(self): + self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) + self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) + if self.d_model % self.n_heads != 0: + raise ValueError('d_model must be divisible by n_heads') + if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): + raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') + if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('alibi only implemented with torch and triton attention.') + if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') + if self.embedding_fraction > 1 or self.embedding_fraction <= 0: + raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') + if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': + raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + if self.init_config.get('name', None) is None: + raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") + if not self.learned_pos_emb and (not self.attn_config['alibi']): + raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') diff --git a/mlc_llm/relax_model/mpt/mpt_topology.txt b/mlc_llm/relax_model/mpt/mpt_topology.txt new file mode 100644 index 0000000000..e2d911d9c0 --- /dev/null +++ b/mlc_llm/relax_model/mpt/mpt_topology.txt @@ -0,0 +1,198 @@ +transformer.wte.weight torch.Size([50432, 4096]) + +transformer.blocks.0.norm_1.weight torch.Size([4096]) +transformer.blocks.0.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.0.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.0.norm_2.weight torch.Size([4096]) +transformer.blocks.0.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.0.ffn.down_proj.weight torch.Size([4096, 16384]) + +transformer.blocks.1.norm_1.weight torch.Size([4096]) +transformer.blocks.1.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.1.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.1.norm_2.weight torch.Size([4096]) +transformer.blocks.1.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.1.ffn.down_proj.weight torch.Size([4096, 16384]) + +transformer.blocks.2.norm_1.weight torch.Size([4096]) +transformer.blocks.2.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.2.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.2.norm_2.weight torch.Size([4096]) +transformer.blocks.2.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.2.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.3.norm_1.weight torch.Size([4096]) +transformer.blocks.3.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.3.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.3.norm_2.weight torch.Size([4096]) +transformer.blocks.3.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.3.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.4.norm_1.weight torch.Size([4096]) +transformer.blocks.4.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.4.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.4.norm_2.weight torch.Size([4096]) +transformer.blocks.4.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.4.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.5.norm_1.weight torch.Size([4096]) +transformer.blocks.5.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.5.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.5.norm_2.weight torch.Size([4096]) +transformer.blocks.5.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.5.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.6.norm_1.weight torch.Size([4096]) +transformer.blocks.6.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.6.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.6.norm_2.weight torch.Size([4096]) +transformer.blocks.6.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.6.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.7.norm_1.weight torch.Size([4096]) +transformer.blocks.7.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.7.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.7.norm_2.weight torch.Size([4096]) +transformer.blocks.7.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.7.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.8.norm_1.weight torch.Size([4096]) +transformer.blocks.8.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.8.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.8.norm_2.weight torch.Size([4096]) +transformer.blocks.8.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.8.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.9.norm_1.weight torch.Size([4096]) +transformer.blocks.9.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.9.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.9.norm_2.weight torch.Size([4096]) +transformer.blocks.9.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.9.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.10.norm_1.weight torch.Size([4096]) +transformer.blocks.10.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.10.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.10.norm_2.weight torch.Size([4096]) +transformer.blocks.10.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.10.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.11.norm_1.weight torch.Size([4096]) +transformer.blocks.11.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.11.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.11.norm_2.weight torch.Size([4096]) +transformer.blocks.11.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.11.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.12.norm_1.weight torch.Size([4096]) +transformer.blocks.12.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.12.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.12.norm_2.weight torch.Size([4096]) +transformer.blocks.12.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.12.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.13.norm_1.weight torch.Size([4096]) +transformer.blocks.13.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.13.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.13.norm_2.weight torch.Size([4096]) +transformer.blocks.13.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.13.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.14.norm_1.weight torch.Size([4096]) +transformer.blocks.14.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.14.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.14.norm_2.weight torch.Size([4096]) +transformer.blocks.14.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.14.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.15.norm_1.weight torch.Size([4096]) +transformer.blocks.15.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.15.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.15.norm_2.weight torch.Size([4096]) +transformer.blocks.15.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.15.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.16.norm_1.weight torch.Size([4096]) +transformer.blocks.16.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.16.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.16.norm_2.weight torch.Size([4096]) +transformer.blocks.16.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.16.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.17.norm_1.weight torch.Size([4096]) +transformer.blocks.17.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.17.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.17.norm_2.weight torch.Size([4096]) +transformer.blocks.17.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.17.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.18.norm_1.weight torch.Size([4096]) +transformer.blocks.18.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.18.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.18.norm_2.weight torch.Size([4096]) +transformer.blocks.18.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.18.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.19.norm_1.weight torch.Size([4096]) +transformer.blocks.19.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.19.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.19.norm_2.weight torch.Size([4096]) +transformer.blocks.19.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.19.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.20.norm_1.weight torch.Size([4096]) +transformer.blocks.20.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.20.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.20.norm_2.weight torch.Size([4096]) +transformer.blocks.20.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.20.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.21.norm_1.weight torch.Size([4096]) +transformer.blocks.21.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.21.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.21.norm_2.weight torch.Size([4096]) +transformer.blocks.21.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.21.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.22.norm_1.weight torch.Size([4096]) +transformer.blocks.22.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.22.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.22.norm_2.weight torch.Size([4096]) +transformer.blocks.22.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.22.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.23.norm_1.weight torch.Size([4096]) +transformer.blocks.23.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.23.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.23.norm_2.weight torch.Size([4096]) +transformer.blocks.23.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.23.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.24.norm_1.weight torch.Size([4096]) +transformer.blocks.24.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.24.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.24.norm_2.weight torch.Size([4096]) +transformer.blocks.24.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.24.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.25.norm_1.weight torch.Size([4096]) +transformer.blocks.25.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.25.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.25.norm_2.weight torch.Size([4096]) +transformer.blocks.25.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.25.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.26.norm_1.weight torch.Size([4096]) +transformer.blocks.26.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.26.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.26.norm_2.weight torch.Size([4096]) +transformer.blocks.26.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.26.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.27.norm_1.weight torch.Size([4096]) +transformer.blocks.27.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.27.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.27.norm_2.weight torch.Size([4096]) +transformer.blocks.27.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.27.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.28.norm_1.weight torch.Size([4096]) +transformer.blocks.28.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.28.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.28.norm_2.weight torch.Size([4096]) +transformer.blocks.28.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.28.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.29.norm_1.weight torch.Size([4096]) +transformer.blocks.29.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.29.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.29.norm_2.weight torch.Size([4096]) +transformer.blocks.29.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.29.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.30.norm_1.weight torch.Size([4096]) +transformer.blocks.30.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.30.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.30.norm_2.weight torch.Size([4096]) +transformer.blocks.30.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.30.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.31.norm_1.weight torch.Size([4096]) +transformer.blocks.31.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.31.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.31.norm_2.weight torch.Size([4096]) +transformer.blocks.31.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.31.ffn.down_proj.weight torch.Size([4096, 16384]) + +transformer.norm_f.weight torch.Size([4096]) \ No newline at end of file diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 7ae411015e..2508de68c9 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -51,7 +51,7 @@ class Quantization: ), } -supported_model_types = set(["llama", "gpt_neox", "gpt_bigcode", "moss", "rwkv"]) +supported_model_types = set(["llama", "gpt_neox", "gpt_bigcode", "moss", "rwkv", "mpt"]) def argparse_postproc_common(args: argparse.Namespace) -> None: @@ -78,6 +78,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "gorilla-": ("gorilla", "llama"), "starcoder": ("code_gpt", "gpt_bigcode"), "wizardcoder-": ("code_gpt", "gpt_bigcode"), + "mpt-": ("mpt", "mpt"), } model = args.model.lower() for prefix, (conv_template, model_category) in supported_model_prefix.items(): @@ -220,10 +221,15 @@ def get_item(i): torch_param_names = list(torch_params.keys()) for torch_param_name in torch_param_names: if str(torch_params[torch_param_name].dtype) == "torch.bfloat16": - # Convert to float32 first. - raw_param = ( - torch_params[torch_param_name].detach().cpu().float().numpy() - ) + if args.quantization.mode == "no" and args.quantization.model_dtype == "float16": + raw_param = ( + torch_params[torch_param_name].detach().cpu().to(dtype=torch.float16).numpy() + ) + else: + # Convert to float32 first. + raw_param = ( + torch_params[torch_param_name].detach().cpu().float().numpy() + ) else: raw_param = torch_params[torch_param_name].detach().cpu().numpy() del torch_params[torch_param_name]