Skip to content

Commit

Permalink
Add an interface for LLM runner
Browse files Browse the repository at this point in the history
In case we have custom LLM runners other than llama runner, we want to
have a uniform interface
  • Loading branch information
kirklandsign committed Oct 18, 2024
1 parent 2c43190 commit 81a2055
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/models/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <string>
#include <unordered_map>

#include <executorch/extension/llm/runner/runner_interface.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
Expand All @@ -26,7 +27,8 @@

namespace example {

class ET_EXPERIMENTAL Runner {
class ET_EXPERIMENTAL Runner
: public executorch::extension::llm::RunnerInterface {
public:
explicit Runner(
const std::string& model_path,
Expand Down
50 changes: 50 additions & 0 deletions extension/llm/runner/runner_interface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// An interface for LLM runners. Developers can create their own runner that
// implements their own load and generation logic to run the model.

#pragma once

#include <functional>
#include <string>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/module/module.h>

namespace executorch {
namespace extension {
namespace llm {

class ET_EXPERIMENTAL RunnerInterface {
public:
virtual ~RunnerInterface() = default;

// Checks if the model is loaded.
virtual bool is_loaded() const = 0;

// Load the model and tokenizer.
virtual ::executorch::runtime::Error load() = 0;

// Generate the output tokens.
virtual ::executorch::runtime::Error generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {},
bool echo = true,
bool warming = false) = 0;

// Stop the generation.
virtual void stop() = 0;
};

} // namespace llm
} // namespace extension
} // namespace executorch

0 comments on commit 81a2055

Please sign in to comment.