Skip to content

Commit

Permalink
Add an interface for LLM runner (#6356)
Browse files Browse the repository at this point in the history
Summary:
In case we have custom LLM runners other than llama runner, we want to have a uniform interface

Pull Request resolved: #6356

Reviewed By: cccclai

Differential Revision: D64629696

Pulled By: kirklandsign

fbshipit-source-id: b9a670e47c4a73ae1180c85e9f11f0b23e3e4ed6
  • Loading branch information
kirklandsign authored and facebook-github-bot committed Oct 18, 2024
1 parent 8209bc1 commit 4d7b294
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/models/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <string>
#include <unordered_map>

#include <executorch/extension/llm/runner/irunner.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
Expand All @@ -26,7 +27,7 @@

namespace example {

class ET_EXPERIMENTAL Runner {
class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
public:
explicit Runner(
const std::string& model_path,
Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def define_common_targets():
],
exported_deps = [
"//executorch/backends/xnnpack:xnnpack_backend",
"//executorch/extension/llm/runner:irunner",
"//executorch/extension/llm/runner:stats",
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
Expand Down
50 changes: 50 additions & 0 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// An interface for LLM runners. Developers can create their own runner that
// implements their own load and generation logic to run the model.

#pragma once

#include <functional>
#include <string>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/module/module.h>

namespace executorch {
namespace extension {
namespace llm {

class ET_EXPERIMENTAL IRunner {
public:
virtual ~IRunner() = default;

// Checks if the model is loaded.
virtual bool is_loaded() const = 0;

// Load the model and tokenizer.
virtual ::executorch::runtime::Error load() = 0;

// Generate the output tokens.
virtual ::executorch::runtime::Error generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {},
bool echo = true,
bool warming = false) = 0;

// Stop the generation.
virtual void stop() = 0;
};

} // namespace llm
} // namespace extension
} // namespace executorch
10 changes: 10 additions & 0 deletions extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
runtime.cxx_library(
name = "irunner",
exported_headers = [
"irunner.h",
],
visibility = [
"@EXECUTORCH_CLIENTS",
],
)

runtime.cxx_library(
name = "stats",
exported_headers = [
Expand Down

0 comments on commit 4d7b294

Please sign in to comment.