diff --git a/.changeset/polite-forks-whisper.md b/.changeset/polite-forks-whisper.md new file mode 100644 index 0000000..c09c26c --- /dev/null +++ b/.changeset/polite-forks-whisper.md @@ -0,0 +1,5 @@ +--- +"chrome-ai": patch +--- + +feat: stream text/object support abort signal diff --git a/src/chromeai.test.ts b/src/chromeai.test.ts new file mode 100644 index 0000000..de3ebac --- /dev/null +++ b/src/chromeai.test.ts @@ -0,0 +1,11 @@ +import { describe, it, expect } from 'vitest'; +import { chromeai } from './chromeai'; + +describe('chromeai', () => { + it('should correctly create instance', async () => { + expect(chromeai().modelId).toBe('generic'); + expect(chromeai('text').modelId).toBe('text'); + expect(chromeai('embedding').modelId).toBe('embedding'); + expect(chromeai.embedding().modelId).toBe('embedding'); + }); +}); diff --git a/src/chromeai.ts b/src/chromeai.ts new file mode 100644 index 0000000..e3a8b8e --- /dev/null +++ b/src/chromeai.ts @@ -0,0 +1,40 @@ +import { + ChromeAIEmbeddingModel, + ChromeAIEmbeddingModelSettings, +} from './embedding-model'; +import { + ChromeAIChatLanguageModel, + ChromeAIChatModelId, + ChromeAIChatSettings, +} from './language-model'; +import createDebug from 'debug'; + +const debug = createDebug('chromeai'); + +/** + * Create a new ChromeAI model/embedding instance. + * @param modelId 'generic' | 'text' | 'embedding' + * @param settings Options for the model + */ +export function chromeai( + modelId?: ChromeAIChatModelId, + settings?: ChromeAIChatSettings +): ChromeAIChatLanguageModel; +export function chromeai( + modelId?: 'embedding', + settings?: ChromeAIEmbeddingModelSettings +): ChromeAIEmbeddingModel; +export function chromeai(modelId: string = 'generic', settings: any = {}) { + debug('create instance', modelId, settings); + if (modelId === 'embedding') { + return new ChromeAIEmbeddingModel(settings); + } + return new ChromeAIChatLanguageModel( + modelId as ChromeAIChatModelId, + settings + ); +} + +/** @deprecated use `chromeai('embedding'[, options])` */ +chromeai.embedding = (settings: ChromeAIEmbeddingModelSettings = {}) => + new ChromeAIEmbeddingModel(settings); diff --git a/src/embedding-model.test.ts b/src/embedding-model.test.ts index 3f9c7ca..10f5d35 100644 --- a/src/embedding-model.test.ts +++ b/src/embedding-model.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi } from 'vitest'; -import { ChromeAIEmbeddingModel, chromeEmbedding } from './embedding-model'; +import { ChromeAIEmbeddingModel } from './embedding-model'; import { embed } from 'ai'; vi.mock('@mediapipe/tasks-text', async () => ({ @@ -23,10 +23,10 @@ vi.mock('@mediapipe/tasks-text', async () => ({ describe('embedding-model', () => { it('should instantiation anyways', async () => { expect(new ChromeAIEmbeddingModel()).toBeInstanceOf(ChromeAIEmbeddingModel); - expect(chromeEmbedding()).toBeInstanceOf(ChromeAIEmbeddingModel); + expect(new ChromeAIEmbeddingModel()).toBeInstanceOf(ChromeAIEmbeddingModel); }); it('should embed', async () => { - const model = chromeEmbedding(); + const model = new ChromeAIEmbeddingModel(); expect( await embed({ model, @@ -45,7 +45,7 @@ describe('embedding-model', () => { it('should embed result empty', async () => { expect( await embed({ - model: chromeEmbedding({ l2Normalize: true }), + model: new ChromeAIEmbeddingModel({ l2Normalize: true }), value: 'undefined', }) ).toMatchObject({ embedding: [] }); diff --git a/src/embedding-model.ts b/src/embedding-model.ts index 565d7f6..55c5b01 100644 --- a/src/embedding-model.ts +++ b/src/embedding-model.ts @@ -36,10 +36,13 @@ export interface ChromeAIEmbeddingModelSettings { delegate?: 'CPU' | 'GPU'; } +// See more: +// - https://github.com/google-ai-edge/mediapipe +// - https://ai.google.dev/edge/mediapipe/solutions/text/text_embedder/web_js export class ChromeAIEmbeddingModel implements EmbeddingModelV1 { readonly specificationVersion = 'v1'; readonly provider = 'google-mediapipe'; - readonly modelId: string = 'mediapipe'; + readonly modelId: string = 'embedding'; readonly supportsParallelCalls = true; readonly maxEmbeddingsPerCall = undefined; @@ -80,18 +83,12 @@ export class ChromeAIEmbeddingModel implements EmbeddingModelV1 { rawResponse?: Record; }> => { // if (options.abortSignal) console.warn('abortSignal is not supported'); - const embedder = await this.getTextEmbedder(); - const embeddings = await Promise.all( - options.values.map((text) => { - const embedderResult = embedder.embed(text); - const [embedding] = embedderResult.embeddings; - return embedding?.floatEmbedding ?? []; - }) - ); + const embeddings = options.values.map((text) => { + const embedderResult = embedder.embed(text); + const [embedding] = embedderResult.embeddings; + return embedding?.floatEmbedding ?? []; + }); return { embeddings }; }; } - -export const chromeEmbedding = (options?: ChromeAIEmbeddingModelSettings) => - new ChromeAIEmbeddingModel(options); diff --git a/src/index.ts b/src/index.ts index b26e9d6..31cd553 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,2 +1,3 @@ export * from './language-model'; export * from './embedding-model'; +export * from './chromeai'; diff --git a/src/language-model.test.ts b/src/language-model.test.ts index b396f86..15b2f8b 100644 --- a/src/language-model.test.ts +++ b/src/language-model.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi, afterEach } from 'vitest'; -import { chromeai, ChromeAIChatLanguageModel } from './index'; +import { ChromeAIChatLanguageModel } from './index'; import { generateText, streamText, generateObject, streamObject } from 'ai'; import { LoadSettingError, @@ -14,17 +14,22 @@ describe('language-model', () => { }); it('should instantiation anyways', () => { - expect(chromeai()).toBeInstanceOf(ChromeAIChatLanguageModel); - expect(chromeai().modelId).toBe('generic'); - expect(chromeai('text').modelId).toBe('text'); + expect(new ChromeAIChatLanguageModel('generic')).toBeInstanceOf( + ChromeAIChatLanguageModel + ); + expect(new ChromeAIChatLanguageModel('text').modelId).toBe('text'); expect( - chromeai('text', { temperature: 1, topK: 10 }).options + new ChromeAIChatLanguageModel('text', { temperature: 1, topK: 10 }) + .options ).toMatchObject({ temperature: 1, topK: 10 }); }); it('should throw when not support', async () => { await expect(() => - generateText({ model: chromeai(), prompt: 'empty' }) + generateText({ + model: new ChromeAIChatLanguageModel('generic'), + prompt: 'empty', + }) ).rejects.toThrowError(LoadSettingError); const cannotCreateSession = vi.fn(async () => 'no'); @@ -34,12 +39,18 @@ describe('language-model', () => { }); await expect(() => - generateText({ model: chromeai('text'), prompt: 'empty' }) + generateText({ + model: new ChromeAIChatLanguageModel('text'), + prompt: 'empty', + }) ).rejects.toThrowError(LoadSettingError); expect(cannotCreateSession).toHaveBeenCalledTimes(1); await expect(() => - generateText({ model: chromeai('generic'), prompt: 'empty' }) + generateText({ + model: new ChromeAIChatLanguageModel('generic'), + prompt: 'empty', + }) ).rejects.toThrowError(LoadSettingError); expect(cannotCreateSession).toHaveBeenCalledTimes(2); }); @@ -58,17 +69,23 @@ describe('language-model', () => { createTextSession: createSession, }); - await generateText({ model: chromeai('text'), prompt: 'test' }); + await generateText({ + model: new ChromeAIChatLanguageModel('text'), + prompt: 'test', + }); expect(getOptions).toHaveBeenCalledTimes(1); - const result = await generateText({ model: chromeai(), prompt: 'test' }); + const result = await generateText({ + model: new ChromeAIChatLanguageModel('generic'), + prompt: 'test', + }); expect(result).toMatchObject({ finishReason: 'stop', text: 'test', }); const resultForMessages = await generateText({ - model: chromeai(), + model: new ChromeAIChatLanguageModel('generic'), messages: [ { role: 'user', content: 'test' }, { role: 'assistant', content: 'assistant' }, @@ -95,7 +112,10 @@ describe('language-model', () => { createGenericSession: vi.fn(async () => ({ promptStreaming })), }); - const result = await streamText({ model: chromeai(), prompt: 'test' }); + const result = await streamText({ + model: new ChromeAIChatLanguageModel('generic'), + prompt: 'test', + }); for await (const textPart of result.textStream) { expect(textPart).toBe('test'); } @@ -110,7 +130,7 @@ describe('language-model', () => { }); const { object } = await generateObject({ - model: chromeai(), + model: new ChromeAIChatLanguageModel('generic'), schema: z.object({ hello: z.string(), }), @@ -129,7 +149,7 @@ describe('language-model', () => { }); await expect(() => generateText({ - model: chromeai(), + model: new ChromeAIChatLanguageModel('generic'), messages: [ { role: 'tool', @@ -161,7 +181,7 @@ describe('language-model', () => { await expect(() => generateObject({ - model: chromeai(), + model: new ChromeAIChatLanguageModel('generic'), mode: 'grammar', schema: z.object({}), prompt: 'test', @@ -170,7 +190,7 @@ describe('language-model', () => { await expect(() => streamObject({ - model: chromeai(), + model: new ChromeAIChatLanguageModel('generic'), mode: 'grammar', schema: z.object({}), prompt: 'test', diff --git a/src/language-model.ts b/src/language-model.ts index 274f5b3..8fb2e58 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -17,7 +17,10 @@ import { import { ChromeAISession, ChromeAISessionOptions } from './global'; import createDebug from 'debug'; import { StreamAI } from './stream-ai'; -import { chromeEmbedding } from './embedding-model'; +import { + ChromeAIEmbeddingModel, + ChromeAIEmbeddingModelSettings, +} from './embedding-model'; const debug = createDebug('chromeai'); @@ -224,7 +227,7 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { const session = await this.getSession(); const message = this.formatMessages(options); const promptStream = session.promptStreaming(message); - const transformStream = new StreamAI(); + const transformStream = new StreamAI(options.abortSignal); const stream = promptStream.pipeThrough(transformStream); return { @@ -233,10 +236,3 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { }; }; } - -export const chromeai = ( - modelId: ChromeAIChatModelId = 'generic', - settings: ChromeAIChatSettings = {} -) => new ChromeAIChatLanguageModel(modelId, settings); - -chromeai.embedding = chromeEmbedding; diff --git a/src/stream-ai.test.ts b/src/stream-ai.test.ts index 3142e32..37755fe 100644 --- a/src/stream-ai.test.ts +++ b/src/stream-ai.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi, afterEach } from 'vitest'; +import { describe, it, expect } from 'vitest'; import { StreamAI } from './stream-ai'; describe('stream-ai', () => { @@ -23,4 +23,22 @@ describe('stream-ai', () => { value: { type: 'finish' }, }); }); + + it('should abort when signal', async () => { + const controller = new AbortController(); + const transformStream = new StreamAI(controller.signal); + + const writer = transformStream.writable.getWriter(); + const reader = transformStream.readable.getReader(); + + writer.write('hello'); + + expect(await reader.read()).toMatchObject({ + value: { type: 'text-delta', textDelta: 'hello' }, + done: false, + }); + + controller.abort(); + expect(await reader.read()).toMatchObject({ done: true }); + }); }); diff --git a/src/stream-ai.ts b/src/stream-ai.ts index 57f6a12..d440e80 100644 --- a/src/stream-ai.ts +++ b/src/stream-ai.ts @@ -7,11 +7,17 @@ export class StreamAI extends TransformStream< string, LanguageModelV1StreamPart > { - public constructor() { + public constructor(abortSignal?: AbortSignal) { let textTemp = ''; super({ - start: () => { + start: (controller) => { textTemp = ''; + if (!abortSignal) return; + abortSignal.addEventListener('abort', () => { + debug('streamText terminate by abortSignal'); + controller.terminate(); + textTemp = ''; + }); }, transform: (chunk, controller) => { const textDelta = chunk.replace(textTemp, '');