From a0bf692ee61ff7eb8e97cdf2e806abc7401c654f Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Fri, 15 Mar 2024 13:54:19 +0100 Subject: [PATCH] feat: adding new Playground Manager V2 (chat based) (#504) * feat: adding new Playground Manager V2 Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: modelId value Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: linter Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * test: ensuring playground manager v2 has expected behaviour Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: removing uppercase Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: prettier Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: linter Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: linter&prettier Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * feat: creating conversation registry to handle conversation history Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * feat: improving apis Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: update StudioAPI Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * test: ensuring playground works as expected Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: prettier&linter Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --------- Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --- .../managers/inference/inferenceManager.ts | 8 + packages/backend/src/managers/playground.ts | 42 ++++ .../src/managers/playgroundV2Manager.spec.ts | 192 ++++++++++++++++++ .../src/managers/playgroundV2Manager.ts | 148 ++++++++++++++ .../src/registries/conversationRegistry.ts | 157 ++++++++++++++ packages/backend/src/studio-api-impl.spec.ts | 2 + packages/backend/src/studio-api-impl.ts | 22 ++ packages/backend/src/studio.ts | 4 + packages/shared/Messages.ts | 8 + packages/shared/src/StudioAPI.ts | 43 ++++ packages/shared/src/models/IModelOptions.ts | 4 + .../shared/src/models/IPlaygroundMessage.ts | 50 +++++ 12 files changed, 680 insertions(+) create mode 100644 packages/backend/src/managers/playgroundV2Manager.spec.ts create mode 100644 packages/backend/src/managers/playgroundV2Manager.ts create mode 100644 packages/backend/src/registries/conversationRegistry.ts create mode 100644 packages/shared/src/models/IModelOptions.ts create mode 100644 packages/shared/src/models/IPlaygroundMessage.ts diff --git a/packages/backend/src/managers/inference/inferenceManager.ts b/packages/backend/src/managers/inference/inferenceManager.ts index 58484719c..014e1f715 100644 --- a/packages/backend/src/managers/inference/inferenceManager.ts +++ b/packages/backend/src/managers/inference/inferenceManager.ts @@ -93,6 +93,14 @@ export class InferenceManager extends Publisher implements Di return Array.from(this.#servers.values()); } + /** + * return an inference server + * @param containerId the containerId of the inference server + */ + public get(containerId: string): InferenceServer | undefined { + return this.#servers.get(containerId); + } + /** * Given an engineId, it will create an inference server. * @param config diff --git a/packages/backend/src/managers/playground.ts b/packages/backend/src/managers/playground.ts index d3a218e3a..c51d47e25 100644 --- a/packages/backend/src/managers/playground.ts +++ b/packages/backend/src/managers/playground.ts @@ -39,6 +39,9 @@ const PLAYGROUND_IMAGE = 'quay.io/bootsy/playground:v0'; const STARTING_TIME_MAX = 3600 * 1000; +/** + * @deprecated + */ export class PlayGroundManager { private queryIdCounter = 0; @@ -56,6 +59,9 @@ export class PlayGroundManager { this.queries = new Map(); } + /** + * @deprecated + */ adoptRunningPlaygrounds() { this.podmanConnection.startupSubscribe(() => { containerEngine @@ -96,11 +102,17 @@ export class PlayGroundManager { }); } + /** + * @deprecated + */ async selectImage(image: string): Promise { const images = (await containerEngine.listImages()).filter(im => im.RepoTags?.some(tag => tag === image)); return images.length > 0 ? images[0] : undefined; } + /** + * @deprecated + */ setPlaygroundStatus(modelId: string, status: PlaygroundStatus): void { this.updatePlaygroundState(modelId, { modelId: modelId, @@ -109,6 +121,9 @@ export class PlayGroundManager { }); } + /** + * @deprecated + */ setPlaygroundError(modelId: string, error: string): void { const state: Partial = this.playgrounds.get(modelId) || {}; this.updatePlaygroundState(modelId, { @@ -119,6 +134,9 @@ export class PlayGroundManager { }); } + /** + * @deprecated + */ updatePlaygroundState(modelId: string, state: PlaygroundState): void { this.playgrounds.set(modelId, { ...state, @@ -127,6 +145,9 @@ export class PlayGroundManager { this.sendPlaygroundState(); } + /** + * @deprecated + */ sendPlaygroundState() { this.webview .postMessage({ @@ -138,6 +159,9 @@ export class PlayGroundManager { }); } + /** + * @deprecated + */ async startPlayground(modelId: string, modelPath: string): Promise { const startTime = performance.now(); // TODO(feloy) remove previous query from state? @@ -265,6 +289,9 @@ export class PlayGroundManager { return result.id; } + /** + * @deprecated + */ async stopPlayground(modelId: string): Promise { const startTime = performance.now(); const state = this.playgrounds.get(modelId); @@ -291,6 +318,9 @@ export class PlayGroundManager { this.telemetry.logUsage('playground.stop', { 'model.id': modelId, durationSeconds }); } + /** + * @deprecated + */ async askPlayground(modelInfo: ModelInfo, prompt: string): Promise { const startTime = performance.now(); const state = this.playgrounds.get(modelInfo.id); @@ -338,17 +368,29 @@ export class PlayGroundManager { return query.id; } + /** + * @deprecated + */ getNextQueryId() { return ++this.queryIdCounter; } + /** + * @deprecated + */ getQueriesState(): QueryState[] { return Array.from(this.queries.values()); } + /** + * @deprecated + */ getPlaygroundsState(): PlaygroundState[] { return Array.from(this.playgrounds.values()); } + /** + * @deprecated + */ sendQueriesState(): void { this.webview .postMessage({ diff --git a/packages/backend/src/managers/playgroundV2Manager.spec.ts b/packages/backend/src/managers/playgroundV2Manager.spec.ts new file mode 100644 index 000000000..40d2869e5 --- /dev/null +++ b/packages/backend/src/managers/playgroundV2Manager.spec.ts @@ -0,0 +1,192 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import { expect, test, vi, beforeEach, afterEach } from 'vitest'; +import OpenAI from 'openai'; +import { PlaygroundV2Manager } from './playgroundV2Manager'; +import type { Webview } from '@podman-desktop/api'; +import type { InferenceServer } from '@shared/src/models/IInference'; +import type { InferenceManager } from './inference/inferenceManager'; +import { Messages } from '@shared/Messages'; + +vi.mock('openai', () => ({ + default: vi.fn(), +})); + +const webviewMock = { + postMessage: vi.fn(), +} as unknown as Webview; + +const inferenceManagerMock = { + get: vi.fn(), +} as unknown as InferenceManager; + +beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(webviewMock.postMessage).mockResolvedValue(undefined); + vi.useFakeTimers(); +}); + +afterEach(() => { + vi.useRealTimers(); +}); + +test('manager should be properly initialized', () => { + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + expect(manager.getConversations().length).toBe(0); +}); + +test('submit should throw an error is the server is stopped', async () => { + vi.mocked(inferenceManagerMock.get).mockReturnValue({ + status: 'stopped', + } as unknown as InferenceServer); + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + await expect( + manager.submit('dummyContainerId', 'dummyModelId', 'dummyConversationId', 'dummyUserInput'), + ).rejects.toThrowError('Inference server is not running.'); +}); + +test('submit should throw an error is the server is unhealthy', async () => { + vi.mocked(inferenceManagerMock.get).mockReturnValue({ + status: 'running', + health: { + Status: 'unhealthy', + }, + } as unknown as InferenceServer); + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + await expect( + manager.submit('dummyContainerId', 'dummyModelId', 'dummyConversationId', 'dummyUserInput'), + ).rejects.toThrowError('Inference server is not healthy, currently status: unhealthy.'); +}); + +test('submit should throw an error is the model id provided does not exist.', async () => { + vi.mocked(inferenceManagerMock.get).mockReturnValue({ + status: 'running', + health: { + Status: 'healthy', + }, + models: [ + { + id: 'dummyModelId', + }, + ], + } as unknown as InferenceServer); + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + await expect( + manager.submit('dummyContainerId', 'invalidModelId', 'dummyConversationId', 'dummyUserInput'), + ).rejects.toThrowError( + `modelId 'invalidModelId' is not available on the inference server, valid model ids are: dummyModelId.`, + ); +}); + +test('submit should throw an error is the conversation id provided does not exist.', async () => { + vi.mocked(inferenceManagerMock.get).mockReturnValue({ + status: 'running', + health: { + Status: 'healthy', + }, + models: [ + { + id: 'dummyModelId', + file: { + file: 'dummyModelFile', + }, + }, + ], + } as unknown as InferenceServer); + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + await expect( + manager.submit('dummyContainerId', 'dummyModelId', 'dummyConversationId', 'dummyUserInput'), + ).rejects.toThrowError(`conversation with id dummyConversationId does not exist.`); +}); + +test('create conversation should create conversation.', async () => { + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + expect(manager.getConversations().length).toBe(0); + const conversationId = manager.createConversation(); + + const conversations = manager.getConversations(); + expect(conversations.length).toBe(1); + expect(conversations[0].id).toBe(conversationId); +}); + +test('valid submit should create IPlaygroundMessage and notify the webview', async () => { + const createMock = vi.fn().mockResolvedValue([]); + vi.mocked(OpenAI).mockReturnValue({ + chat: { + completions: { + create: createMock, + }, + }, + } as unknown as OpenAI); + + vi.mocked(inferenceManagerMock.get).mockReturnValue({ + status: 'running', + health: { + Status: 'healthy', + }, + models: [ + { + id: 'dummyModelId', + file: { + file: 'dummyModelFile', + }, + }, + ], + connection: { + port: 8888, + }, + } as unknown as InferenceServer); + const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock); + const conversationId = manager.createConversation(); + + const date = new Date(2000, 1, 1, 13); + vi.setSystemTime(date); + + await manager.submit('dummyContainerId', 'dummyModelId', conversationId, 'dummyUserInput'); + + // Wait for assistant message to be completed + await vi.waitFor(() => { + expect(manager.getConversations()[0].messages[1].content).toBeDefined(); + }); + + const conversations = manager.getConversations(); + + expect(conversations.length).toBe(1); + expect(conversations[0].messages.length).toBe(2); + expect(conversations[0].messages[0]).toStrictEqual({ + content: 'dummyUserInput', + id: expect.anything(), + options: undefined, + role: 'user', + timestamp: date.getTime(), + }); + expect(conversations[0].messages[1]).toStrictEqual({ + choices: undefined, + completed: true, + content: '', + id: expect.anything(), + role: 'assistant', + timestamp: date.getTime(), + }); + + expect(webviewMock.postMessage).toHaveBeenLastCalledWith({ + id: Messages.MSG_CONVERSATIONS_UPDATE, + body: conversations, + }); +}); diff --git a/packages/backend/src/managers/playgroundV2Manager.ts b/packages/backend/src/managers/playgroundV2Manager.ts new file mode 100644 index 000000000..4dccccca6 --- /dev/null +++ b/packages/backend/src/managers/playgroundV2Manager.ts @@ -0,0 +1,148 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ +import type { Disposable, Webview } from '@podman-desktop/api'; +import type { InferenceManager } from './inference/inferenceManager'; +import OpenAI from 'openai'; +import type { ChatCompletionChunk, ChatCompletionMessageParam } from 'openai/src/resources/chat/completions'; +import type { ModelOptions } from '@shared/src/models/IModelOptions'; +import type { Stream } from 'openai/streaming'; +import { ConversationRegistry } from '../registries/conversationRegistry'; +import type { Conversation, PendingChat, UserChat } from '@shared/src/models/IPlaygroundMessage'; + +export class PlaygroundV2Manager implements Disposable { + #conversationRegistry: ConversationRegistry; + #counter: number; + + constructor( + webview: Webview, + private inferenceManager: InferenceManager, + ) { + this.#conversationRegistry = new ConversationRegistry(webview); + this.#counter = 0; + } + + private getUniqueId(): string { + return `playground-${++this.#counter}`; + } + + createConversation(): string { + return this.#conversationRegistry.createConversation(); + } + + /** + * @param containerId must be corresponding to an inference server container + * @param modelId the model to use, should be included in the inference server matching the containerId + * @param conversationId the conversation id to happen the message to. + * @param userInput the user input + * @param options the model configuration + */ + async submit( + containerId: string, + modelId: string, + conversationId: string, + userInput: string, + options?: ModelOptions, + ): Promise { + const server = this.inferenceManager.get(containerId); + if (server === undefined) throw new Error('Inference server not found.'); + + if (server.status !== 'running') throw new Error('Inference server is not running.'); + + if (server.health?.Status !== 'healthy') + throw new Error(`Inference server is not healthy, currently status: ${server.health.Status}.`); + + const modelInfo = server.models.find(model => model.id === modelId); + if (modelInfo === undefined) + throw new Error( + `modelId '${modelId}' is not available on the inference server, valid model ids are: ${server.models.map(model => model.id).join(', ')}.`, + ); + + const conversation = this.#conversationRegistry.get(conversationId); + if (conversation === undefined) throw new Error(`conversation with id ${conversationId} does not exist.`); + + this.#conversationRegistry.submit(conversation.id, { + content: userInput, + options: options, + role: 'user', + id: this.getUniqueId(), + timestamp: Date.now(), + } as UserChat); + + const client = new OpenAI({ + baseURL: `http://localhost:${server.connection.port}/v1`, + apiKey: 'dummy', + }); + + const response = await client.chat.completions.create({ + messages: this.getFormattedMessages(conversationId), + stream: true, + model: modelInfo.file.file, + ...options, + }); + // process stream async + this.processStream(conversationId, response).catch((err: unknown) => { + console.error('Something went wrong while processing stream', err); + }); + } + + /** + * Given a Stream from the OpenAI library update and notify the publisher + * @param conversationId + * @param stream + */ + private async processStream(conversationId: string, stream: Stream): Promise { + const messageId = this.getUniqueId(); + this.#conversationRegistry.submit(conversationId, { + role: 'assistant', + choices: [], + completed: false, + id: messageId, + timestamp: Date.now(), + } as PendingChat); + + for await (const chunk of stream) { + this.#conversationRegistry.appendChoice(conversationId, messageId, { + content: chunk.choices[0]?.delta?.content || '', + }); + } + + this.#conversationRegistry.completeMessage(conversationId, messageId); + } + + /** + * Transform the ChatMessage interface to the OpenAI ChatCompletionMessageParam + * @private + */ + private getFormattedMessages(conversationId: string): ChatCompletionMessageParam[] { + return this.#conversationRegistry.get(conversationId).messages.map( + message => + ({ + name: undefined, + ...message, + }) as ChatCompletionMessageParam, + ); + } + + getConversations(): Conversation[] { + return this.#conversationRegistry.getAll(); + } + + dispose(): void { + this.#conversationRegistry.dispose(); + } +} diff --git a/packages/backend/src/registries/conversationRegistry.ts b/packages/backend/src/registries/conversationRegistry.ts new file mode 100644 index 000000000..ccb68ee36 --- /dev/null +++ b/packages/backend/src/registries/conversationRegistry.ts @@ -0,0 +1,157 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import { Publisher } from '../utils/Publisher'; +import type { + AssistantChat, + ChatMessage, + Choice, + Conversation, + PendingChat, +} from '@shared/src/models/IPlaygroundMessage'; +import type { Disposable, Webview } from '@podman-desktop/api'; +import { Messages } from '@shared/Messages'; + +export class ConversationRegistry extends Publisher implements Disposable { + #conversations: Map; + #counter: number; + + constructor(webview: Webview) { + super(webview, Messages.MSG_CONVERSATIONS_UPDATE, () => this.getAll()); + this.#conversations = new Map(); + this.#counter = 0; + } + + init(): void { + // TODO: load from file + } + + private getUniqueId(): string { + return `conversation-${++this.#counter}`; + } + + /** + * Utility method to update a message content in a given conversation + * @param conversationId + * @param messageId + * @param message + */ + update(conversationId: string, messageId: string, message: Partial) { + const conversation = this.#conversations.get(conversationId); + + if (conversation === undefined) { + throw new Error(`conversation with id ${conversationId} does not exist.`); + } + + const messageIndex = conversation.messages.findIndex(message => message.id === messageId); + if (messageIndex === -1) + throw new Error(`message with id ${messageId} does not exist in conversation ${conversationId}.`); + + // Update the message with the provided content + conversation.messages[messageIndex] = { + ...conversation.messages[messageIndex], + ...message, + id: messageId, // preventing we are not updating the id + }; + this.notify(); + } + + createConversation(): string { + const conversationId = this.getUniqueId(); + this.#conversations.set(conversationId, { + messages: [], + id: conversationId, + }); + this.notify(); + return conversationId; + } + + /** + * This method will be responsible for finalizing the message by concatenating all the choices + * @param conversationId + * @param messageId + */ + completeMessage(conversationId: string, messageId: string): void { + const conversation = this.#conversations.get(conversationId); + if (conversation === undefined) throw new Error(`conversation with id ${conversationId} does not exist.`); + + const messageIndex = conversation.messages.findIndex(message => message.id === messageId); + if (messageIndex === -1) + throw new Error(`message with id ${messageId} does not exist in conversation ${conversationId}.`); + + const content = ((conversation.messages[messageIndex] as PendingChat)?.choices || []) + .map(choice => choice.content) + .join(''); + + this.update(conversationId, messageId, { + ...conversation.messages[messageIndex], + choices: undefined, + role: 'assistant', + completed: true, + content: content, + } as AssistantChat); + } + + /** + * Utility method to quickly add a choice to a given a message inside a conversation + * @param conversationId + * @param messageId + * @param choice + */ + appendChoice(conversationId: string, messageId: string, choice: Choice): void { + const conversation = this.#conversations.get(conversationId); + if (conversation === undefined) throw new Error(`conversation with id ${conversationId} does not exist.`); + + const messageIndex = conversation.messages.findIndex(message => message.id === messageId); + if (messageIndex === -1) + throw new Error(`message with id ${messageId} does not exist in conversation ${conversationId}.`); + + this.update(conversationId, messageId, { + ...conversation.messages[messageIndex], + choices: [...((conversation.messages[messageIndex] as PendingChat)?.choices || []), choice], + } as PendingChat); + } + + /** + * Utility method to add a new Message to a given conversation + * @param conversationId + * @param message + */ + submit(conversationId: string, message: ChatMessage): void { + const conversation = this.#conversations.get(conversationId); + if (conversation === undefined) throw new Error(`conversation with id ${conversationId} does not exist.`); + + this.#conversations.set(conversationId, { + ...conversation, + messages: [...conversation.messages, message], + }); + this.notify(); + } + + dispose(): void { + this.#conversations.clear(); + } + + get(conversationId: string): Conversation | undefined { + return this.#conversations.get(conversationId); + } + + getAll(): Conversation[] { + return Array.from(this.#conversations.values()); + } +} diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index fc7813ab0..c0b4daeb3 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -35,6 +35,7 @@ import { timeout } from './utils/utils'; import type { TaskRegistry } from './registries/TaskRegistry'; import type { LocalRepositoryRegistry } from './registries/LocalRepositoryRegistry'; import type { Recipe } from '@shared/src/models/IRecipe'; +import type { PlaygroundV2Manager } from './managers/playgroundV2Manager'; vi.mock('./ai.json', () => { return { @@ -105,6 +106,7 @@ beforeEach(async () => { {} as LocalRepositoryRegistry, {} as unknown as TaskRegistry, {} as unknown as InferenceManager, + {} as unknown as PlaygroundV2Manager, ); vi.mock('node:fs'); diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 87f4e3b59..151b5bb4d 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -36,8 +36,11 @@ import path from 'node:path'; import type { InferenceServer } from '@shared/src/models/IInference'; import type { CreationInferenceServerOptions } from '@shared/src/models/InferenceServerConfig'; import type { InferenceManager } from './managers/inference/inferenceManager'; +import type { Conversation } from '@shared/src/models/IPlaygroundMessage'; +import type { PlaygroundV2Manager } from './managers/playgroundV2Manager'; import { getFreeRandomPort } from './utils/ports'; import { withDefaultConfiguration } from './utils/inferenceUtils'; +import type { ModelOptions } from '@shared/src/models/IModelOptions'; export class StudioApiImpl implements StudioAPI { constructor( @@ -49,8 +52,27 @@ export class StudioApiImpl implements StudioAPI { private localRepositories: LocalRepositoryRegistry, private taskRegistry: TaskRegistry, private inferenceManager: InferenceManager, + private playgroundV2: PlaygroundV2Manager, ) {} + submitPlaygroundMessage( + containerId: string, + modelId: string, + conversationId: string, + userInput: string, + options?: ModelOptions, + ): Promise { + return this.playgroundV2.submit(containerId, modelId, conversationId, userInput, options); + } + + async createPlaygroundConversation(): Promise { + return this.playgroundV2.createConversation(); + } + + async getPlaygroundConversations(): Promise { + return this.playgroundV2.getConversations(); + } + async getInferenceServers(): Promise { return this.inferenceManager.getServers(); } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 91e849d74..5af613da0 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -39,6 +39,7 @@ import { ContainerRegistry } from './registries/ContainerRegistry'; import { PodmanConnection } from './managers/podmanConnection'; import { LocalRepositoryRegistry } from './registries/LocalRepositoryRegistry'; import { InferenceManager } from './managers/inference/inferenceManager'; +import { PlaygroundV2Manager } from './managers/playgroundV2Manager'; // TODO: Need to be configured export const AI_STUDIO_FOLDER = path.join('podman-desktop', 'ai-studio'); @@ -168,6 +169,8 @@ export class Studio { this.telemetry.logUsage(e.webviewPanel.visible ? 'opened' : 'closed'); }); + const playgroundV2 = new PlaygroundV2Manager(this.#panel.webview, this.#inferenceManager); + // Creating StudioApiImpl this.studioApi = new StudioApiImpl( applicationManager, @@ -178,6 +181,7 @@ export class Studio { localRepositoryRegistry, taskRegistry, this.#inferenceManager, + playgroundV2, ); this.catalogManager.init(); diff --git a/packages/shared/Messages.ts b/packages/shared/Messages.ts index 46cb6c717..d91661214 100644 --- a/packages/shared/Messages.ts +++ b/packages/shared/Messages.ts @@ -17,8 +17,15 @@ ***********************************************************************/ export enum Messages { + /** + * @deprecated + */ MSG_PLAYGROUNDS_STATE_UPDATE = 'playgrounds-state-update', + /** + * @deprecated + */ MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state', + MSG_PLAYGROUNDS_MESSAGES_UPDATE = 'playgrounds-messages-update', MSG_NEW_CATALOG_STATE = 'new-catalog-state', MSG_TASKS_UPDATE = 'tasks-update', MSG_NEW_MODELS_STATE = 'new-models-state', @@ -27,4 +34,5 @@ export enum Messages { MSG_INFERENCE_SERVERS_UPDATE = 'inference-servers-update', MSG_MONITORING_UPDATE = 'monitoring-update', MSG_SUPPORTED_LANGUAGES_UPDATE = 'supported-languages-supported', + MSG_CONVERSATIONS_UPDATE = 'conversations-update', } diff --git a/packages/shared/src/StudioAPI.ts b/packages/shared/src/StudioAPI.ts index c7ee6c060..db9581a46 100644 --- a/packages/shared/src/StudioAPI.ts +++ b/packages/shared/src/StudioAPI.ts @@ -26,6 +26,8 @@ import type { Task } from './models/ITask'; import type { LocalRepository } from './models/ILocalRepository'; import type { InferenceServer } from './models/IInference'; import type { CreationInferenceServerOptions } from './models/InferenceServerConfig'; +import type { ModelOptions } from './models/IModelOptions'; +import type { Conversation } from './models/IPlaygroundMessage'; export abstract class StudioAPI { abstract ping(): Promise; @@ -41,10 +43,25 @@ export abstract class StudioAPI { * Delete the folder containing the model from local storage */ abstract requestRemoveLocalModel(modelId: string): Promise; + /** + * @deprecated + */ abstract startPlayground(modelId: string): Promise; + /** + * @deprecated + */ abstract stopPlayground(modelId: string): Promise; + /** + * @deprecated + */ abstract askPlayground(modelId: string, prompt: string): Promise; + /** + * @deprecated + */ abstract getPlaygroundQueriesState(): Promise; + /** + * @deprecated + */ abstract getPlaygroundsState(): Promise; abstract getModelsDirectory(): Promise; @@ -101,4 +118,30 @@ export abstract class StudioAPI { * Return a free random port on the host machine */ abstract getHostFreePort(): Promise; + + /** + * Submit a user input to the Playground linked to a conversation, model, and inference server + * @param containerId the container id of the inference server we want to use + * @param modelId the model to use + * @param conversationId the conversation to input the message in + * @param userInput the user input, e.g. 'What is the capital of France ?' + * @param options the options for the model, e.g. temperature + */ + abstract submitPlaygroundMessage( + containerId: string, + modelId: string, + conversationId: string, + userInput: string, + options?: ModelOptions, + ): Promise; + + /** + * Return the conversations + */ + abstract getPlaygroundConversations(): Promise; + + /** + * Create a new conversation and return a conversationId + */ + abstract createPlaygroundConversation(): Promise; } diff --git a/packages/shared/src/models/IModelOptions.ts b/packages/shared/src/models/IModelOptions.ts new file mode 100644 index 000000000..f04fffce0 --- /dev/null +++ b/packages/shared/src/models/IModelOptions.ts @@ -0,0 +1,4 @@ +export interface ModelOptions { + temperature?: number; + max_tokens?: number; +} diff --git a/packages/shared/src/models/IPlaygroundMessage.ts b/packages/shared/src/models/IPlaygroundMessage.ts new file mode 100644 index 000000000..e308447b2 --- /dev/null +++ b/packages/shared/src/models/IPlaygroundMessage.ts @@ -0,0 +1,50 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import type { ModelOptions } from './IModelOptions'; + +export interface ChatMessage { + id: string; + role: 'system' | 'user' | 'assistant'; + content?: string; + timestamp: number; +} + +export interface AssistantChat extends ChatMessage { + role: 'assistant'; + completed: boolean; +} + +export interface PendingChat extends AssistantChat { + completed: false; + choices: Choice[]; +} + +export interface UserChat extends ChatMessage { + role: 'user'; + options?: ModelOptions; +} + +export interface Conversation { + id: string; + messages: ChatMessage[]; +} + +export interface Choice { + content: string; +}