Skip to content

Commit

Permalink
Fix format of Ollama setting keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Dec 14, 2023
1 parent 7681666 commit abe43bd
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 99 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"Lmnt",
"logit",
"Millicents",
"mirostat",
"modelfusion",
"Ollama",
"openai",
Expand Down
160 changes: 61 additions & 99 deletions src/model-provider/ollama/OllamaTextGenerationModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,45 @@ export interface OllamaTextGenerationModelSettings<
* A lower learning rate will result in slower adjustments,
* while a higher learning rate will make the algorithm more responsive. (Default: 0.1)
*/
mirostat_eta?: number;
mirostatEta?: number;

/**
* Controls the balance between coherence and diversity of the output.
* A lower value will result in more focused and coherent text. (Default: 5.0)
*/
mirostat_tau?: number;
mirostatTau?: number;

/**
* The number of GQA groups in the transformer layer. Required for some models,
* for example it is 8 for llama2:70b
*/
num_gqa?: number;
numGqa?: number;

/**
* The number of layers to send to the GPU(s). On macOS it defaults to 1 to
* enable metal support, 0 to disable.
*/
num_gpu?: number;
numGpu?: number;

/**
* Sets the number of threads to use during computation. By default, Ollama will
* detect this for optimal performance. It is recommended to set this value to the
* number of physical CPU cores your system has (as opposed to the logical number of cores).
*/
num_threads?: number;
numThreads?: number;

/**
* Sets how far back for the model to look back to prevent repetition.
* (Default: 64, 0 = disabled, -1 = num_ctx)
*/
repeat_last_n?: number;
repeatLastN?: number;

/**
* Sets how strongly to penalize repetitions. A higher value (e.g., 1.5)
* will penalize repetitions more strongly, while a lower value (e.g., 0.9)
* will be more lenient. (Default: 1.1)
*/
repeat_penalty?: number;
repeatPenalty?: number;

/**
* Sets the random number seed to use for generation. Setting this to a
Expand All @@ -117,21 +117,21 @@ export interface OllamaTextGenerationModelSettings<
* from the output. A higher value (e.g., 2.0) will reduce the impact more,
* while a value of 1.0 disables this setting. (default: 1)
*/
tfs_z?: number;
tfsZ?: number;

/**
* Reduces the probability of generating nonsense. A higher value (e.g. 100)
* will give more diverse answers, while a lower value (e.g. 10) will be more
* conservative. (Default: 40)
*/
top_k?: number;
topK?: number;

/**
* Works together with top-k. A higher value (e.g., 0.95) will lead to more
* diverse text, while a lower value (e.g., 0.5) will generate more focused
* and conservative text. (Default: 0.9)
*/
top_p?: number;
topP?: number;

/**
* When set to true, no formatting will be applied to the prompt and no context
Expand Down Expand Up @@ -196,17 +196,49 @@ export class OllamaTextGenerationModel<
responseFormat: OllamaTextGenerationResponseFormatType<RESPONSE>;
} & FunctionOptions
): Promise<RESPONSE> {
const { responseFormat } = options;
const api = this.settings.api ?? new OllamaApiConfiguration();
const abortSignal = options.run?.abortSignal;

return callWithRetryAndThrottle({
retry: this.settings.api?.retry,
throttle: this.settings.api?.throttle,
retry: api.retry,
throttle: api.throttle,
call: async () =>
callOllamaTextGenerationAPI({
...this.settings,

// other
abortSignal: options.run?.abortSignal,
prompt,
responseFormat: options.responseFormat,
postJsonToApi({
url: api.assembleUrl(`/api/generate`),
headers: api.headers,
body: {
stream: responseFormat.stream,
model: this.settings.model,
prompt: prompt.prompt,
images: prompt.images,
format: this.settings.format,
options: {
mirostat: this.settings.mirostat,
mirostat_eta: this.settings.mirostatEta,
mirostat_tau: this.settings.mirostatTau,
num_ctx: this.settings.contextWindowSize,
num_gpu: this.settings.numGpu,
num_gqa: this.settings.numGqa,
num_predict: this.settings.maxCompletionTokens,
num_threads: this.settings.numThreads,
repeat_last_n: this.settings.repeatLastN,
repeat_penalty: this.settings.repeatPenalty,
seed: this.settings.seed,
stop: this.settings.stopSequences,
temperature: this.settings.temperature,
tfs_z: this.settings.tfsZ,
top_k: this.settings.topK,
top_p: this.settings.topP,
},
system: this.settings.system,
template: this.settings.template,
context: this.settings.context,
raw: this.settings.raw,
},
failedResponseHandler: failedOllamaCallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
}),
});
}
Expand All @@ -220,17 +252,17 @@ export class OllamaTextGenerationModel<
"contextWindowSize",
"temperature",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_gqa",
"num_gpu",
"num_threads",
"repeat_last_n",
"repeat_penalty",
"mirostatEta",
"mirostatTau",
"numGqa",
"numGpu",
"numThreads",
"repeatLastN",
"repeatPenalty",
"seed",
"tfs_z",
"top_k",
"top_p",
"tfsZ",
"topK",
"topP",
"system",
"template",
"context",
Expand Down Expand Up @@ -379,76 +411,6 @@ const ollamaTextStreamingResponseSchema = new ZodSchema(
])
);

async function callOllamaTextGenerationAPI<RESPONSE>({
api = new OllamaApiConfiguration(),
abortSignal,
responseFormat,
prompt,
model,
format,
contextWindowSize,
maxCompletionTokens,
mirostat,
mirostat_eta,
mirostat_tau,
num_gpu,
num_gqa,
num_threads,
repeat_last_n,
repeat_penalty,
seed,
stopSequences,
temperature,
tfs_z,
top_k,
top_p,
system,
template,
context,
raw,
}: OllamaTextGenerationModelSettings<number> & {
abortSignal?: AbortSignal;
responseFormat: OllamaTextGenerationResponseFormatType<RESPONSE>;
prompt: OllamaTextGenerationPrompt;
}): Promise<RESPONSE> {
return postJsonToApi({
url: api.assembleUrl(`/api/generate`),
headers: api.headers,
body: {
stream: responseFormat.stream,
model,
prompt: prompt.prompt,
images: prompt.images,
format,
options: {
mirostat,
mirostat_eta,
mirostat_tau,
num_ctx: contextWindowSize,
num_gpu,
num_gqa,
num_predict: maxCompletionTokens,
num_threads,
repeat_last_n,
repeat_penalty,
seed,
stop: stopSequences,
temperature,
tfs_z,
top_k,
top_p,
},
system,
template,
context,
raw,
},
failedResponseHandler: failedOllamaCallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
});
}

export type OllamaTextGenerationDelta = {
content: string;
isComplete: boolean;
Expand Down

0 comments on commit abe43bd

Please sign in to comment.