Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MTK Android Llama Runner #6208

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Contains values that are used by the mtk_llama_runner.cpp

#pragma once
cmodi-meta marked this conversation as resolved.
Show resolved Hide resolved

namespace mtk::vars {
using example::llm_helper::LLMType;

// Sizes
const size_t PROMPT_TOKEN_BATCH_SIZE = 128;
const size_t CACHE_SIZE = 512;
const size_t HIDDEN_SIZE = 4096;
const size_t NUM_HEAD = 32;
const size_t NUM_LAYER = 32;
const size_t MAX_TOKEN_LENGTH = 8192;
const double ROT_EMB_BASE = 500000;

// Types
const LLMType MODEL_INPUT_TYPE = LLMType::FP32;
const LLMType MODEL_OUTPUT_TYPE = LLMType::FP32;
const LLMType CACHE_TYPE = LLMType::FP32;
const LLMType MASK_TYPE = LLMType::FP32;
const LLMType ROT_EMB_TYPE = LLMType::FP32;

// Paths
const std::string TOKENIZER_PATH =
"/data/local/tmp/et-mtk/llama3/tokenizer.model";
const std::string TOKEN_EMBEDDING_PATH =
"/data/local/tmp/et-mtk/llama3/embedding_llama3-8B-instruct_fp32.bin";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to fix those


// Comma-Separated Paths
const std::string PROMPT_MODEL_PATHS =
"/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_3.pte,";

// Comma-Separated Paths
const std::string GEN_MODEL_PATHS =
"/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_3.pte,";

} // namespace mtk::vars
338 changes: 338 additions & 0 deletions examples/mediatek/executor_runner/mtk_llama_runner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* Copyright (c) 2024 MediaTek Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/* Copyright Statement:
*
* This software/firmware and related documentation ("MediaTek Software") are
* protected under relevant copyright laws. The information contained herein
* is confidential and proprietary to MediaTek Inc. and/or its licensors.
* Without the prior written permission of MediaTek inc. and/or its licensors,
* any reproduction, modification, use or disclosure of MediaTek Software,
* and information contained herein, in whole or in part, shall be strictly
* prohibited.
*/
/* MediaTek Inc. (C) 2024. All rights reserved.
*
* BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES
* THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE")
* RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON
* AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT.
* NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE
* SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR
* SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH
* THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY
* ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY
* THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK
* SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO
* RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN
* FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND
* CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER
* WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT
* ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER
* TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE.
*
* The following software/firmware and/or related documentation ("MediaTek
* Software") have been modified by MediaTek Inc. All revisions are subject to
* any receiver's applicable license agreements with MediaTek Inc.
*/

#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
#include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h"

#include <ctime>
#include <iostream>
#include <memory>
#include <random>

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/profiler.h>
#include <executorch/runtime/platform/runtime.h>
// #include <executorch/util/util.h>
#include <executorch/extension/llm/runner/util.h>
#include <executorch/runtime/core/result.h>

#include "llama_runner/ModelChunk.h"
#include "llama_runner/Utils.h"
#include "llama_runner/llm_helper/include/llama_runner_values.h"
#include "llama_runner/llm_helper/include/llm_types.h"

static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate.
// Global BOS and EOS option for tokenization (encoding)
static constexpr int8_t kAddBos = 1;
static constexpr int8_t kAddEos = 0;

using namespace example::llm_helper;
using example::utils::argmax;
using example::utils::split;
using example::utils::Timer;
using example::utils::to_string;
using namespace mtk::vars;

namespace llm = ::executorch::extension::llm;

MTKLlamaRunner::MTKLlamaRunner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature)
: modeloptions_(get_model_options()), modelpaths_(get_model_paths()) {
executorch::runtime::runtime_init();
ET_LOG(
Info,
"Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init().");
}

Error MTKLlamaRunner::load() {
if (is_loaded()) {
return Error::Ok;
}

// Load tokenizer
ET_LOG(Info, "Loading tokenizer.");
tokenizer_ = load_tokenizer();
ET_LOG(Info, "Complete loading tokenizer.");

// Load prompt model
runtime_ = std::make_unique<LlamaRuntime>();
ET_LOG(Info, "Loading prompt model.");
runtime_->Initialize(modeloptions_, modelpaths_);
ET_LOG(Info, "Complete loading prompt model.");

return Error::Ok;
}

bool MTKLlamaRunner::is_loaded() const {
return tokenizer_ && runtime_;
}

Error MTKLlamaRunner::generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback
bool,
bool) {
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}

// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback](const std::string& piece) {
llm::safe_printf(piece.c_str());
fflush(stdout);
if (token_callback) {
token_callback(piece);
}
};

ET_LOG(Info, "Starting inference from MTKLlamaRunner");
inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback);
ET_LOG(Info, "Completed inference from MTKLlamaRunner");

return Error::Ok;
}

void MTKLlamaRunner::stop() {
if (is_loaded()) {
runtime_->Release();
} else {
ET_LOG(Error, "Llama Runtime is not loaded, cannot stop");
}
}

LlamaModelOptions MTKLlamaRunner::get_model_options() {
LlamaModelOptions options = {
// Sizes
.prompt_token_batch_size = PROMPT_TOKEN_BATCH_SIZE,
.cache_size = CACHE_SIZE,
.hidden_size = HIDDEN_SIZE,
.num_head = NUM_HEAD,
.num_layer = NUM_LAYER,
.max_token_length = MAX_TOKEN_LENGTH,
.rot_emb_base = ROT_EMB_BASE,

// Types
.model_input_type = MODEL_INPUT_TYPE,
.model_output_type = MODEL_OUTPUT_TYPE,
.cache_type = CACHE_TYPE,
.mask_type = MASK_TYPE,
.rot_emb_type = ROT_EMB_TYPE};
ET_LOG(Info, "Completed get_model_options");
return options;
}

