Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [MPT] Support MPT-7b-instruct model #460

Closed
wants to merge 114 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
9ce118c
add file for mpt-7b-instruct model
Jun 6, 2023
f922562
update build.py for mpt
Jun 6, 2023
14eb7e1
update utils
Jun 6, 2023
255ec66
add MPTConfig
Jun 6, 2023
eb09dc2
add get_model function like implementation for t5
Jun 6, 2023
5b21113
add MPTMLP, get Linear from Llama for a moment
Jun 6, 2023
a7371b8
add low-precision layer norm, need to correct it further
Jun 6, 2023
2502c2d
MPTBlock was implemented
Jun 6, 2023
cabd1a8
update MPTConfig by dtype
Jun 7, 2023
adb4e39
draft for attentions layers of MPT. some updates
Jun 7, 2023
bd7edfa
_reset_is_causal and attn_bias_shape methods were added. __init__ of …
Jun 7, 2023
feccf59
MPTForCausalLM was implemented on Relax, some TODOs were added
Jun 8, 2023
b9e5adb
rearrange from einops was replaced by relax ops
Jun 8, 2023
f86c101
reimplement scaled_multihead_dot_product_attention by relax
Jun 8, 2023
81f8712
replace torch.finfo
Jun 8, 2023
86e849b
finish scaled_multihead_dot_product_attention, some TODOs are still t…
Jun 8, 2023
0df2254
replace torch from flash_attn_fn
Jun 8, 2023
0c1c737
replace torch from triton_flash_attn_fn
Jun 9, 2023
893c51d
update MPTModel forward, replace all torch operations
Jun 9, 2023
db4aada
implement masked_fill by relax, replace torch masked_fill by it. remo…
Jun 9, 2023
dbf143a
fix max on dynamic values
Jun 9, 2023
1fab0ba
implement build_attn_bias with dependencies
Jun 9, 2023
6a40c7b
transfer of code for the sake of convenience
Jun 9, 2023
01d4b07
_attn_bias of MPTModel was implemented
Jun 9, 2023
b256f38
_apply_prefix_mask of MPTModel was implemented on relax
Jun 9, 2023
7d5b1d2
_apply_sequence_id of MPTModel was implemented on relax
Jun 9, 2023
f7b604f
fix layer norm
Jun 9, 2023
9a8a331
remove device
Jun 9, 2023
589af32
unroll flash_attn implementation using sources
Jun 9, 2023
4f5ef3f
slicings were reimplemented from python style to relax. corresponding…
Jun 10, 2023
1b58d3a
add draft for create_decoding_func. Fix two TODOs related to rank
Jun 12, 2023
3bbd62c
replace handmade masked_filled by relax op implemented in mlc-relax
Jun 12, 2023
cc3b128
replace torch logical_or by relax op implemented in mlc-relax
Jun 12, 2023
ce46093
replace torch logical_not by relax op implemented in mlc-relax
Jun 13, 2023
9daf6ac
fix TODO with index_select
Jun 13, 2023
8294423
small fixes
Jun 13, 2023
498bdb8
remove backwards
Jun 13, 2023
0f5f64f
zone different types of flash attention implementation
Jun 13, 2023
03426f6
commented flash attention implementations with types flash and triton
Jun 13, 2023
2b8bbbf
some fixes
Jun 14, 2023
b1e34bf
fix dtype in Linear layers
Jun 14, 2023
10c7cd6
fix dtype in layer norm
Jun 14, 2023
fcd8d7d
fix config using
Jun 14, 2023
30ce7b5
small fixes
Jun 14, 2023
983f523
tir.Cast was replaced by relax.op.astype
Jun 14, 2023
62f4a39
update downcast workaround for lplayernorm
Jun 14, 2023
7ad3f4f
more torch group were replaced by relax ops
Jun 14, 2023
aa187bb
correct rearrange
Jun 15, 2023
fb1138b
switch on model_path in get_model method. need to redo due to update …
Jun 15, 2023
4860be4
small fixes
Jun 15, 2023
b888dd4
replace matmul by linear for weight transposition from the box
Jun 15, 2023
be73cd9
fixes
Jun 15, 2023
f173728
check decode only for mpt models
Jun 15, 2023
9885b54
add desc for mpt-7b-instruct
Jun 16, 2023
29800ea
upstream weights mapping
Jun 16, 2023
5693108
fix assert check
Jun 16, 2023
a16753b
add custom f_convert_pname_fwd
Jun 16, 2023
2c091f6
once more update of f_convert_pname_fwd
Jun 16, 2023
fb0fc91
skip bias from weights
Jun 16, 2023
0b46f6f
add f_convert_param_bkwd
Jun 16, 2023
ffed1d5
try to fix bfloat16
Jun 16, 2023
7b61387
file structure for mpt model was refactored
Jun 19, 2023
fcdc915
add script to convert model from bfloat16 to float16
Jun 19, 2023
4126a8e
fix f_convert_param_bkwd
Jun 19, 2023
8467179
clean code for conversion script, add desc
Jun 19, 2023
896cdc9
workaround for lookup func
Jun 19, 2023
eb8ea82
add dummy create_kv_cache_func method to support mlc_chat_cli
Jun 20, 2023
4b8f4e8
debug log
Jun 20, 2023
fa6db2d
add create_softmax_func for MPT
Jun 20, 2023
8a72a9e
update transform
Jun 20, 2023
20ab52d
remove debug log
Jun 20, 2023
7499a90
add conversation for MPT
Jun 20, 2023
be4977f
remove kv_cache_func
Jun 21, 2023
9e58577
remove kv_cache from the list of funcs for mpt
Jun 21, 2023
89b216c
cast logits to float32 before softmax with temperature
Jun 21, 2023
e7d1ef8
debug log: check contiguous weight
Jun 21, 2023
bee5ffb
debug log for logits
Jun 21, 2023
c96d074
update readme and prepare_inputs_for_generation based on mpt model sp…
Jun 22, 2023
6b6fdf1
flash attn implementation was transferred to outside mpt model implem…
Jun 27, 2023
0be3602
correct import after rename file
Jun 28, 2023
51ce1ad
set temperature to zero by default for mpt model
Jul 3, 2023
a870ba1
transfer attention_mask preprocessing to decoder forward. update prep…
Jul 4, 2023
f44c4db
prepare_inputs_for_generation was fully transferred to forward pass
Jul 4, 2023
2188310
remove excess methods
Jul 4, 2023
d56c11d
remove excess methods once more
Jul 4, 2023
e811c34
init kv_cache only if need
Jun 21, 2023
e7c9350
do not use kv_cache during Forward if it is empty
Jun 21, 2023
399bce4
remove input_data from Decode_step due to it is not used and recalcul…
Jul 4, 2023
991f965
test fixes
Jul 4, 2023
f940301
unroll method from generate in README
Jul 4, 2023
6a54257
remove test logs. add PrintLogits method
Jul 5, 2023
6755d86
print logits after copy to cpu
Jul 5, 2023
aea1d83
print shape together with logits
Jul 6, 2023
9cc2494
test log
Jul 6, 2023
8880ed5
print decode input
Jul 6, 2023
8abc8c2
test log in prefill step
Jul 6, 2023
1b1458f
print intermediate tensors in topology to catch nan generation
Jul 6, 2023
bff07c3
print intermediate tensors in topology to catch nan generation: attn_…
Jul 6, 2023
69c344c
revert some debug logs
Jul 6, 2023
e116601
artificial error
Jul 6, 2023
c66a5a2
reimplement all remaining tir funcs to relax
Jul 6, 2023
63ae7c9
print only 10 values from logits
Jul 6, 2023
e68f4fb
revert test logits transform
Jul 6, 2023
0ed398b
remove debug logs and workaround
Jul 6, 2023
24b9a9c
return correct output
Jul 6, 2023
fb7870e
continue debug
Jul 6, 2023
418aadb
remove unneccessary parts from mpt topology. calculate query-key matm…
Jul 6, 2023
6ee82c7
create comparator
Jul 24, 2023
16392c9
update README
Jul 24, 2023
ddcea18
update mpt model file: fix layernorm, remove some TODOs, remove exces…
Jul 24, 2023
7b52b99
remove debug prints
Jul 24, 2023
a857583
update PrintNDArray method
Jul 24, 2023
81f92a5
support mlc-llm chat using with or without kv cache
Aug 25, 2023
bb00d1d
strong refactor based on vc/dev of mpt-like relax model to support us…
Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import mlc_llm
from mlc_llm import utils
from mlc_llm.relax_model import gpt_bigcode, gpt_neox, llama, moss, rwkv
from mlc_llm.relax_model import gpt_bigcode, gpt_neox, llama, moss, rwkv, mpt


