From 81a20551857a05775853cf03f881320ca41a7e79 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 18 Oct 2024 12:22:33 -0700 Subject: [PATCH] Add an interface for LLM runner In case we have custom LLM runners other than llama runner, we want to have a uniform interface --- examples/models/llama/runner/runner.h | 4 +- extension/llm/runner/runner_interface.h | 50 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 extension/llm/runner/runner_interface.h diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 4524aa81aa..cf4ed97c59 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -26,7 +27,8 @@ namespace example { -class ET_EXPERIMENTAL Runner { +class ET_EXPERIMENTAL Runner + : public executorch::extension::llm::RunnerInterface { public: explicit Runner( const std::string& model_path, diff --git a/extension/llm/runner/runner_interface.h b/extension/llm/runner/runner_interface.h new file mode 100644 index 0000000000..f28d9a34a4 --- /dev/null +++ b/extension/llm/runner/runner_interface.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// An interface for LLM runners. Developers can create their own runner that +// implements their own load and generation logic to run the model. + +#pragma once + +#include +#include + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +class ET_EXPERIMENTAL RunnerInterface { + public: + virtual ~RunnerInterface() = default; + + // Checks if the model is loaded. + virtual bool is_loaded() const = 0; + + // Load the model and tokenizer. + virtual ::executorch::runtime::Error load() = 0; + + // Generate the output tokens. + virtual ::executorch::runtime::Error generate( + const std::string& prompt, + int32_t seq_len, + std::function token_callback = {}, + std::function + stats_callback = {}, + bool echo = true, + bool warming = false) = 0; + + // Stop the generation. + virtual void stop() = 0; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch