diff --git a/examples/README.md b/examples/README.md
index c36589f7..ca27bebf 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -25,6 +25,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com)
+- [multi-models](multi-models): demonstrates loading multiple models in a single engine concurrently
#### Advanced OpenAI API Capabilities
diff --git a/examples/multi-models/README.md b/examples/multi-models/README.md
new file mode 100644
index 00000000..7450aad8
--- /dev/null
+++ b/examples/multi-models/README.md
@@ -0,0 +1,14 @@
+# WebLLM Get Started App
+
+This folder provides a minimum demo to show WebLLM API in a webapp setting.
+To try it out, you can do the following steps under this folder
+
+```bash
+npm install
+npm start
+```
+
+Note if you would like to hack WebLLM core package.
+You can change web-llm dependencies as `"file:../.."`, and follow the build from source
+instruction in the project to build webllm locally. This option is only recommended
+if you would like to hack WebLLM core package.
diff --git a/examples/multi-models/package.json b/examples/multi-models/package.json
new file mode 100644
index 00000000..5d7fa7c3
--- /dev/null
+++ b/examples/multi-models/package.json
@@ -0,0 +1,20 @@
+{
+ "name": "get-started",
+ "version": "0.1.0",
+ "private": true,
+ "scripts": {
+ "start": "parcel src/multi_models.html --port 8888",
+ "build": "parcel build src/multi_models.html --dist-dir lib"
+ },
+ "devDependencies": {
+ "buffer": "^5.7.1",
+ "parcel": "^2.8.3",
+ "process": "^0.11.10",
+ "tslib": "^2.3.1",
+ "typescript": "^4.9.5",
+ "url": "^0.11.3"
+ },
+ "dependencies": {
+ "@mlc-ai/web-llm": "file:../.."
+ }
+}
diff --git a/examples/multi-models/src/multi_models.html b/examples/multi-models/src/multi_models.html
new file mode 100644
index 00000000..1de9c00b
--- /dev/null
+++ b/examples/multi-models/src/multi_models.html
@@ -0,0 +1,23 @@
+
+
+
+
+
WebLLM Test Page
+ Open console to see output
+
+
+
+
+
Prompt
+
+
+
Response
+
+
+
+
+
+
+
diff --git a/examples/multi-models/src/multi_models.ts b/examples/multi-models/src/multi_models.ts
new file mode 100644
index 00000000..afafe684
--- /dev/null
+++ b/examples/multi-models/src/multi_models.ts
@@ -0,0 +1,76 @@
+import * as webllm from "@mlc-ai/web-llm";
+
+function setLabel(id: string, text: string) {
+ const label = document.getElementById(id);
+ if (label == null) {
+ throw Error("Cannot find label " + id);
+ }
+ label.innerText = text;
+}
+
+/**
+ * Chat completion (OpenAI style) with streaming, with two models in the pipeline.
+ */
+async function mainStreaming() {
+ const initProgressCallback = (report: webllm.InitProgressReport) => {
+ setLabel("init-label", report.text);
+ };
+ const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k";
+ const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k";
+
+ const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
+ [selectedModel1, selectedModel2],
+ { initProgressCallback: initProgressCallback },
+ );
+
+ const request1: webllm.ChatCompletionRequest = {
+ stream: true,
+ stream_options: { include_usage: true },
+ messages: [
+ { role: "user", content: "Provide me three US states." },
+ { role: "assistant", content: "California, New York, Pennsylvania." },
+ { role: "user", content: "Two more please!" },
+ ],
+ model: selectedModel1, // without specifying it, error will throw due to ambiguity
+ };
+
+ const request2: webllm.ChatCompletionRequest = {
+ stream: true,
+ stream_options: { include_usage: true },
+ messages: [
+ { role: "user", content: "Provide me three cities in NY." },
+ { role: "assistant", content: "New York, Binghamton, Buffalo." },
+ { role: "user", content: "Two more please!" },
+ ],
+ model: selectedModel2, // without specifying it, error will throw due to ambiguity
+ };
+
+ const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
+ let message = "";
+ for await (const chunk of asyncChunkGenerator1) {
+ console.log(chunk);
+ message += chunk.choices[0]?.delta?.content || "";
+ setLabel("generate-label", message);
+ if (chunk.usage) {
+ console.log(chunk.usage); // only last chunk has usage
+ }
+ // engine.interruptGenerate(); // works with interrupt as well
+ }
+ const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
+ message += "\n\n";
+ for await (const chunk of asyncChunkGenerator2) {
+ console.log(chunk);
+ message += chunk.choices[0]?.delta?.content || "";
+ setLabel("generate-label", message);
+ if (chunk.usage) {
+ console.log(chunk.usage); // only last chunk has usage
+ }
+ // engine.interruptGenerate(); // works with interrupt as well
+ }
+
+ // without specifying from which model to get message, error will throw due to ambiguity
+ console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
+ console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
+}
+
+mainStreaming();
diff --git a/src/embedding.ts b/src/embedding.ts
index c620fb26..d5985ee9 100644
--- a/src/embedding.ts
+++ b/src/embedding.ts
@@ -265,6 +265,10 @@ export class EmbeddingPipeline {
await this.device.sync();
}
+ async asyncLoadWebGPUPipelines() {
+ await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule());
+ }
+
// Performance APIs below
/**
diff --git a/src/engine.ts b/src/engine.ts
index 5cb92f3f..3fe0cbb2 100644
--- a/src/engine.ts
+++ b/src/engine.ts
@@ -18,7 +18,6 @@ import {
ChatCompletionRequest,
ChatCompletion,
ChatCompletionChunk,
- ChatCompletionFinishReason,
ChatCompletionMessageParam,
ChatCompletionRequestNonStreaming,
ChatCompletionRequestStreaming,
@@ -51,18 +50,22 @@ import {
import {
cleanModelUrl,
findModelRecord,
+ getModelIdToUse,
getToolCallFromOutputMessage,
} from "./support";
import {
- EngineNotLoadedError,
ConfigurationNotInitializedError,
DeviceLostError,
EmbeddingUnsupportedModelError,
FeatureSupportError,
MissingModelWasmError,
- ModelNotLoadedError,
ShaderF16SupportError,
WebGPUNotAvailableError,
+ ReloadArgumentSizeUnmatchedError,
+ IncorrectPipelineLoadedError,
+ ReloadModelIdNotUniqueError,
+ SpecifiedModelNotFoundError,
+ ModelNotLoadedError,
} from "./error";
import { asyncLoadTokenizer } from "./cache_util";
import { EmbeddingPipeline } from "./embedding";
@@ -72,18 +75,20 @@ import { EmbeddingPipeline } from "./embedding";
*
* Equivalent to `new webllm.MLCEngine().reload(...)`.
*
- * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
- * `engineConfig.appConfig`.
+ * @param modelId model_id of the model to load, either string or string[]. When multiple models
+ * are provided, we load all models sequentially. Each modelId needs to either be in
+ * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`.
* @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig`.
- * @param chatOpts Extra options to override chat behavior specified in `mlc-chat-config.json`.
+ * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`.
+ * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i].
* @returns An initialized `WebLLM.MLCEngine` with `modelId` loaded.
* @throws Throws error when device lost (mostly due to OOM); users should re-call `CreateMLCEngine()`,
* potentially with a smaller model or smaller context window size.
*/
export async function CreateMLCEngine(
- modelId: string,
+ modelId: string | string[],
engineConfig?: MLCEngineConfig,
- chatOpts?: ChatOptions,
+ chatOpts?: ChatOptions | ChatOptions[],
): Promise {
const engine = new MLCEngine(engineConfig);
await engine.reload(modelId, chatOpts);
@@ -104,20 +109,27 @@ export class MLCEngine implements MLCEngineInterface {
/** For embeddings.create() */
public embeddings: API.Embeddings;
- private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded
+ /** Maps each loaded model's modelId to its pipeline */
+ private loadedModelIdToPipeline: Map<
+ string,
+ LLMChatPipeline | EmbeddingPipeline
+ >;
+ /** Maps each loaded model's modelId to its chatConfig */
+ private loadedModelIdToChatConfig: Map;
private logger: (msg: string) => void = log.info;
private logitProcessorRegistry?: Map;
- private logitProcessor?: LogitProcessor;
- private pipeline?: LLMChatPipeline;
- private embeddingPipeline?: EmbeddingPipeline;
private initProgressCallback?: InitProgressCallback;
private interruptSignal = false;
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
private reloadController: AbortController | undefined;
- private config?: ChatConfig;
private appConfig: AppConfig;
constructor(engineConfig?: MLCEngineConfig) {
+ this.loadedModelIdToPipeline = new Map<
+ string,
+ LLMChatPipeline | EmbeddingPipeline
+ >();
+ this.loadedModelIdToChatConfig = new Map();
this.appConfig = engineConfig?.appConfig || prebuiltAppConfig;
this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel);
this.setInitProgressCallback(engineConfig?.initProgressCallback);
@@ -150,24 +162,53 @@ export class MLCEngine implements MLCEngineInterface {
this.logitProcessorRegistry = logitProcessorRegistry;
}
+ /**
+ * Set MLCEngine logging output level
+ *
+ * @param logLevel The new log level
+ */
+ setLogLevel(logLevel: LogLevel) {
+ log.setLevel(logLevel);
+ }
+
//----------------------------------------
// 1. Model/pipeline loading and unloading
//----------------------------------------
- /**
- * Reload model `modelId`.
- * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
- * `engineConfig.appConfig`.
- * @param chatOpts To optionally override the `mlc-chat-config.json` of `modelId`.
- * @throws Throws error when device lost (mostly due to OOM); users should re-call reload(),
- * potentially with a smaller model or smaller context window size.
- */
- async reload(modelId: string, chatOpts?: ChatOptions): Promise {
+ async reload(
+ modelId: string | string[],
+ chatOpts?: ChatOptions | ChatOptions[],
+ ): Promise {
+ // 0. Unload all loaded models
await this.unload();
+ // 1. Convert inputs to arrays
+ if (!Array.isArray(modelId)) {
+ modelId = [modelId];
+ }
+ if (chatOpts !== undefined && !Array.isArray(chatOpts)) {
+ chatOpts = [chatOpts];
+ }
+ // 2. Check whether size matches
+ if (chatOpts !== undefined && modelId.length !== chatOpts.length) {
+ throw new ReloadArgumentSizeUnmatchedError(
+ modelId.length,
+ chatOpts.length,
+ );
+ }
+ // 3. Make sure each model in modelId is unique
+ if (new Set(modelId).size < modelId.length) {
+ throw new ReloadModelIdNotUniqueError(modelId);
+ }
+ // 4. Sequentially load each model
+ // Single abort should stop all to-be-loaded models
this.reloadController = new AbortController();
-
try {
- await this.reloadInternal(modelId, chatOpts);
+ for (let i = 0; i < modelId.length; i++) {
+ await this.reloadInternal(
+ modelId[i],
+ chatOpts ? chatOpts[i] : undefined,
+ );
+ }
} catch (error) {
if (error instanceof DOMException && error.name === "AbortError") {
log.warn("Reload() is aborted.", error.message);
@@ -183,7 +224,7 @@ export class MLCEngine implements MLCEngineInterface {
modelId: string,
chatOpts?: ChatOptions,
): Promise {
- this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
+ const logitProcessor = this.logitProcessorRegistry?.get(modelId);
const tstart = performance.now();
const modelRecord = findModelRecord(modelId, this.appConfig);
@@ -205,7 +246,7 @@ export class MLCEngine implements MLCEngineInterface {
// load config
const configUrl = new URL("mlc-chat-config.json", modelUrl).href;
- this.config = {
+ const curModelConfig = {
...(await configCache.fetchWithCache(
configUrl,
"json",
@@ -214,6 +255,7 @@ export class MLCEngine implements MLCEngineInterface {
...modelRecord.overrides,
...chatOpts,
} as ChatConfig;
+ this.loadedModelIdToChatConfig.set(modelId, curModelConfig);
// load tvm wasm
let wasmCache: tvmjs.ArtifactCacheTemplate;
@@ -297,7 +339,7 @@ export class MLCEngine implements MLCEngineInterface {
const tokenizer = await asyncLoadTokenizer(
modelUrl,
- this.config,
+ curModelConfig,
this.appConfig,
this.logger,
);
@@ -309,23 +351,26 @@ export class MLCEngine implements MLCEngineInterface {
cacheType,
this.reloadController?.signal,
);
+
+ // Instantiate pipeline
+ // TODO: would be good to somehow check for error when LLMChatPipeline is loaded for an
+ // embedding model, and prompt user to use ModelRecord.model_type
+ let newPipeline: LLMChatPipeline | EmbeddingPipeline;
if (modelRecord.model_type === ModelType.embedding) {
- this.embeddingPipeline = new EmbeddingPipeline(
- tvm,
- tokenizer,
- this.config,
- );
+ newPipeline = new EmbeddingPipeline(tvm, tokenizer, curModelConfig);
} else {
- this.pipeline = new LLMChatPipeline(
+ newPipeline = new LLMChatPipeline(
tvm,
tokenizer,
- this.config,
- this.logitProcessor,
+ curModelConfig,
+ logitProcessor,
);
}
- await this.pipeline?.asyncLoadWebGPUPipelines();
- const tend = performance.now();
+ await newPipeline.asyncLoadWebGPUPipelines();
+ this.loadedModelIdToPipeline.set(modelId, newPipeline);
+ // Clean up
+ const tend = performance.now();
if (this.initProgressCallback !== undefined) {
const text = "Finish loading on " + gpuLabel;
this.initProgressCallback({
@@ -334,28 +379,23 @@ export class MLCEngine implements MLCEngineInterface {
text: text,
});
}
- this.currentModelId = modelId;
-
if (deviceLostInReload) {
throw new DeviceLostError();
}
}
- /**
- * Unloads the currently loaded model and destroy the webgpu device. Waits
- * until the webgpu device finishes all submitted work and destroys itself.
- * @note This is an asynchronous function.
- */
async unload() {
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
- this.pipeline?.dispose();
- this.embeddingPipeline?.dispose();
- // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
- await this.pipeline?.sync();
- await this.embeddingPipeline?.sync();
- this.pipeline = undefined;
- this.embeddingPipeline = undefined;
- this.currentModelId = undefined;
+ // TODO: can optimize by calling dispose() to all pipelines in parallel. However, need to wait
+ // for all sync() to finish before proceeding (e.g. naive forEach does not work)
+ for (const entry of Array.from(this.loadedModelIdToPipeline.entries())) {
+ const pipeline = entry[1];
+ pipeline.dispose();
+ // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
+ await pipeline.sync();
+ }
+ this.loadedModelIdToPipeline.clear();
+ this.loadedModelIdToChatConfig.clear();
this.deviceLostIsError = true;
if (this.reloadController) {
this.reloadController.abort("Engine.unload() is called.");
@@ -369,44 +409,52 @@ export class MLCEngine implements MLCEngineInterface {
private async _generate(
input:
- | string
| ChatCompletionRequestNonStreaming
| CompletionCreateParamsNonStreaming,
- genConfig?: GenerationConfig,
+ pipeline: LLMChatPipeline,
+ chatConfig: ChatConfig,
+ genConfig: GenerationConfig,
): Promise {
this.interruptSignal = false;
if (genConfig !== undefined) {
postInitAndCheckGenerationConfigValues(genConfig);
}
- await this.prefill(input, genConfig);
+ await this.prefill(input, pipeline, chatConfig, genConfig);
let counter = 1;
- while (!this.stopped()) {
+ while (!pipeline.stopped()) {
if (this.interruptSignal) {
- this.getPipeline().triggerStop();
+ pipeline.triggerStop();
break;
}
counter += 1;
- await this.decode(genConfig);
+ await this.decode(pipeline, genConfig);
}
- return await this.getMessage();
+ return pipeline.getMessage();
}
/**
* Similar to `_generate()`; but instead of using callback, we use an async iterable.
- * @param request Request for chat completion.
- * @param genConfig Generation config extraced from `request`.
*/
asyncGenerate(
request: ChatCompletionRequestStreaming,
+ model: string,
+ pipeline: LLMChatPipeline,
+ chatConfig: ChatConfig,
genConfig: GenerationConfig,
): AsyncGenerator;
asyncGenerate(
request: CompletionCreateParamsStreaming,
+ model: string,
+ pipeline: LLMChatPipeline,
+ chatConfig: ChatConfig,
genConfig: GenerationConfig,
): AsyncGenerator;
async *asyncGenerate(
request: ChatCompletionRequestStreaming | CompletionCreateParamsStreaming,
+ model: string,
+ pipeline: LLMChatPipeline,
+ chatConfig: ChatConfig,
genConfig: GenerationConfig,
): AsyncGenerator {
// 0. Pre-processing
@@ -422,12 +470,11 @@ export class MLCEngine implements MLCEngineInterface {
}
postInitAndCheckGenerationConfigValues(genConfig);
if (request.seed !== null && request.seed !== undefined) {
- this.getPipeline().setSeed(request.seed);
+ pipeline.setSeed(request.seed);
}
// 1. Helper function that generates the chunk
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
- const model = this.currentModelId!;
const created = Date.now();
const id = crypto.randomUUID();
this.interruptSignal = false;
@@ -446,13 +493,13 @@ export class MLCEngine implements MLCEngineInterface {
}
async function _getChunk(
- thisModule: MLCEngine,
+ selectedPipeline: LLMChatPipeline,
): Promise {
// Remove the replacement character (U+FFFD) from the response to handle emojis.
// Each emoji is made up of multiples of 4 tokens; when truncated, it is displayed as �, so
// we skip this delta until a full emoji is rendered
// TODO(Charlie): This does not consider cases of � not being emoji, need to fix with Streamer
- const curMessage = await thisModule.getMessage();
+ const curMessage = selectedPipeline.getMessage();
const numTrailingReplacementChar =
_countTrailingReplacementChar(curMessage);
if (numTrailingReplacementChar % 4 !== 0) {
@@ -463,7 +510,7 @@ export class MLCEngine implements MLCEngineInterface {
prevMessageLength = curMessage.length;
const logprobs = request.logprobs
? ({
- content: thisModule.getPipeline().getTokenLogprobArray().slice(-1), // always the last entry
+ content: selectedPipeline.getTokenLogprobArray().slice(-1), // always the last entry
} as ChatCompletionChunk.Choice.Logprobs)
: null;
if (isChatCompletion) {
@@ -502,19 +549,19 @@ export class MLCEngine implements MLCEngineInterface {
}
// 2. Auto-regressive loop
- await this.prefill(request, genConfig);
- let curChunk = await _getChunk(this); // prefill produces a chunk
+ await this.prefill(request, pipeline, chatConfig, genConfig);
+ let curChunk = await _getChunk(pipeline); // prefill produces a chunk
if (curChunk) {
yield curChunk;
}
- while (!this.stopped()) {
+ while (!pipeline.stopped()) {
if (this.interruptSignal) {
- this.getPipeline().triggerStop();
+ pipeline.triggerStop();
break;
}
- await this.decode(genConfig);
- curChunk = await _getChunk(this);
+ await this.decode(pipeline, genConfig);
+ curChunk = await _getChunk(pipeline);
if (curChunk) {
yield curChunk;
}
@@ -522,20 +569,20 @@ export class MLCEngine implements MLCEngineInterface {
// Reset seed -- we do not want this seed to affect future requests
if (request.seed !== null && request.seed !== undefined) {
- this.getPipeline().setSeed(Date.now());
+ pipeline.setSeed(Date.now());
}
// 3. Last chunk empty marking the end
// If function calling, use the last chunk to return tool_calls
- let finish_reason = this.getFinishReason()!;
+ let finish_reason = pipeline.getFinishReason()!;
let tool_calls:
| Array
| undefined;
- if (this.getFinishReason()! == "stop" && isFunctionCalling) {
+ if (pipeline.getFinishReason() === "stop" && isFunctionCalling) {
// If stopped due to length or abort, cannot output return tool_calls field
finish_reason = "tool_calls";
- const outputMessage = await this.getMessage();
+ const outputMessage = pipeline.getMessage();
tool_calls = getToolCallFromOutputMessage(
outputMessage,
/*isStreaming=*/ true,
@@ -581,13 +628,10 @@ export class MLCEngine implements MLCEngineInterface {
// 4. Usage chunk
if (request.stream_options?.include_usage) {
- const completion_tokens =
- this.getPipeline().getCurRoundDecodingTotalTokens();
- const prompt_tokens = this.getPipeline().getCurRoundPrefillTotalTokens();
- const prefill_tokens_per_s =
- this.getPipeline().getCurRoundPrefillTokensPerSec();
- const decode_tokens_per_s =
- this.getPipeline().getCurRoundDecodingTokensPerSec();
+ const completion_tokens = pipeline.getCurRoundDecodingTotalTokens();
+ const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens();
+ const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec();
+ const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec();
const usage: CompletionUsage = {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
@@ -649,11 +693,10 @@ export class MLCEngine implements MLCEngineInterface {
async chatCompletion(
request: ChatCompletionRequest,
): Promise | ChatCompletion> {
- // 0. Preprocess inputs
- if (!this.currentModelId) {
- throw new ModelNotLoadedError();
- }
- API.postInitAndCheckFieldsChatCompletion(request, this.currentModelId);
+ // 0. Check model loaded and preprocess inputs
+ const [selectedModelId, selectedPipeline, selectedChatConfig] =
+ this.getLLMStates("ChatCompletionRequest", request.model);
+ API.postInitAndCheckFieldsChatCompletion(request, selectedModelId);
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
@@ -669,11 +712,17 @@ export class MLCEngine implements MLCEngineInterface {
// 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`)
if (request.stream) {
- return this.asyncGenerate(request, genConfig);
+ return this.asyncGenerate(
+ request,
+ selectedModelId,
+ selectedPipeline,
+ selectedChatConfig,
+ genConfig,
+ );
}
if (request.seed !== null && request.seed !== undefined) {
- this.getPipeline().setSeed(request.seed);
+ selectedPipeline.setSeed(request.seed);
}
// 2. If request is non-streaming, directly reuse `_generate()`
@@ -687,18 +736,23 @@ export class MLCEngine implements MLCEngineInterface {
let outputMessage: string;
if (this.interruptSignal) {
// A single interrupt signal should stop all choices' generations
- this.getPipeline().triggerStop();
+ selectedPipeline.triggerStop();
outputMessage = "";
} else {
- outputMessage = await this._generate(request, genConfig);
+ outputMessage = await this._generate(
+ request,
+ selectedPipeline,
+ selectedChatConfig,
+ genConfig,
+ );
}
- let finish_reason = this.getFinishReason()!;
+ let finish_reason = selectedPipeline.getFinishReason()!;
// 3. Post processing for function calling
const isFunctionCalling =
request.tools !== undefined && request.tools !== null;
let tool_calls: Array | undefined;
- if (this.getFinishReason()! == "stop" && isFunctionCalling) {
+ if (selectedPipeline.getFinishReason() === "stop" && isFunctionCalling) {
// If stopped due to length or abort, cannot output return tool_calls field
finish_reason = "tool_calls";
tool_calls = getToolCallFromOutputMessage(
@@ -713,7 +767,7 @@ export class MLCEngine implements MLCEngineInterface {
index: i,
logprobs: request.logprobs
? ({
- content: this.getPipeline().getTokenLogprobArray(),
+ content: selectedPipeline.getTokenLogprobArray(),
} as ChatCompletion.Choice.Logprobs)
: null,
message: isFunctionCalling
@@ -727,16 +781,16 @@ export class MLCEngine implements MLCEngineInterface {
role: "assistant",
},
});
- completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens();
- prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens();
- prefill_time += this.getPipeline().getCurRoundPrefillTotalTime();
- decode_time += this.getPipeline().getCurRoundDecodingTotalTime();
+ completion_tokens += selectedPipeline.getCurRoundDecodingTotalTokens();
+ prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens();
+ prefill_time += selectedPipeline.getCurRoundPrefillTotalTime();
+ decode_time += selectedPipeline.getCurRoundDecodingTotalTime();
}
const response: ChatCompletion = {
id: crypto.randomUUID(),
choices: choices,
- model: this.currentModelId,
+ model: selectedModelId,
object: "chat.completion",
created: Date.now(),
usage: {
@@ -752,7 +806,7 @@ export class MLCEngine implements MLCEngineInterface {
// Reset seed -- we do not want this seed to affect future requests
if (request.seed !== null && request.seed !== undefined) {
- this.getPipeline().setSeed(Date.now());
+ selectedPipeline.setSeed(Date.now());
}
return response;
}
@@ -777,11 +831,10 @@ export class MLCEngine implements MLCEngineInterface {
async completion(
request: CompletionCreateParams,
): Promise | Completion> {
- // 0. Preprocess inputs
- if (!this.currentModelId) {
- throw new ModelNotLoadedError();
- }
- API.postInitAndCheckFieldsCompletion(request, this.currentModelId);
+ // 0. Check model loaded and preprocess inputs
+ const [selectedModelId, selectedPipeline, selectedChatConfig] =
+ this.getLLMStates("ChatCompletionRequest", request.model);
+ API.postInitAndCheckFieldsCompletion(request, selectedModelId);
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
@@ -796,11 +849,17 @@ export class MLCEngine implements MLCEngineInterface {
// 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`)
if (request.stream) {
- return this.asyncGenerate(request, genConfig);
+ return this.asyncGenerate(
+ request,
+ selectedModelId,
+ selectedPipeline,
+ selectedChatConfig,
+ genConfig,
+ );
}
if (request.seed !== null && request.seed !== undefined) {
- this.getPipeline().setSeed(request.seed);
+ selectedPipeline.setSeed(request.seed);
}
// 2. If request is non-streaming, directly reuse `_generate()`
@@ -814,12 +873,17 @@ export class MLCEngine implements MLCEngineInterface {
let outputMessage: string;
if (this.interruptSignal) {
// A single interrupt signal should stop all choices' generations
- this.getPipeline().triggerStop();
+ selectedPipeline.triggerStop();
outputMessage = "";
} else {
- outputMessage = await this._generate(request, genConfig);
+ outputMessage = await this._generate(
+ request,
+ selectedPipeline,
+ selectedChatConfig,
+ genConfig,
+ );
}
- const finish_reason = this.getFinishReason()!;
+ const finish_reason = selectedPipeline.getFinishReason()!;
choices.push({
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
@@ -827,21 +891,21 @@ export class MLCEngine implements MLCEngineInterface {
index: i,
logprobs: request.logprobs
? ({
- content: this.getPipeline().getTokenLogprobArray(),
+ content: selectedPipeline.getTokenLogprobArray(),
} as ChatCompletion.Choice.Logprobs)
: null,
text: request.echo ? request.prompt + outputMessage : outputMessage,
});
- completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens();
- prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens();
- prefill_time += this.getPipeline().getCurRoundPrefillTotalTime();
- decode_time += this.getPipeline().getCurRoundDecodingTotalTime();
+ completion_tokens += selectedPipeline.getCurRoundDecodingTotalTokens();
+ prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens();
+ prefill_time += selectedPipeline.getCurRoundPrefillTotalTime();
+ decode_time += selectedPipeline.getCurRoundDecodingTotalTime();
}
const response: Completion = {
id: crypto.randomUUID(),
choices: choices,
- model: this.currentModelId,
+ model: selectedModelId,
object: "text_completion",
created: Date.now(),
usage: {
@@ -857,7 +921,7 @@ export class MLCEngine implements MLCEngineInterface {
// Reset seed -- we do not want this seed to affect future requests
if (request.seed !== null && request.seed !== undefined) {
- this.getPipeline().setSeed(Date.now());
+ selectedPipeline.setSeed(Date.now());
}
return response;
}
@@ -866,20 +930,34 @@ export class MLCEngine implements MLCEngineInterface {
request: EmbeddingCreateParams,
): Promise {
// 0. Preprocess inputs
- if (!this.currentModelId) {
- throw new ModelNotLoadedError();
+ const loadedModelIds: string[] = Array.from(
+ this.loadedModelIdToPipeline.keys(),
+ );
+ const selectedModelId: string = getModelIdToUse(
+ loadedModelIds,
+ request.model,
+ "EmbeddingCreateParams",
+ );
+ const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId);
+ if (!(selectedPipeline instanceof EmbeddingPipeline)) {
+ throw new IncorrectPipelineLoadedError(
+ selectedModelId,
+ "EmbeddingPipeline",
+ "EmbeddingCreateParams",
+ );
}
if (
- findModelRecord(this.currentModelId, this.appConfig).model_type !==
+ findModelRecord(selectedModelId, this.appConfig).model_type !==
ModelType.embedding
) {
- throw new EmbeddingUnsupportedModelError(this.currentModelId);
+ throw new EmbeddingUnsupportedModelError(selectedModelId);
}
- API.postInitAndCheckFieldsEmbedding(request, this.currentModelId);
+ API.postInitAndCheckFieldsEmbedding(request, selectedModelId);
// 1. Call EmbeddingPipeline to get embeddings
- const embedResult: Array> =
- await this.getEmbeddingPipeline().embedStep(request.input);
+ const embedResult: Array> = await selectedPipeline.embedStep(
+ request.input,
+ );
// 2. Prepare response
const batchSize = embedResult.length;
@@ -894,15 +972,13 @@ export class MLCEngine implements MLCEngineInterface {
}
return {
data: data,
- model: this.currentModelId,
+ model: selectedModelId,
object: "list",
usage: {
- prompt_tokens:
- this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(),
- total_tokens: this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(),
+ prompt_tokens: selectedPipeline.getCurRoundEmbedTotalTokens(),
+ total_tokens: selectedPipeline.getCurRoundEmbedTotalTokens(),
extra: {
- prefill_tokens_per_s:
- this.getEmbeddingPipeline().getCurRoundEmbedTokensPerSec(),
+ prefill_tokens_per_s: selectedPipeline.getCurRoundEmbedTokensPerSec(),
},
},
};
@@ -950,42 +1026,64 @@ export class MLCEngine implements MLCEngineInterface {
return gpuDetectOutput.adapterInfo.vendor;
}
- //----------------------------------------------
- // 5. Low-level APIs that interact with pipeline
- //----------------------------------------------
- private getPipeline(): LLMChatPipeline {
- if (this.pipeline === undefined) {
- throw new EngineNotLoadedError();
- }
- return this.pipeline;
- }
+ //---------------------------------------------------------------
+ // 5. Helper for querying currently loaded model/pipeline/config.
+ // Needed due to possibly multiple loaded models.
+ //---------------------------------------------------------------
- private getEmbeddingPipeline(): EmbeddingPipeline {
- if (this.embeddingPipeline === undefined) {
- throw new EngineNotLoadedError();
+ /**
+ * Return the model, its LLMChatPipeline, and ChatConfig to use. Throws error when unclear which
+ * model to load.
+ * @param requestName The type of request or API to load the model for. Needed for error throwing.
+ * @param modelId Model the user specified to load via the request. Required when multiple
+ * models are loaded
+ */
+ private getLLMStates(
+ requestName: string,
+ modelId?: string | null,
+ ): [string, LLMChatPipeline, ChatConfig] {
+ // TODO(webllm-team): when more modalities/pipelines are supported, make this method
+ // generic for different pipelines. e.g. currently embedding() does not use this method
+ const loadedModelIds: string[] = Array.from(
+ this.loadedModelIdToPipeline.keys(),
+ );
+ const selectedModelId: string = getModelIdToUse(
+ loadedModelIds,
+ modelId,
+ requestName,
+ );
+ const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId);
+ if (!(selectedPipeline instanceof LLMChatPipeline)) {
+ throw new IncorrectPipelineLoadedError(
+ selectedModelId,
+ "LLMChatPipeline",
+ requestName,
+ );
}
- return this.embeddingPipeline;
+ const selectedChatConfig =
+ this.loadedModelIdToChatConfig.get(selectedModelId);
+ if (selectedChatConfig === undefined) {
+ throw new Error(
+ `InternalError: chat config not registered for ${selectedModelId}.`,
+ );
+ }
+ return [selectedModelId, selectedPipeline, selectedChatConfig];
}
+ //--------------------------------------------------------------------
+ // 6. External low-level APIs that directly interacts with a pipeline.
+ //--------------------------------------------------------------------
+
async forwardTokensAndSample(
inputIds: Array,
isPrefill: boolean,
+ modelId?: string,
): Promise {
- return this.getPipeline().forwardTokensAndSample(inputIds, isPrefill);
- }
-
- /**
- * @returns Whether the generation stopped.
- */
- stopped(): boolean {
- return this.getPipeline().stopped();
- }
-
- /**
- * @returns Finish reason; undefined if generation not started/stopped yet.
- */
- getFinishReason(): ChatCompletionFinishReason | undefined {
- return this.getPipeline().getFinishReason();
+ const [, selectedPipeline] = this.getLLMStates(
+ "forwardTokensAndSample",
+ modelId,
+ );
+ return selectedPipeline.forwardTokensAndSample(inputIds, isPrefill);
}
/**
@@ -993,33 +1091,46 @@ export class MLCEngine implements MLCEngineInterface {
*
* @returns The current output message.
*/
- async getMessage(): Promise {
- return this.getPipeline().getMessage();
- }
-
- /**
- * Set MLCEngine logging output level
- *
- * @param logLevel The new log level
- */
- setLogLevel(logLevel: LogLevel) {
- log.setLevel(logLevel);
+ async getMessage(modelId?: string): Promise {
+ const [, selectedPipeline] = this.getLLMStates("getMessage", modelId);
+ return selectedPipeline.getMessage();
}
- async runtimeStatsText(): Promise {
+ async runtimeStatsText(modelId?: string): Promise {
log.warn(
"WARNING: `runtimeStatsText()` will soon be deprecated. " +
"Please use `ChatCompletion.usage` for non-streaming requests, or " +
"`ChatCompletionChunk.usage` for streaming requests, enabled by `stream_options`. " +
"The only flow that expects to use `runtimeStatsText()` as of now is `forwardTokensAndSample()`.",
);
- return this.getPipeline().runtimeStatsText();
+ const [, selectedPipeline] = this.getLLMStates("runtimeStatsText", modelId);
+ return selectedPipeline.runtimeStatsText();
}
- async resetChat(keepStats = false) {
- this.pipeline?.resetChat(keepStats);
+ async resetChat(keepStats = false, modelId?: string) {
+ try {
+ const [, selectedPipeline] = this.getLLMStates("resetChat", modelId);
+ selectedPipeline.resetChat(keepStats);
+ } catch (error) {
+ if (
+ error instanceof ModelNotLoadedError ||
+ error instanceof SpecifiedModelNotFoundError
+ ) {
+ // Only allow calling resetChat before pipeline instantiated.
+ log.debug(
+ "Caught an expected error in resetChat, treating it as no-op. Error: ",
+ error,
+ );
+ } else {
+ throw error;
+ }
+ }
}
+ //-----------------------------------------------
+ // 7. Prefill and decode given an LLMChatPipeline
+ //-----------------------------------------------
+
/**
* Run a prefill step with a given input.
*
@@ -1031,36 +1142,40 @@ export class MLCEngine implements MLCEngineInterface {
* performing multi-round chatting, so we do not reset, hence reusing KV cache. Otherwise, we
* reset every thing, treating the request as something completely new.
*
- * @param input The input prompt, or `messages` in OpenAI-like APIs.
+ * @param input The OpenAI-style prompt to prefill.
+ * @param pipeline The loaded pipeline, hence model, to carry out this prefill.
+ * @param chatConfig The chat config to use for this model.
+ * @param genConfig Generation config.
*/
async prefill(
- input: string | ChatCompletionRequest | CompletionCreateParams,
- genConfig?: GenerationConfig,
+ input: ChatCompletionRequest | CompletionCreateParams,
+ pipeline: LLMChatPipeline,
+ chatConfig: ChatConfig,
+ genConfig: GenerationConfig,
) {
- if (this.config === undefined) {
+ // TODO: SPECIFY MODEL TO PERFORM PREFILL, HENCE RETRIEVE CONFIG
+ if (chatConfig === undefined) {
throw new ConfigurationNotInitializedError();
}
let input_str: string;
let input_role_str: string | undefined;
let lastMsgRole = Role.user;
- if (typeof input === "string") {
- input_str = input;
- } else if ("messages" in input) {
+ if ("messages" in input) {
// For ChatCompletionRequest, we prepare input using `messages`
// 1. Get new conversation based on request, determine if we are in multiround chatting
- const oldConv = this.getPipeline().getConversationObject();
+ const oldConv = pipeline.getConversationObject();
const newConv = getConversationFromChatCompletionRequest(
input,
- this.config,
+ chatConfig,
);
if (!compareConversationObject(oldConv, newConv)) {
// Not the same conversation, so not multiround chatting, reset everything (KV cache, etc.)
- this.resetChat();
- this.getPipeline().setConversation(newConv);
+ pipeline.resetChat();
+ pipeline.setConversation(newConv);
} else if (newConv.messages.length === 0) {
// Empty oldConv, and no chat history in newConv, so reset and setConversation
- this.resetChat();
- this.getPipeline().setConversation(newConv);
+ pipeline.resetChat();
+ pipeline.setConversation(newConv);
} else {
log.info("Multiround chatting, reuse KVCache.");
}
@@ -1076,15 +1191,15 @@ export class MLCEngine implements MLCEngineInterface {
} else {
// For CompletionCreateParams, the input is just the prompt
input_str = input.prompt;
- this.resetChat();
+ pipeline.resetChat();
const newConv = getConversation(
- this.config.conv_template,
- this.config.conv_config,
+ chatConfig.conv_template,
+ chatConfig.conv_config,
true,
);
- this.getPipeline().setConversation(newConv);
+ pipeline.setConversation(newConv);
}
- return this.getPipeline().prefillStep(
+ return pipeline.prefillStep(
input_str,
lastMsgRole,
input_role_str,
@@ -1095,7 +1210,7 @@ export class MLCEngine implements MLCEngineInterface {
/**
* Run a decode step to decode the next token.
*/
- async decode(genConfig?: GenerationConfig) {
- return this.getPipeline().decodeStep(genConfig);
+ async decode(pipeline: LLMChatPipeline, genConfig?: GenerationConfig) {
+ return pipeline.decodeStep(genConfig);
}
}
diff --git a/src/error.ts b/src/error.ts
index ef672892..0c4bd568 100644
--- a/src/error.ts
+++ b/src/error.ts
@@ -84,9 +84,11 @@ export class WebGPUNotFoundError extends Error {
}
export class ModelNotLoadedError extends Error {
- constructor() {
+ constructor(requestName: string) {
super(
- "Model not loaded before calling chatCompletion(). Please ensure you have called `MLCEngine.reload(model)` to load the model before initiating chat operations, or initialize your engine using `CreateMLCEngine()` with a valid model configuration.",
+ `Model not loaded before trying to complete ${requestName}. Please ensure you have called ` +
+ `MLCEngine.reload(model) to load the model before initiating APIs, ` +
+ `or initialize your engine using CreateMLCEngine() with a valid model configuration.`,
);
this.name = "ModelNotLoadedError";
}
@@ -479,3 +481,63 @@ export class EmbeddingInputEmptyError extends Error {
this.name = "EmbeddingInputEmptyError";
}
}
+
+export class ReloadArgumentSizeUnmatchedError extends Error {
+ constructor(numModelId: number, numChatOpts: number) {
+ super(
+ `Expect chatOpts, if specified, to match the size of modelId. However, got ` +
+ `${numModelId} modelId, but ${numChatOpts} chatOpts.`,
+ );
+ this.name = "ReloadArgumentSizeUnmatchedError";
+ }
+}
+
+export class UnclearModelToUseError extends Error {
+ constructor(loadedModels: string[], requestName: string) {
+ super(
+ `Multiple models are loaded in engine. Please specify the model in ${requestName}.\n` +
+ `Currently loaded models are:\n${loadedModels}`,
+ );
+ this.name = "UnclearModelToUseError";
+ }
+}
+
+export class SpecifiedModelNotFoundError extends Error {
+ constructor(
+ loadedModels: string[],
+ requestedModelId: string,
+ requestName: string,
+ ) {
+ super(
+ `Specified model ${requestedModelId} for ${requestName} is not found in loaded models. ` +
+ `Please check if the correct model is loaded/specified. ` +
+ `Currently loaded models are:\n${loadedModels}`,
+ );
+ this.name = "SpecifiedModelNotFoundError";
+ }
+}
+
+export class IncorrectPipelineLoadedError extends Error {
+ constructor(
+ selectedModelId: string,
+ expectedPipeline: string,
+ requestName: string,
+ ) {
+ super(
+ `${requestName} expects model be loaded with ${expectedPipeline}. However, ` +
+ `${selectedModelId} is not loaded with this pipeline.`,
+ );
+ this.name = "IncorrectPipelineLoadedError";
+ }
+}
+
+export class ReloadModelIdNotUniqueError extends Error {
+ constructor(modelId: string[]) {
+ super(
+ `Need to make models in modelId passed to reload() need to be unique. If you want to, ` +
+ `load copies of the same model, consider making copies of the ModelRecord with ` +
+ `different model_id. Received modelId: ${modelId}`,
+ );
+ this.name = "ReloadModelIdNotUniqueError";
+ }
+}
diff --git a/src/extension_service_worker.ts b/src/extension_service_worker.ts
index d1e42032..97e0dcb0 100644
--- a/src/extension_service_worker.ts
+++ b/src/extension_service_worker.ts
@@ -7,7 +7,7 @@ import {
WebWorkerMLCEngineHandler,
WebWorkerMLCEngine,
} from "./web_worker";
-import { areChatOptionsEqual } from "./utils";
+import { areArraysEqual, areChatOptionsListEqual } from "./utils";
import { WebGPUNotFoundError } from "./error";
export interface ExtensionMLCEngineConfig extends MLCEngineConfig {
@@ -66,8 +66,8 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
const params = msg.content as ReloadParams;
// If the modelId, chatOpts, and appConfig are the same, immediately return
if (
- this.modelId === params.modelId &&
- areChatOptionsEqual(this.chatOpts, params.chatOpts)
+ areArraysEqual(this.modelId, params.modelId) &&
+ areChatOptionsListEqual(this.chatOpts, params.chatOpts)
) {
log.info("Already loaded the model. Skip loading");
const gpuDetectOutput = await tvmjs.detectGPUDevice();
@@ -104,18 +104,21 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
/**
* Create a ServiceWorkerMLCEngine.
*
- * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
- * `engineConfig.appConfig`.
+ * @param modelId model_id of the model to load, either string or string[]. When multiple models
+ * are provided, we load all models sequentially. Each modelId needs to either be in
+ * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`.
* @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig` for more.
+ * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`.
+ * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i].
* @param keepAliveMs The interval to send keep alive messages to the service worker.
* See [Service worker lifecycle](https://developer.chrome.com/docs/extensions/develop/concepts/service-workers/lifecycle#idle-shutdown)
* The default is 10s.
* @returns An initialized `WebLLM.ServiceWorkerMLCEngine` with `modelId` loaded.
*/
export async function CreateServiceWorkerMLCEngine(
- modelId: string,
+ modelId: string | string[],
engineConfig?: ExtensionMLCEngineConfig,
- chatOpts?: ChatOptions,
+ chatOpts?: ChatOptions | ChatOptions[],
keepAliveMs = 10000,
): Promise {
const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(
diff --git a/src/message.ts b/src/message.ts
index f9cd5775..618dd51a 100644
--- a/src/message.ts
+++ b/src/message.ts
@@ -1,4 +1,4 @@
-import { AppConfig, ChatOptions, GenerationConfig } from "./config";
+import { AppConfig, ChatOptions } from "./config";
import { InitProgressReport, LogLevel } from "./types";
import {
ChatCompletionRequestStreaming,
@@ -40,55 +40,55 @@ type RequestKind =
type ResponseKind = "return" | "throw" | "initProgressCallback";
export interface ReloadParams {
- modelId: string;
- chatOpts?: ChatOptions;
+ modelId: string[];
+ chatOpts?: ChatOptions[];
}
export interface ResetChatParams {
keepStats: boolean;
+ modelId?: string;
+}
+export interface GetMessageParams {
+ modelId?: string;
+}
+export interface RuntimeStatsTextParams {
+ modelId?: string;
}
export interface ForwardTokensAndSampleParams {
inputIds: Array;
isPrefill: boolean;
+ modelId?: string;
}
+
+// Notes on the following Params with modelId and chatOpts:
+// These fields are the model and chatOpts that the frontend engine expects the backend
+// to be loaded with. If not loaded due to web/service worker unexpectedly killed,
+// handler will call reload(). An engine can load multiple models, hence both are list.
+// TODO(webllm-team): should add appConfig here as well if rigorous.
+// Fore more, see https://github.com/mlc-ai/web-llm/pull/471
export interface ChatCompletionNonStreamingParams {
request: ChatCompletionRequestNonStreaming;
- // The model and chatOpts that the frontend engine expects the backend to be loaded with.
- // If not loaded due to service worker unexpectedly killed, handler will call reload().
- // TODO(webllm-team): should add appConfig here as well.
- modelId: string;
- chatOpts: ChatOptions;
+ modelId: string[];
+ chatOpts?: ChatOptions[];
}
export interface ChatCompletionStreamInitParams {
request: ChatCompletionRequestStreaming;
- // The model and chatOpts that the frontend engine expects the backend to be loaded with.
- // If not loaded due to service worker unexpectedly killed, handler will call reload().
- // TODO(webllm-team): should add appConfig here as well.
- modelId: string;
- chatOpts: ChatOptions;
+ modelId: string[];
+ chatOpts?: ChatOptions[];
}
export interface CompletionNonStreamingParams {
request: CompletionCreateParamsNonStreaming;
- // The model and chatOpts that the frontend engine expects the backend to be loaded with.
- // If not loaded due to service worker unexpectedly killed, handler will call reload().
- // TODO(webllm-team): should add appConfig here as well.
- modelId: string;
- chatOpts: ChatOptions;
+ modelId: string[];
+ chatOpts?: ChatOptions[];
}
export interface CompletionStreamInitParams {
request: CompletionCreateParamsStreaming;
- // The model and chatOpts that the frontend engine expects the backend to be loaded with.
- // If not loaded due to service worker unexpectedly killed, handler will call reload().
- // TODO(webllm-team): should add appConfig here as well.
- modelId: string;
- chatOpts: ChatOptions;
+ modelId: string[];
+ chatOpts?: ChatOptions[];
}
export interface EmbeddingParams {
request: EmbeddingCreateParams;
- // The model and chatOpts that the frontend engine expects the backend to be loaded with.
- // If not loaded due to service worker unexpectedly killed, handler will call reload().
- // TODO(webllm-team): should add appConfig here as well.
- modelId: string;
- chatOpts: ChatOptions;
+ modelId: string[];
+ chatOpts?: ChatOptions[];
}
export interface CustomRequestParams {
@@ -98,6 +98,8 @@ export interface CustomRequestParams {
export type MessageContent =
| ReloadParams
| ResetChatParams
+ | GetMessageParams
+ | RuntimeStatsTextParams
| ForwardTokensAndSampleParams
| ChatCompletionNonStreamingParams
| ChatCompletionStreamInitParams
diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts
index ff77015d..492b869f 100644
--- a/src/openai_api_protocols/chat_completion.ts
+++ b/src/openai_api_protocols/chat_completion.ts
@@ -233,12 +233,13 @@ export interface ChatCompletionRequestBase {
*/
response_format?: ResponseFormat;
- //////////////// BELOW FIELDS NOT SUPPORTED YET ////////////////
-
/**
- * Model to carry out this API.
+ * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in
+ * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`.
*
- * @note Not supported. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)`.
+ * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time.
+ * @note If only one model is loaded in the engine, this field is optional. If multiple models
+ * are loaded, this is required.
*/
model?: string | null;
}
@@ -363,7 +364,7 @@ export interface ChatCompletionChunk {
usage?: CompletionUsage;
}
-export const ChatCompletionRequestUnsupportedFields: Array = ["model"];
+export const ChatCompletionRequestUnsupportedFields: Array = []; // all supported as of now
/**
* Post init and verify whether the input of the request is valid. Thus, this function can throw
diff --git a/src/openai_api_protocols/completion.ts b/src/openai_api_protocols/completion.ts
index 66e0714c..0ed869cd 100644
--- a/src/openai_api_protocols/completion.ts
+++ b/src/openai_api_protocols/completion.ts
@@ -182,14 +182,18 @@ export interface CompletionCreateParamsBase {
*/
top_p?: number | null;
- //////////////// BELOW FIELDS NOT SUPPORTED YET ////////////////
/**
- * Model to carry out this API.
+ * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in
+ * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`.
*
- * @note Not supported. Instead call `CreateMLCEngine(model)` or `engine.reload(model)` instead.
+ * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time.
+ * @note If only one model is loaded in the engine, this field is optional. If multiple models
+ * are loaded, this is required.
*/
model?: string | null;
+ //////////////// BELOW FIELDS NOT SUPPORTED YET ////////////////
+
/**
* The suffix that comes after a completion of inserted text.
*
@@ -305,7 +309,6 @@ export interface CompletionChoice {
//////////////////////////////// 3. POST INIT ////////////////////////////////
export const CompletionCreateParamsUnsupportedFields: Array = [
- "model",
"suffix",
"user",
"best_of",
diff --git a/src/openai_api_protocols/embedding.ts b/src/openai_api_protocols/embedding.ts
index f5eeef24..5f623eb4 100644
--- a/src/openai_api_protocols/embedding.ts
+++ b/src/openai_api_protocols/embedding.ts
@@ -118,18 +118,21 @@ export interface EmbeddingCreateParams {
input: string | Array | Array | Array>;
/**
- * The format to return the embeddings in.
+ * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in
+ * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`.
*
- * @note Currently only support `float`.
+ * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time.
+ * @note If only one model is loaded in the engine, this field is optional. If multiple models
+ * are loaded, this is required.
*/
- encoding_format?: "float" | "base64";
+ model?: string | null;
/**
- * ID of the model to use.
+ * The format to return the embeddings in.
*
- * @note Not supported. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)`.
+ * @note Currently only support `float`.
*/
- model?: string;
+ encoding_format?: "float" | "base64";
// TODO: can support matryoshka embedding models in future, hence allow `dimensions` for those.
/**
@@ -149,7 +152,6 @@ export interface EmbeddingCreateParams {
}
export const EmbeddingCreateParamsUnsupportedFields: Array = [
- "model",
"dimensions",
"user",
];
diff --git a/src/service_worker.ts b/src/service_worker.ts
index b1e2d32d..6e1c24ea 100644
--- a/src/service_worker.ts
+++ b/src/service_worker.ts
@@ -8,7 +8,7 @@ import {
WebWorkerMLCEngine,
ChatWorker,
} from "./web_worker";
-import { areChatOptionsEqual } from "./utils";
+import { areArraysEqual, areChatOptionsListEqual } from "./utils";
import {
NoServiceWorkerAPIError,
NonWorkerEnvironmentError,
@@ -110,8 +110,8 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler {
const params = msg.content as ReloadParams;
// If the modelId, chatOpts, and appConfig are the same, immediately return
if (
- this.modelId === params.modelId &&
- areChatOptionsEqual(this.chatOpts, params.chatOpts)
+ areArraysEqual(this.modelId, params.modelId) &&
+ areChatOptionsListEqual(this.chatOpts, params.chatOpts)
) {
log.info("Already loaded the model. Skip loading");
const gpuDetectOutput = await tvmjs.detectGPUDevice();
@@ -181,15 +181,18 @@ export class ServiceWorker implements ChatWorker {
/**
* Create a ServiceWorkerMLCEngine.
*
- * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
- * `engineConfig.appConfig`.
+ * @param modelId model_id of the model to load, either string or string[]. When multiple models
+ * are provided, we load all models sequentially. Each modelId needs to either be in
+ * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`.
* @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig` for more.
+ * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`.
+ * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i].
* @returns An initialized `WebLLM.ServiceWorkerMLCEngine` with `modelId` loaded.
*/
export async function CreateServiceWorkerMLCEngine(
- modelId: string,
+ modelId: string | string[],
engineConfig?: MLCEngineConfig,
- chatOpts?: ChatOptions,
+ chatOpts?: ChatOptions | ChatOptions[],
keepAliveMs = 10000,
): Promise {
if (!("serviceWorker" in navigator)) {
diff --git a/src/support.ts b/src/support.ts
index 95b20584..d30a94e0 100644
--- a/src/support.ts
+++ b/src/support.ts
@@ -1,15 +1,18 @@
/** Util methods. */
import { Tokenizer } from "@mlc-ai/web-tokenizers";
-import { AppConfig, MessagePlaceholders } from "./config";
+import { AppConfig, MessagePlaceholders, ModelRecord } from "./config";
import {
ChatCompletionChunk,
ChatCompletionMessageToolCall,
} from "./openai_api_protocols/index";
import {
ModelNotFoundError,
+ ModelNotLoadedError,
+ SpecifiedModelNotFoundError,
ToolCallOutputInvalidTypeError,
ToolCallOutputMissingFieldsError,
ToolCallOutputParseError,
+ UnclearModelToUseError,
} from "./error";
/**
@@ -199,10 +202,52 @@ export function getToolCallFromOutputMessage(
}
}
-export function findModelRecord(modelId: string, appConfig: AppConfig) {
+export function findModelRecord(
+ modelId: string,
+ appConfig: AppConfig,
+): ModelRecord {
const matchedItem = appConfig.model_list.find(
(item) => item.model_id == modelId,
);
if (matchedItem !== undefined) return matchedItem;
throw new ModelNotFoundError(modelId);
}
+
+/**
+ * Return the model to use given the loaded modelIds and requestModel. Throws error when unclear
+ * which model to load.
+ * @param loadedModelIds Models currently loaded in the engine.
+ * @param requestModel Model the user specified to load via the request. Required when multiple
+ * models are loaded
+ * @param requestName The type of request or API to load the model for. Needed for error throwing.
+ */
+export function getModelIdToUse(
+ loadedModelIds: string[],
+ requestModel: string | undefined | null,
+ requestName: string,
+): string {
+ let selectedModelId: string;
+ if (loadedModelIds.length === 0) {
+ throw new ModelNotLoadedError(requestName);
+ }
+ if (requestModel) {
+ // If specified model
+ if (loadedModelIds.indexOf(requestModel) === -1) {
+ throw new SpecifiedModelNotFoundError(
+ loadedModelIds,
+ requestModel,
+ requestName,
+ );
+ } else {
+ selectedModelId = requestModel;
+ }
+ } else {
+ // If not specified
+ if (loadedModelIds.length > 1) {
+ throw new UnclearModelToUseError(loadedModelIds, requestName);
+ } else {
+ selectedModelId = loadedModelIds[0];
+ }
+ }
+ return selectedModelId;
+}
diff --git a/src/types.ts b/src/types.ts
index 1dc15899..d7c88846 100644
--- a/src/types.ts
+++ b/src/types.ts
@@ -99,12 +99,20 @@ export interface MLCEngineInterface {
/**
* Reload the chat with a new model.
*
- * @param modelId model_id of the model to load.
- * @param chatOpts Extra options to override chat behavior.
+ * @param modelId model_id of the model to load, either string or string[]. When multiple models
+ * are provided, we load all models sequentially. Each modelId needs to either be in
+ * `webllm.prebuiltAppConfig`, or in `engineConfig.appConfig`.
+ * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`.
+ * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i].
* @returns A promise when reload finishes.
+ * @throws Throws error when device lost (mostly due to OOM); users should re-call reload(),
+ * potentially with a smaller model or smaller context window size.
* @note This is an async function.
*/
- reload: (modelId: string, chatOpts?: ChatOptions) => Promise;
+ reload: (
+ modelId: string | string[],
+ chatOpts?: ChatOptions | ChatOptions[],
+ ) => Promise;
/**
* OpenAI-style API. Generate a chat completion response for the given conversation and
@@ -164,9 +172,10 @@ export interface MLCEngineInterface {
/**
* @returns A text summarizing the runtime stats.
+ * @param modelId Only required when multiple models are loaded.
* @note This is an async function
*/
- runtimeStatsText: () => Promise;
+ runtimeStatsText: (modelId?: string) => Promise;
/**
* Interrupt the generate process if it is already running.
@@ -174,22 +183,25 @@ export interface MLCEngineInterface {
interruptGenerate: () => void;
/**
- * Explicitly unload the current model and release the related resources.
+ * Explicitly unload the currently loaded model(s) and release the related resources. Waits until
+ * the webgpu device finishes all submitted work and destroys itself.
+ * @note This is an asynchronous function.
*/
unload: () => Promise;
/**
* Reset the current chat session by clear all memories.
* @param keepStats: If True, do not reset the statistics.
+ * @param modelId Only required when multiple models are loaded.
*/
- resetChat: (keepStats?: boolean) => Promise;
+ resetChat: (keepStats?: boolean, modelId?: string) => Promise;
/**
* Get the current generated response.
- *
+ * @param modelId Only required when multiple models are loaded.
* @returns The current output message.
*/
- getMessage: () => Promise;
+ getMessage: (modelId?: string) => Promise;
/**
* Returns the device's maxStorageBufferBindingSize, can be used to guess whether the device
@@ -210,12 +222,14 @@ export interface MLCEngineInterface {
*
* @param inputIds The input tokens.
* @param isPrefill True if prefill, false if decode; only used for statistics.
+ * @param modelId Only required when multiple models are loaded.
* @returns Next token sampled.
* @note This is an async function.
*/
forwardTokensAndSample(
inputIds: Array,
isPrefill: boolean,
+ modelId?: string,
): Promise;
/**
diff --git a/src/utils.ts b/src/utils.ts
index e28697a7..7c688927 100644
--- a/src/utils.ts
+++ b/src/utils.ts
@@ -1,7 +1,7 @@
import { AppConfig, ChatOptions, ModelRecord } from "./config";
// Helper function to compare two arrays
-function areArraysEqual(arr1?: Array, arr2?: Array): boolean {
+export function areArraysEqual(arr1?: Array, arr2?: Array): boolean {
if (!arr1 && !arr2) return true;
if (!arr1 || !arr2) return false;
if (arr1.length !== arr2.length) return false;
@@ -120,3 +120,28 @@ export function areChatOptionsEqual(
// If all checks passed, the options are equal
return true;
}
+
+export function areChatOptionsListEqual(
+ options1?: ChatOptions[],
+ options2?: ChatOptions[],
+): boolean {
+ if (options1 && options2) {
+ // Both defined, need to compare
+ if (options1.length !== options2.length) {
+ return false;
+ } else {
+ for (let i = 0; i < options1.length; i++) {
+ if (!areChatOptionsEqual(options1[i], options2[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+ } else if (!options1 && !options2) {
+ // Both undefined, equal
+ return true;
+ } else {
+ // One defined, other not
+ return false;
+ }
+}
diff --git a/src/web_worker.ts b/src/web_worker.ts
index 5f4ec6f3..6243f828 100644
--- a/src/web_worker.ts
+++ b/src/web_worker.ts
@@ -34,6 +34,8 @@ import {
CompletionNonStreamingParams,
EmbeddingParams,
CompletionStreamInitParams,
+ GetMessageParams,
+ RuntimeStatsTextParams,
} from "./message";
import log from "loglevel";
import { MLCEngine } from "./engine";
@@ -41,6 +43,7 @@ import {
UnknownMessageKindError,
WorkerEngineModelNotLoadedError,
} from "./error";
+import { areArraysEqual } from "./utils";
/**
* Worker handler that can be used in a WebWorker
@@ -56,14 +59,15 @@ import {
export class WebWorkerMLCEngineHandler {
/**
* The modelId and chatOpts that the underlying engine (backend) is currently loaded with.
+ * An engine can be loaded with multiple models, so modelId and chatOpts are lists.
*
* TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to
* unexpected reason. Therefore, we should get it from `this.engine` directly and make handler
* stateless. Besides, consider if we should add appConfig, or use engine's API to find the
* corresponding model record rather than relying on just the modelId.
*/
- modelId?: string;
- chatOpts?: ChatOptions;
+ modelId?: string[];
+ chatOpts?: ChatOptions[];
public engine: MLCEngine;
/** ChatCompletion and Completion share the same chunk generator. */
@@ -151,6 +155,7 @@ export class WebWorkerMLCEngineHandler {
const res = await this.engine.forwardTokensAndSample(
params.inputIds,
params.isPrefill,
+ params.modelId,
);
onComplete?.(res);
return res;
@@ -238,7 +243,8 @@ export class WebWorkerMLCEngineHandler {
}
case "runtimeStatsText": {
this.handleTask(msg.uuid, async () => {
- const res = await this.engine.runtimeStatsText();
+ const params = msg.content as RuntimeStatsTextParams;
+ const res = await this.engine.runtimeStatsText(params.modelId);
onComplete?.(res);
return res;
});
@@ -266,7 +272,7 @@ export class WebWorkerMLCEngineHandler {
case "resetChat": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as ResetChatParams;
- await this.engine.resetChat(params.keepStats);
+ await this.engine.resetChat(params.keepStats, params.modelId);
onComplete?.(null);
return null;
});
@@ -290,7 +296,8 @@ export class WebWorkerMLCEngineHandler {
}
case "getMessage": {
this.handleTask(msg.uuid, async () => {
- const res = await this.engine.getMessage();
+ const params = msg.content as GetMessageParams;
+ const res = await this.engine.getMessage(params.modelId);
onComplete?.(res);
return res;
});
@@ -329,10 +336,11 @@ export class WebWorkerMLCEngineHandler {
* to possibly killed service worker), we reload here.
*/
async reloadIfUnmatched(
- expectedModelId: string,
- expectedChatOpts: ChatOptions,
+ expectedModelId: string[],
+ expectedChatOpts?: ChatOptions[],
) {
- if (this.modelId !== expectedModelId) {
+ // TODO: should we also check expectedChatOpts here?
+ if (!areArraysEqual(this.modelId, expectedModelId)) {
log.warn(
"WebWorkerMLCEngine expects model is loaded in WebWorkerMLCEngineHandler, " +
"but it is not. This may due to web/service worker is unexpectedly killed.\n" +
@@ -353,19 +361,22 @@ export interface ChatWorker {
*
* Equivalent to `new webllm.WebWorkerMLCEngine(worker).reload(...)`.
*
- * @param worker The worker that holds the actual MLCEngine, intialized with `new Worker()`.
- * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
- * `engineConfig.appConfig`.
+ * @param worker The worker that holds the actual MLCEngine, initialized with `new Worker()`.
+ * @param modelId model_id of the model to load, either string or string[]. When multiple models
+ * are provided, we load all models sequentially. Each modelId needs to either be in
+ * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`.
* @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig` for more.
+ * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`.
+ * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i].
* @returns An initialized `WebLLM.WebWorkerMLCEngine` with `modelId` loaded.
*
* @note engineConfig.logitProcessorRegistry is ignored for `CreateWebWorkMLCEngine()`.
*/
export async function CreateWebWorkerMLCEngine(
worker: any,
- modelId: string,
+ modelId: string | string[],
engineConfig?: MLCEngineConfig,
- chatOpts?: ChatOptions,
+ chatOpts?: ChatOptions | ChatOptions[],
): Promise {
const webWorkerMLCEngine = new WebWorkerMLCEngine(worker, engineConfig);
await webWorkerMLCEngine.reload(modelId, chatOpts);
@@ -395,9 +406,10 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
* The modelId and chatOpts that the frontend expects the backend engine is currently loaded
* with. Needed for service worker. It is the backend and handler's job to match up with the
* expectation despite the web/service worker possibly being killed.
+ * Since an engine can load multiple models, both modelId and chatOpts are lists.
*/
- modelId?: string;
- chatOpts?: ChatOptions;
+ modelId?: string[];
+ chatOpts?: ChatOptions[];
private initProgressCallback?: InitProgressCallback;
private pendingPromise = new Map void>();
@@ -481,7 +493,18 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
return promise;
}
- async reload(modelId: string, chatOpts?: ChatOptions): Promise {
+ async reload(
+ modelId: string | string[],
+ chatOpts?: ChatOptions | ChatOptions[],
+ ): Promise {
+ // Always convert modelId and chatOpts to lists internally for ease of manipulation
+ if (!Array.isArray(modelId)) {
+ modelId = [modelId];
+ }
+ if (chatOpts !== undefined && !Array.isArray(chatOpts)) {
+ chatOpts = [chatOpts];
+ }
+
const msg: WorkerRequest = {
kind: "reload",
uuid: crypto.randomUUID(),
@@ -513,20 +536,24 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
return await this.getPromise(msg);
}
- async getMessage(): Promise {
+ async getMessage(modelId?: string): Promise {
const msg: WorkerRequest = {
kind: "getMessage",
uuid: crypto.randomUUID(),
- content: null,
+ content: {
+ modelId: modelId,
+ },
};
return await this.getPromise(msg);
}
- async runtimeStatsText(): Promise {
+ async runtimeStatsText(modelId?: string): Promise {
const msg: WorkerRequest = {
kind: "runtimeStatsText",
uuid: crypto.randomUUID(),
- content: null,
+ content: {
+ modelId: modelId,
+ },
};
return await this.getPromise(msg);
}
@@ -551,12 +578,13 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
this.chatOpts = undefined;
}
- async resetChat(keepStats = false): Promise {
+ async resetChat(keepStats = false, modelId?: string): Promise {
const msg: WorkerRequest = {
kind: "resetChat",
uuid: crypto.randomUUID(),
content: {
keepStats: keepStats,
+ modelId: modelId,
},
};
await this.getPromise(msg);
@@ -565,6 +593,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
async forwardTokensAndSample(
inputIds: Array,
isPrefill: boolean,
+ modelId?: string,
): Promise {
const msg: WorkerRequest = {
kind: "forwardTokensAndSample",
@@ -572,6 +601,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
content: {
inputIds: inputIds,
isPrefill: isPrefill,
+ modelId: modelId,
},
};
return await this.getPromise(msg);
diff --git a/tests/openai_chat_completion.test.ts b/tests/openai_chat_completion.test.ts
index 96cd8d98..f7176650 100644
--- a/tests/openai_chat_completion.test.ts
+++ b/tests/openai_chat_completion.test.ts
@@ -33,21 +33,6 @@ describe("Check chat completion unsupported requests", () => {
}).toThrow("Only specify stream_options when stream=True.");
});
- test("High-level unsupported fields", () => {
- expect(() => {
- const request: ChatCompletionRequest = {
- model: "phi-2-q4f32_1-MLC", // this raises error
- messages: [
- { role: "system", content: "You are a helpful assistant." },
- { role: "user", content: "Hello! " },
- ],
- };
- postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC");
- }).toThrow(
- "The following fields in ChatCompletionRequest are not yet supported",
- );
- });
-
test("Last message should be from user or tool", () => {
expect(() => {
const request: ChatCompletionRequest = {
diff --git a/tests/openai_completion.test.ts b/tests/openai_completion.test.ts
index bb22e8df..00b12ffb 100644
--- a/tests/openai_completion.test.ts
+++ b/tests/openai_completion.test.ts
@@ -55,16 +55,6 @@ describe("Check completion unsupported requests", () => {
});
test("High-level unsupported fields", () => {
- expect(() => {
- const request: CompletionCreateParams = {
- model: "phi-2-q4f32_1-MLC", // this raises error
- prompt: "Hello, ",
- };
- postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC");
- }).toThrow(
- "The following fields in CompletionCreateParams are not yet supported",
- );
-
expect(() => {
const request: CompletionCreateParams = {
prompt: "Hello, ",
diff --git a/tests/openai_embeddings.test.ts b/tests/openai_embeddings.test.ts
index dd704ad4..09fdec5c 100644
--- a/tests/openai_embeddings.test.ts
+++ b/tests/openai_embeddings.test.ts
@@ -98,17 +98,6 @@ describe("Check embeddings unsupported requests", () => {
}).toThrow(new EmbeddingUnsupportedEncodingFormatError());
});
- test("model", () => {
- expect(() => {
- const request: EmbeddingCreateParams = {
- input: ["Hello", "Hi"],
- encoding_format: "float",
- model: "snowflake-arctic-embed-m-q0f32-MLC",
- };
- postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC");
- }).toThrow("The following fields in");
- });
-
test("user", () => {
expect(() => {
const request: EmbeddingCreateParams = {
diff --git a/tests/util.test.ts b/tests/util.test.ts
index 8cdc7955..f35e729d 100644
--- a/tests/util.test.ts
+++ b/tests/util.test.ts
@@ -1,4 +1,12 @@
-import { cleanModelUrl, getTopProbs } from "../src/support";
+import { ChatOptions } from "../src/config";
+import {
+ ModelNotLoadedError,
+ SpecifiedModelNotFoundError,
+ UnclearModelToUseError,
+} from "../src/error";
+import { cleanModelUrl, getModelIdToUse, getTopProbs } from "../src/support";
+import { areChatOptionsListEqual } from "../src/utils";
+import { MLCEngine } from "../src/engine";
describe("Check getTopLogprobs correctness", () => {
test("Correctness test 1", () => {
@@ -56,3 +64,252 @@ describe("Test clean model URL", () => {
expect(output).toEqual(expected);
});
});
+
+describe("Test getModelIdToUse", () => {
+ test("Specified model not found", () => {
+ const loadedModelIds = ["a", "b", "c"];
+ const requestModel = "d";
+ const requestName = "ChatCompletionRequest";
+ expect(() => {
+ getModelIdToUse(loadedModelIds, requestModel, requestName);
+ }).toThrow(
+ new SpecifiedModelNotFoundError(
+ loadedModelIds,
+ requestModel,
+ requestName,
+ ),
+ );
+ });
+
+ test("No model loaded", () => {
+ const loadedModelIds: string[] = [];
+ const requestModel = "d";
+ const requestName = "ChatCompletionRequest";
+ expect(() => {
+ getModelIdToUse(loadedModelIds, requestModel, requestName);
+ }).toThrow(new ModelNotLoadedError(requestName));
+ });
+
+ test("Unclear what model to use, undefined", () => {
+ const loadedModelIds = ["a", "b", "c"];
+ const requestModel = undefined;
+ const requestName = "ChatCompletionRequest";
+ expect(() => {
+ getModelIdToUse(loadedModelIds, requestModel, requestName);
+ }).toThrow(new UnclearModelToUseError(loadedModelIds, requestName));
+ });
+
+ test("Unclear what model to use, null", () => {
+ const loadedModelIds = ["a", "b", "c"];
+ const requestModel = null;
+ const requestName = "ChatCompletionRequest";
+ expect(() => {
+ getModelIdToUse(loadedModelIds, requestModel, requestName);
+ }).toThrow(new UnclearModelToUseError(loadedModelIds, requestName));
+ });
+
+ test("Valid config, unspecified request model", () => {
+ const loadedModelIds = ["a"];
+ const requestModel = null;
+ const requestName = "ChatCompletionRequest";
+ const selectedModelId = getModelIdToUse(
+ loadedModelIds,
+ requestModel,
+ requestName,
+ );
+ expect(selectedModelId).toEqual("a");
+ });
+
+ test("Valid config, specified request model", () => {
+ const loadedModelIds = ["a"];
+ const requestModel = "a";
+ const requestName = "ChatCompletionRequest";
+ const selectedModelId = getModelIdToUse(
+ loadedModelIds,
+ requestModel,
+ requestName,
+ );
+ expect(selectedModelId).toEqual("a");
+ });
+
+ test("Valid config, specified request model, multi models loaded", () => {
+ const loadedModelIds = ["a", "b", "c"];
+ const requestModel = "c";
+ const requestName = "ChatCompletionRequest";
+ const selectedModelId = getModelIdToUse(
+ loadedModelIds,
+ requestModel,
+ requestName,
+ );
+ expect(selectedModelId).toEqual("c");
+ });
+
+ // Cannot test MLCEngine.getLLMStates E2E because `instanceof LLMChatPipeline` would not pass
+ // with dummy pipeline variables
+ test("E2E test with MLCEngine not loading a model for APIs", () => {
+ const engine = new MLCEngine();
+ expect(async () => {
+ await engine.chatCompletion({
+ messages: [{ role: "user", content: "hi" }],
+ });
+ }).rejects.toThrow(new ModelNotLoadedError("ChatCompletionRequest"));
+ expect(async () => {
+ await engine.getMessage();
+ }).rejects.toThrow(new ModelNotLoadedError("getMessage"));
+
+ // resetChat should not throw error because it is allowed to resetChat before pipeline
+ // established, as a no-op
+ expect(async () => {
+ await engine.resetChat();
+ }).not.toThrow(new ModelNotLoadedError("resetChat"));
+ });
+
+ test("E2E test with MLCEngine with two models without specifying a model", () => {
+ const engine = new MLCEngine() as any;
+ engine.loadedModelIdToPipeline = new Map();
+ engine.loadedModelIdToPipeline.set("model1", "dummyLLMChatPipeline");
+ engine.loadedModelIdToPipeline.set("model2", "dummyLLMChatPipeline");
+ const loadedModelIds = ["model1", "model2"];
+
+ expect(async () => {
+ await engine.chatCompletion({
+ messages: [{ role: "user", content: "hi" }],
+ });
+ }).rejects.toThrow(
+ new UnclearModelToUseError(loadedModelIds, "ChatCompletionRequest"),
+ );
+ expect(async () => {
+ await engine.getMessage();
+ }).rejects.toThrow(
+ new UnclearModelToUseError(loadedModelIds, "getMessage"),
+ );
+ expect(async () => {
+ await engine.resetChat();
+ }).rejects.toThrow(new UnclearModelToUseError(loadedModelIds, "resetChat"));
+ });
+
+ test("E2E test with MLCEngine with two models specifying wrong model", () => {
+ const engine = new MLCEngine() as any;
+ engine.loadedModelIdToPipeline = new Map();
+ engine.loadedModelIdToPipeline.set("model1", "dummyLLMChatPipeline");
+ engine.loadedModelIdToPipeline.set("model2", "dummyLLMChatPipeline");
+ const loadedModelIds = ["model1", "model2"];
+ const requestedModelId = "model3";
+
+ expect(async () => {
+ await engine.chatCompletion({
+ messages: [{ role: "user", content: "hi" }],
+ model: requestedModelId,
+ });
+ }).rejects.toThrow(
+ new SpecifiedModelNotFoundError(
+ loadedModelIds,
+ requestedModelId,
+ "ChatCompletionRequest",
+ ),
+ );
+ expect(async () => {
+ await engine.getMessage(requestedModelId);
+ }).rejects.toThrow(
+ new SpecifiedModelNotFoundError(
+ loadedModelIds,
+ requestedModelId,
+ "getMessage",
+ ),
+ );
+ expect(async () => {
+ await engine.runtimeStatsText(requestedModelId);
+ }).rejects.toThrow(
+ new SpecifiedModelNotFoundError(
+ loadedModelIds,
+ requestedModelId,
+ "runtimeStatsText",
+ ),
+ );
+
+ // resetChat should not throw error because it is allowed to resetChat before pipeline
+ // established, as a no-op
+ expect(async () => {
+ await engine.resetChat(false, requestedModelId);
+ }).not.toThrow(
+ new SpecifiedModelNotFoundError(
+ loadedModelIds,
+ requestedModelId,
+ "resetChat",
+ ),
+ );
+ });
+});
+
+describe("Test areChatOptionsListEqual", () => {
+ const dummyChatOpts1: ChatOptions = { tokenizer_files: ["a", "b"] };
+ const dummyChatOpts2: ChatOptions = {};
+ const dummyChatOpts3: ChatOptions = { tokenizer_files: ["a", "b"] };
+ const dummyChatOpts4: ChatOptions = {
+ tokenizer_files: ["a", "b"],
+ top_p: 0.5,
+ };
+
+ test("Two undefined", () => {
+ const options1: ChatOptions[] | undefined = undefined;
+ const options2: ChatOptions[] | undefined = undefined;
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(true);
+ });
+
+ test("One undefined", () => {
+ const options1: ChatOptions[] | undefined = [dummyChatOpts1];
+ const options2: ChatOptions[] | undefined = undefined;
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(false);
+ });
+
+ test("Both defined, not equal", () => {
+ const options1: ChatOptions[] | undefined = [dummyChatOpts1];
+ const options2: ChatOptions[] | undefined = [dummyChatOpts2];
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(false);
+ });
+
+ test("Different size", () => {
+ const options1: ChatOptions[] | undefined = [
+ dummyChatOpts1,
+ dummyChatOpts3,
+ ];
+ const options2: ChatOptions[] | undefined = [dummyChatOpts2];
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(false);
+ });
+
+ test("Same size, not equal 1", () => {
+ const options1: ChatOptions[] | undefined = [
+ dummyChatOpts1,
+ dummyChatOpts3,
+ ];
+ const options2: ChatOptions[] | undefined = [
+ dummyChatOpts1,
+ dummyChatOpts2,
+ ];
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(false);
+ });
+
+ test("Same size, not equal 2", () => {
+ const options1: ChatOptions[] | undefined = [
+ dummyChatOpts1,
+ dummyChatOpts3,
+ ];
+ const options2: ChatOptions[] | undefined = [
+ dummyChatOpts1,
+ dummyChatOpts4,
+ ];
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(false);
+ });
+
+ test("Same size, equal", () => {
+ const options1: ChatOptions[] | undefined = [
+ dummyChatOpts1,
+ dummyChatOpts3,
+ ];
+ const options2: ChatOptions[] | undefined = [
+ dummyChatOpts3,
+ dummyChatOpts1,
+ ];
+ expect(areChatOptionsListEqual(options1, options2)).toEqual(true);
+ });
+});