From 8bf19d111ded56d8baf8b9a2b2eded2c63088b19 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:27:12 -0700 Subject: [PATCH 1/8] MTK Android Llama Runner --- .../llm_helper/include/llama_runner_values.h | 32 ++ .../executor_runner/mtk_llama_runner.cpp | 333 ++++++++++++++++++ .../executor_runner/mtk_llama_runner.h | 69 ++++ 3 files changed, 434 insertions(+) create mode 100644 examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h create mode 100644 examples/mediatek/executor_runner/mtk_llama_runner.cpp create mode 100644 examples/mediatek/executor_runner/mtk_llama_runner.h diff --git a/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h b/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h new file mode 100644 index 0000000000..98cd8ab394 --- /dev/null +++ b/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h @@ -0,0 +1,32 @@ +#pragma once + +namespace torch::executor { + using llm_helper::LLMType; + + // Sizes + const size_t PROMPT_TOKEN_BATCH_SIZE = 128; + const size_t CACHE_SIZE = 512; + const size_t HIDDEN_SIZE = 4096; + const size_t NUM_HEAD = 32; + const size_t NUM_LAYER = 32; + const size_t MAX_TOKEN_LENGTH = 8192; + const double ROT_EMB_BASE = 500000; + + // Types + const LLMType MODEL_INPUT_TYPE = LLMType::FP32; + const LLMType MODEL_OUTPUT_TYPE = LLMType::FP32; + const LLMType CACHE_TYPE = LLMType::FP32; + const LLMType MASK_TYPE = LLMType::FP32; + const LLMType ROT_EMB_TYPE = LLMType::FP32; + + // Paths + const std::string TOKENIZER_PATH="/data/local/tmp/et-mtk/llama3/tokenizer.model"; + const std::string TOKEN_EMBEDDING_PATH="/data/local/tmp/et-mtk/llama3/embedding_llama3-8B-instruct_fp32.bin"; + + // Comma-Separated Paths + const std::string PROMPT_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_3.pte,"; + + // Comma-Separated Paths + const std::string GEN_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_3.pte,"; + +} // namespace torch::executor diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.cpp b/examples/mediatek/executor_runner/mtk_llama_runner.cpp new file mode 100644 index 0000000000..ea882cbb2f --- /dev/null +++ b/examples/mediatek/executor_runner/mtk_llama_runner.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) 2024 MediaTek Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* Copyright Statement: + * + * This software/firmware and related documentation ("MediaTek Software") are + * protected under relevant copyright laws. The information contained herein + * is confidential and proprietary to MediaTek Inc. and/or its licensors. + * Without the prior written permission of MediaTek inc. and/or its licensors, + * any reproduction, modification, use or disclosure of MediaTek Software, + * and information contained herein, in whole or in part, shall be strictly + * prohibited. + */ +/* MediaTek Inc. (C) 2024. All rights reserved. + * + * BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES + * THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE") + * RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON + * AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT. + * NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE + * SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR + * SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH + * THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY + * ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY + * THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK + * SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO + * RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN + * FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND + * CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER + * WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT + * ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER + * TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE. + * + * The following software/firmware and/or related documentation ("MediaTek + * Software") have been modified by MediaTek Inc. All revisions are subject to + * any receiver's applicable license agreements with MediaTek Inc. + */ + +#include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h" +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +// #include +#include +#include + +#include "llama_runner/ModelChunk.h" +#include "llama_runner/Utils.h" +#include "llama_runner/llm_helper/include/llm_types.h" +#include "llama_runner/llm_helper/include/llama_runner_values.h" + +static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate. +// Global BOS and EOS option for tokenization (encoding) +static constexpr int8_t kAddBos = 1; +static constexpr int8_t kAddEos = 0; + +using namespace torch::executor; +using namespace torch::executor::llm_helper; +using torch::executor::utils::Timer; + +MTKLlamaRunner::MTKLlamaRunner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature) + : modeloptions_(get_model_options()), + modelpaths_(get_model_paths()) { + runtime_init(); + ET_LOG( + Info, + "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init()."); +} + +Error MTKLlamaRunner::load() { + if (is_loaded()) { + return Error::Ok; + } + + // Load tokenizer + ET_LOG(Info, "Loading tokenizer."); + tokenizer_ = load_tokenizer(); + ET_LOG(Info, "Complete loading tokenizer."); + + // Load prompt model + runtime_ = std::make_unique(); + ET_LOG(Info, "Loading prompt model."); + runtime_->Initialize(modeloptions_, modelpaths_); + ET_LOG(Info, "Complete loading prompt model."); + + return Error::Ok; +} + +bool MTKLlamaRunner::is_loaded() const { + return tokenizer_ && runtime_; +} + +Error MTKLlamaRunner::generate( + const std::string& prompt, + int32_t seq_len, + std::function token_callback, + std::function stats_callback) { + + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + // Wrap the token_callback with print function + std::function wrapped_callback = + [token_callback](const std::string& piece) { + util::safe_printf(piece.c_str()); + fflush(stdout); + if (token_callback) { + token_callback(piece); + } + }; + + ET_LOG(Info, "Starting inference from MTKLlamaRunner"); + inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback); + ET_LOG(Info, "Completed inference from MTKLlamaRunner"); + + return Error::Ok; +} + +void MTKLlamaRunner::stop() { + if (is_loaded()) { + runtime_->Release(); + } else { + ET_LOG(Error, "Llama Runtime is not loaded, cannot stop"); + } +} + +LlamaModelOptions MTKLlamaRunner::get_model_options() { + LlamaModelOptions options = { + // Sizes + .prompt_token_batch_size = PROMPT_TOKEN_BATCH_SIZE, + .cache_size = CACHE_SIZE, + .hidden_size = HIDDEN_SIZE, + .num_head = NUM_HEAD, + .num_layer = NUM_LAYER, + .max_token_length = MAX_TOKEN_LENGTH, + .rot_emb_base = ROT_EMB_BASE, + + // Types + .model_input_type = MODEL_INPUT_TYPE, + .model_output_type = MODEL_OUTPUT_TYPE, + .cache_type = CACHE_TYPE, + .mask_type = MASK_TYPE, + .rot_emb_type = ROT_EMB_TYPE}; + ET_LOG(Info, "Completed get_model_options"); + return options; +} + +LlamaModelPaths MTKLlamaRunner::get_model_paths() { + LlamaModelPaths model_paths = { + .tokenizer_path = TOKENIZER_PATH, + .token_embedding_path = TOKEN_EMBEDDING_PATH, + .prompt_model_paths = utils::split(PROMPT_MODEL_PATHS, ','), + .gen_model_paths = utils::split(GEN_MODEL_PATHS, ',')}; + ET_LOG(Info, "Completed get_model_paths"); + return model_paths; +} + +Result MTKLlamaRunner::digest_prompt( + LlamaRuntime& llama_runtime, + const std::unique_ptr& tokenizer, + const std::vector input_tokens) { + const auto input_token_count = input_tokens.size(); + const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize(); + size_t cur_token_index = 0; + + Timer timer_digest_prompt([=](const auto elapsed_sec) { + // Ideal prompt size is a multiple of prompt batch size + const size_t ideal_prompt_size = + std::ceil(float(input_token_count) / prompt_token_batch_size) * + prompt_token_batch_size; + ET_LOG( + Info, + "Done analyzing prompt in %f sec (%f tok/s)", + elapsed_sec, + (float)ideal_prompt_size / elapsed_sec); + }); + + auto getNextTokens = [&]() { + const size_t num_tok_remain = input_token_count - cur_token_index; + const size_t remainder = num_tok_remain % prompt_token_batch_size; + const size_t num_new_tokens = + remainder ? remainder : prompt_token_batch_size; + const auto start = cur_token_index; + const auto end = start + num_new_tokens; + return std::vector( + input_tokens.begin() + start, input_tokens.begin() + end); + }; + + void* logits; + timer_digest_prompt.Start(); + while (cur_token_index < input_token_count) { + const auto next_tokens = getNextTokens(); + ET_LOG( + Debug, + "Digest next tokens (size=%zu), 1st tok=%lu", + next_tokens.size(), + next_tokens[0]); + logits = llama_runtime.Run(next_tokens); + cur_token_index += next_tokens.size(); + } + timer_digest_prompt.End(); + + const auto vocab_size = tokenizer->vocab_size(); + const auto logits_type = llama_runtime.GetModelOptions().model_output_type; + const auto first_output_token = + utils::argmax(logits_type, logits, vocab_size); + return first_output_token; +} + +Error MTKLlamaRunner::gen_response( + LlamaRuntime& llama_runtime, + const std::unique_ptr& tokenizer, + const uint64_t input_token, + std::function token_callback) { + Timer timer_model_swap( + [](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); }); + + // Swap to gen mode + timer_model_swap.Start(); + llama_runtime.SwapModel(1); + timer_model_swap.End(); + + size_t gen_tok_count = 0; + uint64_t prev_token = input_token; + uint64_t output_token = input_token; + + auto decode_res = tokenizer->decode(prev_token, output_token); + ET_CHECK_OR_RETURN_ERROR( + decode_res.ok(), + InvalidState, + "Tokenizer failed to decode first generated token: %lu", + output_token); + std::string full_response = std::move(decode_res.get()); + std::vector full_response_tokens = {input_token}; + + const auto vocab_size = tokenizer->vocab_size(); + const auto logits_type = llama_runtime.GetModelOptions().model_output_type; + + double gen_total_time_sec = 0; + Timer timer_gen_token( + [&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; }); + + // Print first output token + token_callback(full_response); + + while (gen_tok_count++ < MAX_RESPONSE && + llama_runtime.GetTokenIndex() < modeloptions_.max_token_length) { + timer_gen_token.Start(); + void* logits = llama_runtime.Run({output_token}); + timer_gen_token.End(); + + prev_token = output_token; + output_token = utils::argmax(logits_type, logits, vocab_size); + full_response_tokens.push_back(output_token); + + // Stop when output is EOS + if (output_token == tokenizer->eos_tok()) { + token_callback(""); + break; + } + auto decode_res = tokenizer->decode(prev_token, output_token); + ET_CHECK_OR_RETURN_ERROR( + decode_res.ok(), + InvalidState, + "Tokenizer failed to decode generated token %lu", + output_token); + const std::string tok_str = std::move(decode_res.get()); + full_response += tok_str; + token_callback(tok_str); + } + + std::cout << "\n\n[Generated Tokens]\n" + << utils::to_string(full_response_tokens) << std::endl; + + ET_LOG( + Info, + "Token generation speed: %f tok/s", + gen_tok_count / gen_total_time_sec); + + return Error::Ok; +} + +Error MTKLlamaRunner::inference( + LlamaRuntime& llama_runtime, + const std::unique_ptr& tokenizer, + const std::string& prompt, + std::function token_callback) { + // Tokenize input prompt + auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos); + ET_CHECK_OR_RETURN_ERROR( + encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt"); + const auto input_tokens = std::move(encode_res.get()); + + // Run prompt mode (pre-fill) + auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens); + ET_CHECK_OR_RETURN_ERROR( + prefill_res.ok(), InvalidState, "Failed to digest prompt"); + const auto first_output_token = prefill_res.get(); + + // run generation mode (decoding) + return gen_response(llama_runtime, tokenizer, first_output_token, token_callback); +} + +std::unique_ptr MTKLlamaRunner::load_tokenizer() { + std::unique_ptr tokenizer; + // Assumes that tokenizer type is Tiktoken + tokenizer = torch::executor::get_tiktoken_for_llama(); + tokenizer->load(modelpaths_.tokenizer_path); + return tokenizer; +} diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.h b/examples/mediatek/executor_runner/mtk_llama_runner.h new file mode 100644 index 0000000000..d9f85c2025 --- /dev/null +++ b/examples/mediatek/executor_runner/mtk_llama_runner.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llama_runner/LlamaConfig.h" +#include "llama_runner/LlamaRuntime.h" +using namespace torch::executor; +using Stats = ::executorch::llm::Stats; + +class MTKLlamaRunner { + public: + explicit MTKLlamaRunner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature = 0.8f); + + bool is_loaded() const; + Error load(); + Error generate( + const std::string& prompt, + int32_t seq_len = 128, + std::function token_callback = {}, + std::function stats_callback = {}); + void stop(); + + LlamaModelOptions get_model_options(); + LlamaModelPaths get_model_paths(); + Result digest_prompt( + LlamaRuntime& llama_runtime, + const std::unique_ptr& tokenizer, + const std::vector input_tokens); + Error gen_response( + LlamaRuntime& llama_runtime, + const std::unique_ptr& tokenizer, + const uint64_t input_token, + std::function token_callback); + Error inference( + LlamaRuntime& llama_runtime, + const std::unique_ptr& tokenizer, + const std::string& prompt, + std::function token_callback); + std::unique_ptr load_tokenizer(); + + + private: + // model + const torch::executor::LlamaModelOptions modeloptions_; + const torch::executor::LlamaModelPaths modelpaths_; + std::unique_ptr tokenizer_; + std::unique_ptr runtime_; +}; From 826d59dfa5e01df94cc42d5c2934432b61073c18 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Tue, 15 Oct 2024 09:19:07 -0700 Subject: [PATCH 2/8] Enable JNI with MTK Llama Runner core functions --- extension/android/jni/jni_layer_llama.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 1049b9da30..e6b5807a08 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -113,6 +114,7 @@ class ExecuTorchLlamaJni int model_type_category_; std::unique_ptr runner_; std::unique_ptr multi_modal_runner_; + std::unique_ptr mtk_llama_runner_; public: constexpr static auto kJavaDescriptor = @@ -120,6 +122,7 @@ class ExecuTorchLlamaJni constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; + constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, @@ -158,6 +161,11 @@ class ExecuTorchLlamaJni model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), temperature); + } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { + mtk_llama_runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + temperature); } } @@ -197,6 +205,12 @@ class ExecuTorchLlamaJni [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& result) { callback->onStats(result); }, echo); + } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) { + mtk_llama_runner_->generate( + prompt->toStdString(), + seq_len, + [callback](std::string result) { callback->onResult(result); }, + [callback](const Stats& result) { callback->onStats(result); }); } return 0; } @@ -286,6 +300,8 @@ class ExecuTorchLlamaJni multi_modal_runner_->stop(); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { runner_->stop(); + } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) { + mtk_llama_runner_->stop(); } } @@ -294,6 +310,8 @@ class ExecuTorchLlamaJni return static_cast(multi_modal_runner_->load()); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { return static_cast(runner_->load()); + } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) { + return static_cast(mtk_llama_runner_->load()); } return static_cast(Error::InvalidArgument); } From 38e88df3975c2072f00a98a394e8f61e9032ec7c Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Tue, 15 Oct 2024 09:29:49 -0700 Subject: [PATCH 3/8] Cmake to include mtk target source --- extension/android/CMakeLists.txt | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 8f0e67900c..21c25e1c9b 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -158,6 +158,26 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ${EXECUTORCH_ROOT}/examples/models/llama/runner ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner ) + + target_sources( + executorch_jni PRIVATE + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/mtk_llama_runner.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/MultiModelLoader.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/mask_builder.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/rotary_embedding.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/token_embedding.cpp + ) + target_include_directories( + executorch_jni PRIVATE + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/ + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner + ) + ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED) + SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB}/libneuron_buffer_allocator.so) + list(APPEND link_libraries neuron_backend libneuron_buffer_allocator) endif() target_include_directories( From 5f4e4a97a0eb12f56479d8d1628e9352785357ca Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Wed, 16 Oct 2024 01:30:48 -0700 Subject: [PATCH 4/8] namespace changes to runner and jni layer --- .../llm_helper/include/llama_runner_values.h | 16 +++++++++-- .../executor_runner/mtk_llama_runner.cpp | 28 +++++++++++-------- .../executor_runner/mtk_llama_runner.h | 12 ++++++-- extension/android/CMakeLists.txt | 2 +- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h b/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h index 98cd8ab394..bef4335e8e 100644 --- a/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h +++ b/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h @@ -1,7 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Contains values that are used by the mtk_llama_runner.cpp + #pragma once -namespace torch::executor { - using llm_helper::LLMType; +namespace mtk::vars { + using example::llm_helper::LLMType; // Sizes const size_t PROMPT_TOKEN_BATCH_SIZE = 128; @@ -29,4 +39,4 @@ namespace torch::executor { // Comma-Separated Paths const std::string GEN_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_3.pte,"; -} // namespace torch::executor +} // namespace mtk:vars diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.cpp b/examples/mediatek/executor_runner/mtk_llama_runner.cpp index ea882cbb2f..695812eb30 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.cpp +++ b/examples/mediatek/executor_runner/mtk_llama_runner.cpp @@ -73,9 +73,14 @@ static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate. static constexpr int8_t kAddBos = 1; static constexpr int8_t kAddEos = 0; -using namespace torch::executor; -using namespace torch::executor::llm_helper; -using torch::executor::utils::Timer; +using namespace example::llm_helper; +using example::utils::argmax; +using example::utils::split; +using example::utils::Timer; +using example::utils::to_string; +using namespace mtk::vars; + +namespace llm = ::executorch::extension::llm; MTKLlamaRunner::MTKLlamaRunner( const std::string& model_path, @@ -83,7 +88,7 @@ MTKLlamaRunner::MTKLlamaRunner( const float temperature) : modeloptions_(get_model_options()), modelpaths_(get_model_paths()) { - runtime_init(); + executorch::runtime::runtime_init(); ET_LOG( Info, "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init()."); @@ -125,7 +130,7 @@ Error MTKLlamaRunner::generate( // Wrap the token_callback with print function std::function wrapped_callback = [token_callback](const std::string& piece) { - util::safe_printf(piece.c_str()); + llm::safe_printf(piece.c_str()); fflush(stdout); if (token_callback) { token_callback(piece); @@ -172,8 +177,8 @@ LlamaModelPaths MTKLlamaRunner::get_model_paths() { LlamaModelPaths model_paths = { .tokenizer_path = TOKENIZER_PATH, .token_embedding_path = TOKEN_EMBEDDING_PATH, - .prompt_model_paths = utils::split(PROMPT_MODEL_PATHS, ','), - .gen_model_paths = utils::split(GEN_MODEL_PATHS, ',')}; + .prompt_model_paths = split(PROMPT_MODEL_PATHS, ','), + .gen_model_paths = split(GEN_MODEL_PATHS, ',')}; ET_LOG(Info, "Completed get_model_paths"); return model_paths; } @@ -225,8 +230,7 @@ Result MTKLlamaRunner::digest_prompt( const auto vocab_size = tokenizer->vocab_size(); const auto logits_type = llama_runtime.GetModelOptions().model_output_type; - const auto first_output_token = - utils::argmax(logits_type, logits, vocab_size); + const auto first_output_token = argmax(logits_type, logits, vocab_size); return first_output_token; } @@ -273,7 +277,7 @@ Error MTKLlamaRunner::gen_response( timer_gen_token.End(); prev_token = output_token; - output_token = utils::argmax(logits_type, logits, vocab_size); + output_token = argmax(logits_type, logits, vocab_size); full_response_tokens.push_back(output_token); // Stop when output is EOS @@ -293,7 +297,7 @@ Error MTKLlamaRunner::gen_response( } std::cout << "\n\n[Generated Tokens]\n" - << utils::to_string(full_response_tokens) << std::endl; + << to_string(full_response_tokens) << std::endl; ET_LOG( Info, @@ -327,7 +331,7 @@ Error MTKLlamaRunner::inference( std::unique_ptr MTKLlamaRunner::load_tokenizer() { std::unique_ptr tokenizer; // Assumes that tokenizer type is Tiktoken - tokenizer = torch::executor::get_tiktoken_for_llama(); + tokenizer = example::get_tiktoken_for_llama(); tokenizer->load(modelpaths_.tokenizer_path); return tokenizer; } diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.h b/examples/mediatek/executor_runner/mtk_llama_runner.h index d9f85c2025..e79a3b02ad 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.h +++ b/examples/mediatek/executor_runner/mtk_llama_runner.h @@ -22,9 +22,15 @@ #include "llama_runner/LlamaConfig.h" #include "llama_runner/LlamaRuntime.h" -using namespace torch::executor; using Stats = ::executorch::llm::Stats; +using example::LlamaModelOptions; +using example::LlamaModelPaths; +using example::LlamaRuntime; +using executorch::extension::llm::Tokenizer; +using executorch::runtime::Error; +using executorch::runtime::Result; + class MTKLlamaRunner { public: explicit MTKLlamaRunner( @@ -62,8 +68,8 @@ class MTKLlamaRunner { private: // model - const torch::executor::LlamaModelOptions modeloptions_; - const torch::executor::LlamaModelPaths modelpaths_; + const LlamaModelOptions modeloptions_; + const LlamaModelPaths modelpaths_; std::unique_ptr tokenizer_; std::unique_ptr runtime_; }; diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 21c25e1c9b..c7f61ff59b 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -176,7 +176,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner ) ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED) - SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB}/libneuron_buffer_allocator.so) + SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB}) list(APPEND link_libraries neuron_backend libneuron_buffer_allocator) endif() From 26acfc29e41d6ec5064bb77ed506b3c22750e51c Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:38:04 -0700 Subject: [PATCH 5/8] lintrunner formatting --- .../llm_helper/include/llama_runner_values.h | 62 ++++++++++--------- .../executor_runner/mtk_llama_runner.cpp | 29 +++++---- .../executor_runner/mtk_llama_runner.h | 9 ++- extension/android/jni/jni_layer_llama.cpp | 8 +-- 4 files changed, 55 insertions(+), 53 deletions(-) diff --git a/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h b/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h index bef4335e8e..098898f5c2 100644 --- a/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h +++ b/examples/mediatek/executor_runner/llama_runner/llm_helper/include/llama_runner_values.h @@ -11,32 +11,36 @@ #pragma once namespace mtk::vars { - using example::llm_helper::LLMType; - - // Sizes - const size_t PROMPT_TOKEN_BATCH_SIZE = 128; - const size_t CACHE_SIZE = 512; - const size_t HIDDEN_SIZE = 4096; - const size_t NUM_HEAD = 32; - const size_t NUM_LAYER = 32; - const size_t MAX_TOKEN_LENGTH = 8192; - const double ROT_EMB_BASE = 500000; - - // Types - const LLMType MODEL_INPUT_TYPE = LLMType::FP32; - const LLMType MODEL_OUTPUT_TYPE = LLMType::FP32; - const LLMType CACHE_TYPE = LLMType::FP32; - const LLMType MASK_TYPE = LLMType::FP32; - const LLMType ROT_EMB_TYPE = LLMType::FP32; - - // Paths - const std::string TOKENIZER_PATH="/data/local/tmp/et-mtk/llama3/tokenizer.model"; - const std::string TOKEN_EMBEDDING_PATH="/data/local/tmp/et-mtk/llama3/embedding_llama3-8B-instruct_fp32.bin"; - - // Comma-Separated Paths - const std::string PROMPT_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_3.pte,"; - - // Comma-Separated Paths - const std::string GEN_MODEL_PATHS="/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_3.pte,"; - -} // namespace mtk:vars +using example::llm_helper::LLMType; + +// Sizes +const size_t PROMPT_TOKEN_BATCH_SIZE = 128; +const size_t CACHE_SIZE = 512; +const size_t HIDDEN_SIZE = 4096; +const size_t NUM_HEAD = 32; +const size_t NUM_LAYER = 32; +const size_t MAX_TOKEN_LENGTH = 8192; +const double ROT_EMB_BASE = 500000; + +// Types +const LLMType MODEL_INPUT_TYPE = LLMType::FP32; +const LLMType MODEL_OUTPUT_TYPE = LLMType::FP32; +const LLMType CACHE_TYPE = LLMType::FP32; +const LLMType MASK_TYPE = LLMType::FP32; +const LLMType ROT_EMB_TYPE = LLMType::FP32; + +// Paths +const std::string TOKENIZER_PATH = + "/data/local/tmp/et-mtk/llama3/tokenizer.model"; +const std::string TOKEN_EMBEDDING_PATH = + "/data/local/tmp/et-mtk/llama3/embedding_llama3-8B-instruct_fp32.bin"; + +// Comma-Separated Paths +const std::string PROMPT_MODEL_PATHS = + "/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_128t512c_3.pte,"; + +// Comma-Separated Paths +const std::string GEN_MODEL_PATHS = + "/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_0.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_1.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_2.pte,/data/local/tmp/et-mtk/llama3/llama3-8B-instruct_A16W4_4_chunks_1t512c_3.pte,"; + +} // namespace mtk::vars diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.cpp b/examples/mediatek/executor_runner/mtk_llama_runner.cpp index 695812eb30..824bd8f3c8 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.cpp +++ b/examples/mediatek/executor_runner/mtk_llama_runner.cpp @@ -44,8 +44,8 @@ * any receiver's applicable license agreements with MediaTek Inc. */ -#include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h" #include +#include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h" #include #include @@ -65,8 +65,8 @@ #include "llama_runner/ModelChunk.h" #include "llama_runner/Utils.h" -#include "llama_runner/llm_helper/include/llm_types.h" #include "llama_runner/llm_helper/include/llama_runner_values.h" +#include "llama_runner/llm_helper/include/llm_types.h" static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate. // Global BOS and EOS option for tokenization (encoding) @@ -83,15 +83,14 @@ using namespace mtk::vars; namespace llm = ::executorch::extension::llm; MTKLlamaRunner::MTKLlamaRunner( - const std::string& model_path, - const std::string& tokenizer_path, - const float temperature) - : modeloptions_(get_model_options()), - modelpaths_(get_model_paths()) { + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature) + : modeloptions_(get_model_options()), modelpaths_(get_model_paths()) { executorch::runtime::runtime_init(); ET_LOG( - Info, - "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init()."); + Info, + "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init()."); } Error MTKLlamaRunner::load() { @@ -122,7 +121,6 @@ Error MTKLlamaRunner::generate( int32_t seq_len, std::function token_callback, std::function stats_callback) { - if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } @@ -137,9 +135,9 @@ Error MTKLlamaRunner::generate( } }; - ET_LOG(Info, "Starting inference from MTKLlamaRunner"); + ET_LOG(Info, "Starting inference from MTKLlamaRunner"); inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback); - ET_LOG(Info, "Completed inference from MTKLlamaRunner"); + ET_LOG(Info, "Completed inference from MTKLlamaRunner"); return Error::Ok; } @@ -169,7 +167,7 @@ LlamaModelOptions MTKLlamaRunner::get_model_options() { .cache_type = CACHE_TYPE, .mask_type = MASK_TYPE, .rot_emb_type = ROT_EMB_TYPE}; - ET_LOG(Info, "Completed get_model_options"); + ET_LOG(Info, "Completed get_model_options"); return options; } @@ -179,7 +177,7 @@ LlamaModelPaths MTKLlamaRunner::get_model_paths() { .token_embedding_path = TOKEN_EMBEDDING_PATH, .prompt_model_paths = split(PROMPT_MODEL_PATHS, ','), .gen_model_paths = split(GEN_MODEL_PATHS, ',')}; - ET_LOG(Info, "Completed get_model_paths"); + ET_LOG(Info, "Completed get_model_paths"); return model_paths; } @@ -325,7 +323,8 @@ Error MTKLlamaRunner::inference( const auto first_output_token = prefill_res.get(); // run generation mode (decoding) - return gen_response(llama_runtime, tokenizer, first_output_token, token_callback); + return gen_response( + llama_runtime, tokenizer, first_output_token, token_callback); } std::unique_ptr MTKLlamaRunner::load_tokenizer() { diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.h b/examples/mediatek/executor_runner/mtk_llama_runner.h index e79a3b02ad..292a91fe87 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.h +++ b/examples/mediatek/executor_runner/mtk_llama_runner.h @@ -11,14 +11,14 @@ #pragma once +#include +#include +#include +#include #include #include #include #include -#include -#include -#include -#include #include "llama_runner/LlamaConfig.h" #include "llama_runner/LlamaRuntime.h" @@ -65,7 +65,6 @@ class MTKLlamaRunner { std::function token_callback); std::unique_ptr load_tokenizer(); - private: // model const LlamaModelOptions modeloptions_; diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index e6b5807a08..db3dbd89f2 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -13,9 +13,9 @@ #include #include +#include #include #include -#include #include #include #include @@ -163,9 +163,9 @@ class ExecuTorchLlamaJni temperature); } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { mtk_llama_runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - temperature); + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + temperature); } } From 8a29b1a81fc7bd5fe21b7670124aad5bff308a63 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:06:22 -0700 Subject: [PATCH 6/8] protect cmakelist for extension under NEURON_BUFFER_ALLOCATOR_LIB flag --- extension/android/CMakeLists.txt | 40 +++++++++++++++++--------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index c7f61ff59b..9dd155db00 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -159,25 +159,27 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner ) - target_sources( - executorch_jni PRIVATE - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/mtk_llama_runner.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/MultiModelLoader.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/mask_builder.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/rotary_embedding.cpp - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/token_embedding.cpp - ) - target_include_directories( - executorch_jni PRIVATE - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/ - ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner - ) - ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED) - SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB}) - list(APPEND link_libraries neuron_backend libneuron_buffer_allocator) + if(NEURON_BUFFER_ALLOCATOR_LIB) + target_sources( + executorch_jni PRIVATE + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/mtk_llama_runner.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/ModelChunk.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/MultiModelLoader.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/mask_builder.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/rotary_embedding.cpp + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner/llm_helper/token_embedding.cpp + ) + target_include_directories( + executorch_jni PRIVATE + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/ + ${EXECUTORCH_ROOT}/examples/mediatek/executor_runner/llama_runner + ) + ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED) + SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB}) + list(APPEND link_libraries neuron_backend libneuron_buffer_allocator) + endif() endif() target_include_directories( From 66f550aff1150ff7e04c27cf86123bbc6342eb76 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 17 Oct 2024 20:54:00 -0700 Subject: [PATCH 7/8] llama2 -> llama --- examples/mediatek/executor_runner/mtk_llama_runner.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.h b/examples/mediatek/executor_runner/mtk_llama_runner.h index 292a91fe87..2123818f09 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.h +++ b/examples/mediatek/executor_runner/mtk_llama_runner.h @@ -11,7 +11,7 @@ #pragma once -#include +#include #include #include #include From 03c12c840f013d4710097cb5c5ca9f7f7238b271 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 18 Oct 2024 12:34:19 -0700 Subject: [PATCH 8/8] Use common LLM interface --- .../executor_runner/mtk_llama_runner.cpp | 4 ++- .../executor_runner/mtk_llama_runner.h | 8 ++++-- extension/android/CMakeLists.txt | 1 + extension/android/jni/jni_layer_llama.cpp | 25 ++++++++----------- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.cpp b/examples/mediatek/executor_runner/mtk_llama_runner.cpp index 824bd8f3c8..713e6679e4 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.cpp +++ b/examples/mediatek/executor_runner/mtk_llama_runner.cpp @@ -120,7 +120,9 @@ Error MTKLlamaRunner::generate( const std::string& prompt, int32_t seq_len, std::function token_callback, - std::function stats_callback) { + std::function stats_callback + bool, + bool) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } diff --git a/examples/mediatek/executor_runner/mtk_llama_runner.h b/examples/mediatek/executor_runner/mtk_llama_runner.h index 2123818f09..8240a6a45c 100644 --- a/examples/mediatek/executor_runner/mtk_llama_runner.h +++ b/examples/mediatek/executor_runner/mtk_llama_runner.h @@ -12,6 +12,7 @@ #pragma once #include +#include #include #include #include @@ -31,7 +32,8 @@ using executorch::extension::llm::Tokenizer; using executorch::runtime::Error; using executorch::runtime::Result; -class MTKLlamaRunner { +class MTKLlamaRunner + : public executorch::extension::llm::RunnerInterface { public: explicit MTKLlamaRunner( const std::string& model_path, @@ -44,7 +46,9 @@ class MTKLlamaRunner { const std::string& prompt, int32_t seq_len = 128, std::function token_callback = {}, - std::function stats_callback = {}); + std::function stats_callback = {}, + bool echo = true, + bool warming = false); void stop(); LlamaModelOptions get_model_options(); diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 9dd155db00..1fe2852c97 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -179,6 +179,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ADD_LIBRARY(libneuron_buffer_allocator SHARED IMPORTED) SET_PROPERTY(TARGET libneuron_buffer_allocator PROPERTY IMPORTED_LOCATION ${NEURON_BUFFER_ALLOCATOR_LIB}) list(APPEND link_libraries neuron_backend libneuron_buffer_allocator) + target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_MEDIATEK=1) endif() endif() diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index db3dbd89f2..54a2a5dba2 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -13,10 +13,10 @@ #include #include -#include #include #include #include +#include #include #include #include @@ -29,6 +29,10 @@ #include #include +#if defined(EXECUTORCH_BUILD_MEDIATEK) +#include +#endif + namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; @@ -112,9 +116,8 @@ class ExecuTorchLlamaJni private: friend HybridBase; int model_type_category_; - std::unique_ptr runner_; + std::unique_ptr runner_; std::unique_ptr multi_modal_runner_; - std::unique_ptr mtk_llama_runner_; public: constexpr static auto kJavaDescriptor = @@ -161,11 +164,15 @@ class ExecuTorchLlamaJni model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), temperature); +#if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { - mtk_llama_runner_ = std::make_unique( + runner_ = std::make_unique( model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), temperature); + // Interpret the model type as LLM + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; +#endif } } @@ -205,12 +212,6 @@ class ExecuTorchLlamaJni [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& result) { callback->onStats(result); }, echo); - } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) { - mtk_llama_runner_->generate( - prompt->toStdString(), - seq_len, - [callback](std::string result) { callback->onResult(result); }, - [callback](const Stats& result) { callback->onStats(result); }); } return 0; } @@ -300,8 +301,6 @@ class ExecuTorchLlamaJni multi_modal_runner_->stop(); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { runner_->stop(); - } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) { - mtk_llama_runner_->stop(); } } @@ -310,8 +309,6 @@ class ExecuTorchLlamaJni return static_cast(multi_modal_runner_->load()); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { return static_cast(runner_->load()); - } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) { - return static_cast(mtk_llama_runner_->load()); } return static_cast(Error::InvalidArgument); }