Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching previous prompts for later reuse, part 1 #3073

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions gpt4all-backend/include/gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ class LLModel {
};

struct PromptContext {
std::vector<int32_t> tokens; // current tokens in the context window
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_ctx = 0; // number of tokens possible in context window
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
Expand All @@ -151,8 +149,8 @@ class LLModel {
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const = 0;
virtual size_t saveState(std::span<uint8_t> dest) const = 0;
virtual size_t restoreState(std::span<const uint8_t> src) = 0;
virtual size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const = 0;
virtual size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) = 0;

// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
Expand Down Expand Up @@ -210,17 +208,23 @@ class LLModel {

void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }

virtual int32_t contextLength() const = 0;

protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
virtual std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special = false) = 0;
virtual std::vector<Token> tokenize(std::string_view str, bool special = false) = 0;
virtual bool isSpecialToken(Token id) const = 0;
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;

Expand Down Expand Up @@ -257,6 +261,10 @@ class LLModel {
bool allowContextShift,
PromptContext &promptCtx);

// Use only for testing and debugging. Throws if the backend was built with -DNDEBUG.
virtual std::span<const Token> inputTokens() const = 0;

protected:
Token m_tokenize_last_token = -1; // not serialized

friend class LLMImplementation;
Expand Down
40 changes: 28 additions & 12 deletions gpt4all-backend/include/gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ extern "C" {
*/
typedef void *llmodel_model;

/**
* A token.
*/
typedef int32_t token_t;

/**
* llmodel_prompt_context structure for holding the prompt context.
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
* raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined
* behavior.
*/
struct llmodel_prompt_context {
int32_t *tokens; // current tokens in the context window
size_t tokens_size; // the size of the raw tokens vector
int32_t n_past; // number of tokens in past conversation
int32_t n_ctx; // number of tokens possible in context window
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
Expand Down Expand Up @@ -141,27 +143,41 @@ bool llmodel_isModelLoaded(llmodel_model model);
* @param model A pointer to the llmodel_model instance.
* @return the size in bytes of the internal state of the model
*/
uint64_t llmodel_get_state_size(llmodel_model model);
uint64_t llmodel_state_get_size(llmodel_model model);

/**
* Saves the internal state of the model to the specified destination address.
* Saves the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @param size The size of the destination buffer.
* @return the number of bytes copied, or zero on error.
* @param state Where to store the state. This must be a buffer of at least llmodel_state_get_size() bytes.
* @param state_size The size of the destination for the state.
* @param input_tokens_out Where to store the address of the token cache state. This is dynamically allocated and must
* be freed with llmodel_state_free_input_tokens.
* @param n_input_tokens Where to store the size of the token cache state.
* @return The number of bytes copied. On error, zero is returned, the token cache is set to NULL, and the token cache
* size is set to zero.
*/
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens);

/**
* Frees the temporary token cache buffer created by a call to llmodel_state_get_data().
* @param input_tokens The token cache buffer.
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size);
void llmodel_state_free_input_tokens(token_t *input_tokens);

/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the state data.
* @param size The size of the source data.
* @param state A pointer to the state data.
* @param state_size The size of the state data.
* @param input_tokens The token cache associated with the saved state.
* @param n_input_tokens The number of tokens in input_tokens.
* @return The number of bytes read, or zero on error.
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size);
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens);

/**
* Generate a response using the model.
Expand Down
77 changes: 68 additions & 9 deletions gpt4all-backend/src/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ struct LLamaPrivate {
int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
std::vector<LLModel::Token> inputTokens;

llama_model *model = nullptr;
llama_context *ctx = nullptr;
Expand Down Expand Up @@ -501,17 +502,23 @@ size_t LLamaModel::stateSize() const
return llama_state_get_size(d_ptr->ctx);
}

size_t LLamaModel::saveState(std::span<uint8_t> dest) const
size_t LLamaModel::saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const
{
return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size());
size_t bytesWritten = llama_state_get_data(d_ptr->ctx, stateOut.data(), stateOut.size());
if (bytesWritten)
inputTokensOut.assign(d_ptr->inputTokens.begin(), d_ptr->inputTokens.end());
return bytesWritten;
}

size_t LLamaModel::restoreState(std::span<const uint8_t> src)
size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens)
{
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
size_t bytesRead = llama_state_set_data(d_ptr->ctx, state.data(), state.size());
if (bytesRead)
d_ptr->inputTokens.assign(inputTokens.begin(), inputTokens.end());
return bytesRead;
}

std::vector<LLModel::Token> LLamaModel::tokenize(PromptContext &ctx, std::string_view str, bool special)
std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
{
bool atStart = m_tokenize_last_token == -1;
bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token);
Expand Down Expand Up @@ -594,7 +601,7 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}

bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);

Expand Down Expand Up @@ -625,7 +632,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));

assert(n_discard > 0);
if (n_discard <= 0)
Expand All @@ -638,15 +645,67 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);

promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size();
}

int32_t LLamaModel::contextLength() const
{
return llama_n_ctx(d_ptr->ctx);
}

int32_t LLamaModel::inputLength() const
{
return d_ptr->inputTokens.size();
}

void LLamaModel::setTokenizeInputPosition(int32_t pos)
{
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}

auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}

// find common prefix
auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos;
}
// truncate token cache to end at n_past
if (pos < d_ptr->inputTokens.size())
d_ptr->inputTokens.resize(pos);
// tell the caller to ignore the tokens between [begin, inputIt)
ctx.n_past = int32_t(pos);
return inputIt;
}

void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
{
d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
}

auto LLamaModel::inputTokens() const -> std::span<const Token>
{
#ifdef NDEBUG
return d_ptr->inputTokens;
#else
throw std::runtime_error("attempt to call debug-only inputTokens() in release build");
#endif
}

const std::vector<LLModel::Token> &LLamaModel::endTokens() const
{
return d_ptr->end_tokens;
Expand Down
27 changes: 18 additions & 9 deletions gpt4all-backend/src/llamamodel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class LLamaModel : public LLModel {
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override;
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired = 0) const override;
Expand All @@ -48,20 +48,21 @@ class LLamaModel : public LLModel {
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;

private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
int32_t contextLength() const override;

protected:
std::vector<Token> tokenize(PromptContext &ctx, std::string_view str, bool special) override;
std::vector<Token> tokenize(std::string_view str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override;
void appendInputToken(PromptContext &ctx, Token tok) override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
Expand All @@ -70,6 +71,14 @@ class LLamaModel : public LLModel {
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
const EmbModelSpec *spec);

// Use only for testing and debugging.
std::span<const Token> inputTokens() const override;

private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
};

#endif // LLAMAMODEL_H
41 changes: 29 additions & 12 deletions gpt4all-backend/src/llmodel_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
#include <iostream>
#include <memory>
#include <optional>
#include <ranges>
#include <string>
#include <string_view>
#include <vector>
#include <span>

namespace ranges = std::ranges;

static_assert(sizeof(token_t) == sizeof(LLModel::Token));

struct LLModelWrapper {
LLModel *llModel = nullptr;
Expand Down Expand Up @@ -85,22 +91,40 @@ bool llmodel_isModelLoaded(llmodel_model model)
return wrapper->llModel->isModelLoaded();
}

uint64_t llmodel_get_state_size(llmodel_model model)
uint64_t llmodel_state_get_size(llmodel_model model)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->stateSize();
}

uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size)
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->saveState({dest, size_t(size)});
std::vector<LLModel::Token> inputTokens;
auto bytesWritten = wrapper->llModel->saveState({state_out, size_t(state_size)}, inputTokens);
if (bytesWritten) {
auto *buf = new LLModel::Token[inputTokens.size()];
ranges::copy(inputTokens, buf);
*input_tokens_out = buf;
*n_input_tokens = uint64_t(inputTokens.size());
} else {
*input_tokens_out = nullptr;
*n_input_tokens = 0;
}
return bytesWritten;
}

uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, uint64_t size)
void llmodel_state_free_input_tokens(LLModel::Token *input_tokens)
{
delete[] input_tokens;
}

uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens)
{
auto *wrapper = static_cast<LLModelWrapper *>(model);
return wrapper->llModel->restoreState({src, size_t(size)});
return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)});
}

void llmodel_prompt(llmodel_model model, const char *prompt,
Expand All @@ -120,7 +144,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt,

// Copy the C prompt context
wrapper->promptContext.n_past = ctx->n_past;
wrapper->promptContext.n_ctx = ctx->n_ctx;
wrapper->promptContext.n_predict = ctx->n_predict;
wrapper->promptContext.top_k = ctx->top_k;
wrapper->promptContext.top_p = ctx->top_p;
Expand All @@ -136,14 +159,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
wrapper->promptContext, special,
fake_reply ? std::make_optional<std::string_view>(fake_reply) : std::nullopt);

// Update the C context by giving access to the wrappers raw pointers to std::vector data
// which involves no copies
ctx->tokens = wrapper->promptContext.tokens.data();
ctx->tokens_size = wrapper->promptContext.tokens.size();

// Update the rest of the C prompt context
ctx->n_past = wrapper->promptContext.n_past;
ctx->n_ctx = wrapper->promptContext.n_ctx;
ctx->n_predict = wrapper->promptContext.n_predict;
ctx->top_k = wrapper->promptContext.top_k;
ctx->top_p = wrapper->promptContext.top_p;
Expand Down
Loading