LlamaModelPaths MTKLlamaRunner::get_model_paths() {
LlamaModelPaths model_paths = {
.tokenizer_path = TOKENIZER_PATH,
.token_embedding_path = TOKEN_EMBEDDING_PATH,
.prompt_model_paths = split(PROMPT_MODEL_PATHS, ','),
.gen_model_paths = split(GEN_MODEL_PATHS, ',')};
ET_LOG(Info, "Completed get_model_paths");
return model_paths;
}

Result<uint64_t> MTKLlamaRunner::digest_prompt(
LlamaRuntime& llama_runtime,
const std::unique_ptr<Tokenizer>& tokenizer,
const std::vector<uint64_t> input_tokens) {
const auto input_token_count = input_tokens.size();
const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize();
size_t cur_token_index = 0;

Timer timer_digest_prompt([=](const auto elapsed_sec) {
// Ideal prompt size is a multiple of prompt batch size
const size_t ideal_prompt_size =
std::ceil(float(input_token_count) / prompt_token_batch_size) *
prompt_token_batch_size;
ET_LOG(
Info,
"Done analyzing prompt in %f sec (%f tok/s)",
elapsed_sec,
(float)ideal_prompt_size / elapsed_sec);
});

auto getNextTokens = [&]() {
const size_t num_tok_remain = input_token_count - cur_token_index;
const size_t remainder = num_tok_remain % prompt_token_batch_size;
const size_t num_new_tokens =
remainder ? remainder : prompt_token_batch_size;
const auto start = cur_token_index;
const auto end = start + num_new_tokens;
return std::vector(
input_tokens.begin() + start, input_tokens.begin() + end);
};

void* logits;
timer_digest_prompt.Start();
while (cur_token_index < input_token_count) {
const auto next_tokens = getNextTokens();
ET_LOG(
Debug,
"Digest next tokens (size=%zu), 1st tok=%lu",
next_tokens.size(),
next_tokens[0]);
logits = llama_runtime.Run(next_tokens);
cur_token_index += next_tokens.size();
}
timer_digest_prompt.End();

const auto vocab_size = tokenizer->vocab_size();
const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
const auto first_output_token = argmax(logits_type, logits, vocab_size);
return first_output_token;
}

Error MTKLlamaRunner::gen_response(
LlamaRuntime& llama_runtime,
const std::unique_ptr<Tokenizer>& tokenizer,
const uint64_t input_token,
std::function<void(const std::string&)> token_callback) {
Timer timer_model_swap(
[](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); });

// Swap to gen mode
timer_model_swap.Start();
llama_runtime.SwapModel(1);
timer_model_swap.End();

size_t gen_tok_count = 0;
uint64_t prev_token = input_token;
uint64_t output_token = input_token;

auto decode_res = tokenizer->decode(prev_token, output_token);
ET_CHECK_OR_RETURN_ERROR(
decode_res.ok(),
InvalidState,
"Tokenizer failed to decode first generated token: %lu",
output_token);
std::string full_response = std::move(decode_res.get());
std::vector<uint64_t> full_response_tokens = {input_token};

const auto vocab_size = tokenizer->vocab_size();
const auto logits_type = llama_runtime.GetModelOptions().model_output_type;

double gen_total_time_sec = 0;
Timer timer_gen_token(
[&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; });

// Print first output token
token_callback(full_response);

while (gen_tok_count++ < MAX_RESPONSE &&
llama_runtime.GetTokenIndex() < modeloptions_.max_token_length) {
timer_gen_token.Start();
void* logits = llama_runtime.Run({output_token});
timer_gen_token.End();

prev_token = output_token;
output_token = argmax(logits_type, logits, vocab_size);
full_response_tokens.push_back(output_token);

// Stop when output is EOS
if (output_token == tokenizer->eos_tok()) {
token_callback("</eos>");
break;
}
auto decode_res = tokenizer->decode(prev_token, output_token);
ET_CHECK_OR_RETURN_ERROR(
decode_res.ok(),
InvalidState,
"Tokenizer failed to decode generated token %lu",
output_token);
const std::string tok_str = std::move(decode_res.get());
full_response += tok_str;
token_callback(tok_str);
}

std::cout << "\n\n[Generated Tokens]\n"
<< to_string(full_response_tokens) << std::endl;

ET_LOG(
Info,
"Token generation speed: %f tok/s",
gen_tok_count / gen_total_time_sec);

return Error::Ok;
}

Error MTKLlamaRunner::inference(
LlamaRuntime& llama_runtime,
const std::unique_ptr<Tokenizer>& tokenizer,
const std::string& prompt,
std::function<void(const std::string&)> token_callback) {
// Tokenize input prompt
auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos);
ET_CHECK_OR_RETURN_ERROR(
encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt");
const auto input_tokens = std::move(encode_res.get());

// Run prompt mode (pre-fill)
auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens);
ET_CHECK_OR_RETURN_ERROR(
prefill_res.ok(), InvalidState, "Failed to digest prompt");
const auto first_output_token = prefill_res.get();

// run generation mode (decoding)
return gen_response(
llama_runtime, tokenizer, first_output_token, token_callback);
}

std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer() {
std::unique_ptr<Tokenizer> tokenizer;
// Assumes that tokenizer type is Tiktoken
tokenizer = example::get_tiktoken_for_llama();
tokenizer->load(modelpaths_.tokenizer_path);
return tokenizer;
}
Loading
Loading