def _parse_args():
Expand Down Expand Up @@ -57,6 +57,12 @@ def _parse_args():
default=1,
help="Whether to use previously pickled IRModule and skip trace.",
)
args.add_argument(
"--use-kv-cache",
action="store_false",
default=True,
help="Forcely replace use_cache hyperparameter in model config",
)
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--debug-load-script", action="store_true", default=False)
args.add_argument(
Expand Down Expand Up @@ -274,6 +280,20 @@ def mod_transform_before_build(
"get_metadata",
"reset_kv_cache",
]
elif ARGS.model.startswith("mpt-"):
if ARGS.use_kv_cache:
model_names = [
"decode",
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]
else:
model_names = [
"decode",
"softmax_with_temperature",
"get_metadata",
]
else:
model_names = [
"prefill",
Expand Down Expand Up @@ -337,6 +357,10 @@ def dump_default_mlc_chat_config(args):
config["shift_fill_factor"] = 0.3
config["tokenizer_files"] = utils.get_tokenizer_files(params_path)

# TODO(vchernov): create mechanism which gets default config prepared for specific model and covers this one
if args.model_category == "mpt":
config["temperature"] = 0.0

dump_path = os.path.join(params_path, "mlc-chat-config.json")
with open(dump_path, "w", encoding="utf-8") as outfile:
json.dump(config, outfile, indent=4)
Expand Down Expand Up @@ -407,6 +431,8 @@ def main():
mod, params = moss.get_model(ARGS, config)
elif ARGS.model_category == "rwkv":
mod, params = rwkv.get_model(ARGS, config)
elif ARGS.model_category == "mpt":
mod, params = mpt.get_model(ARGS, config)
else:
raise ValueError(f"Model {ARGS.model} not supported")
mod = mod_transform_before_build(mod, params, ARGS)
Expand Down
20 changes: 20 additions & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,25 @@ Conversation CodeGPT() {
return conv;
}

Conversation MPT() {
Conversation conv;
conv.name = "mpt";
conv.system = "";
conv.roles = {"", ""};
conv.messages = {};
conv.separator_style = SeparatorStyle::kSepRoleMsg;
conv.offset = 0;
conv.seps = {"\n"};
conv.role_msg_sep = "";
conv.role_empty_sep = "";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {0};
conv.stop_str = "<|endoftext|>";
conv.add_bos = false;
return conv;
}

} // namespace

using ConvFactory = Conversation (*)();
Expand All @@ -312,6 +331,7 @@ Conversation Conversation::FromTemplate(const std::string& name) {
{"moss", MOSS},
{"LM", VanillaLM},
{"code_gpt", CodeGPT},
{"mpt", MPT},
};
auto it = factory.find(name);
if (it == factory.end()) {
Expand Down
116 changes: 97 additions & 19 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class LLMChat {
friend class LLMChatModule;

public:
explicit LLMChat(DLDevice device) : device_(device) {}
explicit LLMChat(DLDevice device) : device_(device), debug_index_(0) {}

/*!
* \return Text describing runtime stats.
Expand Down Expand Up @@ -289,8 +289,11 @@ class LLMChat {
<< "Cannot find env function vm.builtin.attention_kv_cache_array_popn";
fkvcache_array_popn_ = *fkvcache_array_popn;

// Step 4. KV cache creation.
kv_cache_ = vm_->GetFunction("create_kv_cache")();
// Step 4. KV cache creation if need.
auto kv_cache_func = vm_->GetFunction("create_kv_cache");
if (kv_cache_func.defined()) {
kv_cache_ = kv_cache_func();
}

// Step 5. KV cache reset.
reset_kv_cache_func_ = vm_->GetFunction("reset_kv_cache");
Expand Down Expand Up @@ -508,6 +511,9 @@ class LLMChat {
}

std::vector<int32_t> prompt_tokens = this->GetInputTokens();
if (kv_cache_.empty()) {
full_output_ids_.insert(full_output_ids_.end(), prompt_tokens.begin(), prompt_tokens.end());
}
int64_t token_len = static_cast<int64_t>(prompt_tokens.size());
if (token_len == 0) return;

Expand All @@ -527,14 +533,18 @@ class LLMChat {
}

void DecodeStep() {
ICHECK(!output_ids_.empty());
int32_t last_token = output_ids_.back();
tvm::runtime::NDArray input_data = GetInputTokenNDArray({last_token});
std::vector<int32_t> input_tokens;
if (kv_cache_.empty()) {
ICHECK(!full_output_ids_.empty());
input_tokens = full_output_ids_;
} else {
ICHECK(!output_ids_.empty());
input_tokens = {output_ids_.back()};
}

auto tstart = std::chrono::high_resolution_clock::now();

NDArray logits_on_device = this->Forward({last_token}, total_seq_len_ + 1);
total_seq_len_ += 1;
NDArray logits_on_device = this->Forward(input_tokens, ++total_seq_len_);

int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_);

Expand Down Expand Up @@ -588,12 +598,7 @@ class LLMChat {
auto decoding_end = std::chrono::high_resolution_clock::now();

// print first few logits for eyeballs
std::ostringstream os;
for (int i = 0; i < 10; ++i) {
if (i != 0) os << ", ";
os << static_cast<float*>(logits_on_cpu_->data)[i];
}
LOG(INFO) << "logits[:10] =[" << os.str() << "]";
PrintNDArray(logits_on_cpu_, 10, "Logits");

double encoding_ms = static_cast<double>((decoding_start - encoding_start).count()) / 1e6;
double decoding_ms = static_cast<double>((decoding_end - decoding_start).count()) / 1e6;
Expand All @@ -602,6 +607,62 @@ class LLMChat {
<< "decoding-time=" << decoding_ms << "ms.";
}

NDArray getArrayToPrint(NDArray array) const {
ICHECK(array->data != nullptr) << "Array data is nullptr";
// Check that the data on CPU and copy if need
if (array->device.device_type != kDLCPU) {
NDArray array_cpu;
array_cpu = array.CopyTo(DLDevice{kDLCPU, 0});
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
return array_cpu;
} else {
return array;
}
}

void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor", bool to_save = false) {
NDArray array_cpu = getArrayToPrint(array);

size_t ndim = array_cpu->ndim;
int64_t numel = 1;
// Print shape and calculate numel
std::ostringstream os_shape;
for (size_t i = 0; i < ndim; ++i) {
if (i != 0) os_shape << ", ";
numel *= array_cpu->shape[i];
os_shape << array_cpu->shape[i];
}

std::string num_tag = std::to_string(num);
if (num == -1 || num >= numel) {
num = numel;
num_tag = "";
}
// TODO(vchernov): after test return LOG(INFO)
std::cout << tensor_tag << " shape = [" << os_shape.str() << "]" << std::endl;
// LOG(INFO) << tensor_tag << " shape = [" << os_shape.str() << "]";

// Print specified number of values from tensor
std::ostringstream os;
const float* p_data = static_cast<float*>(array_cpu->data);
for (int64_t i = 0; i < num; ++i) {
if (i != 0) os << ", ";
os << p_data[i];
}
// TODO(vchernov): after test return LOG(INFO)
std::cout << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]" << std::endl;
// LOG(INFO) << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]";

// Save to binary file
if (to_save) {
std::string file_name = "tensor_" + std::to_string(debug_index_++) + ".bin";
std::cout << tensor_tag << " is saved in " << file_name << std::endl;
std::ofstream fs(file_name, std::ios::out | std::ios::binary | std::ios::app);
fs.write(reinterpret_cast<const char*>(p_data), 4 * numel);
fs.close();
}
}

private:
picojson::value SerializeConfigToJSONValue() const {
picojson::object config;
Expand Down Expand Up @@ -656,6 +717,9 @@ class LLMChat {

if (!stop_triggered_) {
output_ids_.push_back(next_token);
if (kv_cache_.empty()) {
full_output_ids_.push_back(next_token);
}
appeared_token_ids_.insert(next_token);
}

Expand Down Expand Up @@ -699,10 +763,16 @@ class LLMChat {
ret = prefill_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_);
} else {
// running decode function when prefill is not available
for (int i = 0; i < input_tokens.size(); ++i) {
NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]});
int64_t pos = cur_pos + i + 1 - input_tokens.size();
ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_);
if (kv_cache_.empty()){
// Without kv_cache full sequence of tokens is used
NDArray input_data = this->GetInputTokenNDArray(input_tokens);
ret = decode_func_(input_data, params_);
} else {
for (int i = 0; i < input_tokens.size(); ++i) {
NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]});
int64_t pos = cur_pos + i + 1 - input_tokens.size();
ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_);
}
}
}
return Downcast<NDArray>(ret[0]);
Expand Down Expand Up @@ -763,7 +833,10 @@ class LLMChat {
// Clear kv cache
void ResetKVCache() { reset_kv_cache_func_(kv_cache_); }

void ProcessSystemPrompts() { this->PrefillStep(/*inp=*/"", /*append_conversation=*/false); }
void ProcessSystemPrompts() {
full_output_ids_.clear();
this->PrefillStep(/*inp=*/"", /*append_conversation=*/false);
}

// Utils
static double GetRandomNumber() {
Expand All @@ -783,6 +856,7 @@ class LLMChat {
ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined";
ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D";
ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch";

return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber());
}

Expand Down Expand Up @@ -816,6 +890,8 @@ class LLMChat {
double top_p_{0.95};
// output ids till now (refresh after encoding step)
std::vector<int32_t> output_ids_;
// output ids till now (sys and client prompt + generated by decoder)
std::vector<int32_t> full_output_ids_;
// appeared token ids till now (refresh after encoding step)
std::unordered_set<int32_t> appeared_token_ids_;
// output message till now (refresh after encoding step)
Expand Down Expand Up @@ -866,6 +942,8 @@ class LLMChat {
Array<ObjectRef> kv_cache_;
// Temp logits on cpu
NDArray logits_on_cpu_{nullptr};
// Debug index
int32_t debug_index_;
};

/*!
Expand Down
3 changes: 3 additions & 0 deletions mlc_llm/dispatch/dispatch_tir_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, model: str):
elif model == "rwkv":
lookup = None

elif model == "mpt":
lookup = None

else:
raise ValueError(f"Model {model} not supported")
self.lookup = lookup
Expand Down
1 change: 1 addition & 0 deletions mlc_llm/relax_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import llama
from .mpt import mpt
Loading