diff --git a/src/global.d.ts b/src/global.d.ts index 6a84c80..33ab90a 100644 --- a/src/global.d.ts +++ b/src/global.d.ts @@ -1,6 +1,12 @@ export type ChromeAISessionAvailable = 'no' | 'after-download' | 'readily'; -export interface ChromeAISessionOptions { +export interface ChromeAIModelInfo { + defaultTemperature: number; + defaultTopK: number; + maxTopK: number; +} + +export interface ChromeAISessionOptions extends Record { temperature?: number; topK?: number; } @@ -13,7 +19,7 @@ export interface ChromeAISession { export interface ChromePromptAPI { canCreateTextSession: () => Promise; - defaultTextSessionOptions: () => Promise; + textModelInfo: () => Promise; createTextSession: ( options?: ChromeAISessionOptions ) => Promise; diff --git a/src/language-model.ts b/src/language-model.ts index a768edc..dc94eeb 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -24,33 +24,7 @@ export type ChromeAIChatModelId = 'text'; export interface ChromeAIChatSettings extends Record { temperature?: number; - /** - * Optional. The maximum number of tokens to consider when sampling. - * - * Models use nucleus sampling or combined Top-k and nucleus sampling. - * Top-k sampling considers the set of topK most probable tokens. - * Models running with nucleus sampling don't allow topK setting. - */ topK?: number; - - /** - * Optional. A list of unique safety settings for blocking unsafe content. - * @note this is not working yet - */ - safetySettings?: Array<{ - category: - | 'HARM_CATEGORY_HATE_SPEECH' - | 'HARM_CATEGORY_DANGEROUS_CONTENT' - | 'HARM_CATEGORY_HARASSMENT' - | 'HARM_CATEGORY_SEXUALLY_EXPLICIT'; - - threshold: - | 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' - | 'BLOCK_LOW_AND_ABOVE' - | 'BLOCK_MEDIUM_AND_ABOVE' - | 'BLOCK_ONLY_HIGH' - | 'BLOCK_NONE'; - }>; } function getStringContent( @@ -105,8 +79,13 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { throw new LoadSettingError({ message: 'Built-in model not ready' }); } - const defaultOptions = await ai.defaultTextSessionOptions(); - this.options = { ...defaultOptions, ...this.options, ...options }; + const defaultOptions = await ai.textModelInfo(); + this.options = { + temperature: defaultOptions.defaultTemperature, + topK: defaultOptions.defaultTopK, + ...this.options, + ...options, + }; this.session = await ai.createTextSession(this.options); diff --git a/src/polyfill/session.ts b/src/polyfill/session.ts index 42511e8..ce8bedb 100644 --- a/src/polyfill/session.ts +++ b/src/polyfill/session.ts @@ -1,5 +1,6 @@ import { LlmInference, ProgressListener } from '@mediapipe/tasks-genai'; import { + ChromeAIModelInfo, ChromeAISession, ChromeAISessionAvailable, ChromeAISessionOptions, @@ -87,16 +88,17 @@ export class PolyfillChromeAI implements ChromePromptAPI { return isModelAssetBufferReady ? 'readily' : 'after-download'; }; - public defaultTextSessionOptions = - async (): Promise => ({ - temperature: 0.8, - topK: 3, - }); + public textModelInfo = async (): Promise => ({ + defaultTemperature: 0.8, + defaultTopK: 3, + maxTopK: 128, + }); public createTextSession = async ( options?: ChromeAISessionOptions ): Promise => { - const argv = options ?? (await this.defaultTextSessionOptions()); + const defaultParams = await this.textModelInfo(); + const argv = options ?? { temperature: 0.8, topK: 3 }; const llm = await LlmInference.createFromOptions( { wasmLoaderPath: this.aiOptions.wasmLoaderPath!,