Skip to content

Commit

Permalink
feat: adding new Playground Manager V2 (chat based) (containers#504)
Browse files Browse the repository at this point in the history
* feat: adding new Playground Manager V2

Signed-off-by: axel7083 <[email protected]>

* fix: modelId value

Signed-off-by: axel7083 <[email protected]>

* fix: linter

Signed-off-by: axel7083 <[email protected]>

* test: ensuring playground manager v2 has expected behaviour

Signed-off-by: axel7083 <[email protected]>

* fix: removing uppercase

Signed-off-by: axel7083 <[email protected]>

* fix: prettier

Signed-off-by: axel7083 <[email protected]>

* fix: linter

Signed-off-by: axel7083 <[email protected]>

* fix: linter&prettier

Signed-off-by: axel7083 <[email protected]>

* feat: creating conversation registry to handle conversation history

Signed-off-by: axel7083 <[email protected]>

* feat: improving apis

Signed-off-by: axel7083 <[email protected]>

* fix: update StudioAPI

Signed-off-by: axel7083 <[email protected]>

* test: ensuring playground works as expected

Signed-off-by: axel7083 <[email protected]>

* fix: prettier&linter

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Mar 15, 2024
1 parent 7e01521 commit a0bf692
Show file tree
Hide file tree
Showing 12 changed files with 680 additions and 0 deletions.
8 changes: 8 additions & 0 deletions packages/backend/src/managers/inference/inferenceManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
return Array.from(this.#servers.values());
}

/**
* return an inference server
* @param containerId the containerId of the inference server
*/
public get(containerId: string): InferenceServer | undefined {
return this.#servers.get(containerId);
}

/**
* Given an engineId, it will create an inference server.
* @param config
Expand Down
42 changes: 42 additions & 0 deletions packages/backend/src/managers/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const PLAYGROUND_IMAGE = 'quay.io/bootsy/playground:v0';

const STARTING_TIME_MAX = 3600 * 1000;

/**
* @deprecated
*/
export class PlayGroundManager {
private queryIdCounter = 0;

Expand All @@ -56,6 +59,9 @@ export class PlayGroundManager {
this.queries = new Map<number, QueryState>();
}

/**
* @deprecated
*/
adoptRunningPlaygrounds() {
this.podmanConnection.startupSubscribe(() => {
containerEngine
Expand Down Expand Up @@ -96,11 +102,17 @@ export class PlayGroundManager {
});
}

/**
* @deprecated
*/
async selectImage(image: string): Promise<ImageInfo | undefined> {
const images = (await containerEngine.listImages()).filter(im => im.RepoTags?.some(tag => tag === image));
return images.length > 0 ? images[0] : undefined;
}

/**
* @deprecated
*/
setPlaygroundStatus(modelId: string, status: PlaygroundStatus): void {
this.updatePlaygroundState(modelId, {
modelId: modelId,
Expand All @@ -109,6 +121,9 @@ export class PlayGroundManager {
});
}

/**
* @deprecated
*/
setPlaygroundError(modelId: string, error: string): void {
const state: Partial<PlaygroundState> = this.playgrounds.get(modelId) || {};
this.updatePlaygroundState(modelId, {
Expand All @@ -119,6 +134,9 @@ export class PlayGroundManager {
});
}

/**
* @deprecated
*/
updatePlaygroundState(modelId: string, state: PlaygroundState): void {
this.playgrounds.set(modelId, {
...state,
Expand All @@ -127,6 +145,9 @@ export class PlayGroundManager {
this.sendPlaygroundState();
}

/**
* @deprecated
*/
sendPlaygroundState() {
this.webview
.postMessage({
Expand All @@ -138,6 +159,9 @@ export class PlayGroundManager {
});
}

/**
* @deprecated
*/
async startPlayground(modelId: string, modelPath: string): Promise<string> {
const startTime = performance.now();
// TODO(feloy) remove previous query from state?
Expand Down Expand Up @@ -265,6 +289,9 @@ export class PlayGroundManager {
return result.id;
}

/**
* @deprecated
*/
async stopPlayground(modelId: string): Promise<void> {
const startTime = performance.now();
const state = this.playgrounds.get(modelId);
Expand All @@ -291,6 +318,9 @@ export class PlayGroundManager {
this.telemetry.logUsage('playground.stop', { 'model.id': modelId, durationSeconds });
}

/**
* @deprecated
*/
async askPlayground(modelInfo: ModelInfo, prompt: string): Promise<number> {
const startTime = performance.now();
const state = this.playgrounds.get(modelInfo.id);
Expand Down Expand Up @@ -338,17 +368,29 @@ export class PlayGroundManager {
return query.id;
}

/**
* @deprecated
*/
getNextQueryId() {
return ++this.queryIdCounter;
}
/**
* @deprecated
*/
getQueriesState(): QueryState[] {
return Array.from(this.queries.values());
}

/**
* @deprecated
*/
getPlaygroundsState(): PlaygroundState[] {
return Array.from(this.playgrounds.values());
}

/**
* @deprecated
*/
sendQueriesState(): void {
this.webview
.postMessage({
Expand Down
192 changes: 192 additions & 0 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/**********************************************************************
* Copyright (C) 2024 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import { expect, test, vi, beforeEach, afterEach } from 'vitest';
import OpenAI from 'openai';
import { PlaygroundV2Manager } from './playgroundV2Manager';
import type { Webview } from '@podman-desktop/api';
import type { InferenceServer } from '@shared/src/models/IInference';
import type { InferenceManager } from './inference/inferenceManager';
import { Messages } from '@shared/Messages';

vi.mock('openai', () => ({
default: vi.fn(),
}));

const webviewMock = {
postMessage: vi.fn(),
} as unknown as Webview;

const inferenceManagerMock = {
get: vi.fn(),
} as unknown as InferenceManager;

beforeEach(() => {
vi.resetAllMocks();
vi.mocked(webviewMock.postMessage).mockResolvedValue(undefined);
vi.useFakeTimers();
});

afterEach(() => {
vi.useRealTimers();
});

test('manager should be properly initialized', () => {
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
expect(manager.getConversations().length).toBe(0);
});

test('submit should throw an error is the server is stopped', async () => {
vi.mocked(inferenceManagerMock.get).mockReturnValue({
status: 'stopped',
} as unknown as InferenceServer);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
await expect(
manager.submit('dummyContainerId', 'dummyModelId', 'dummyConversationId', 'dummyUserInput'),
).rejects.toThrowError('Inference server is not running.');
});

test('submit should throw an error is the server is unhealthy', async () => {
vi.mocked(inferenceManagerMock.get).mockReturnValue({
status: 'running',
health: {
Status: 'unhealthy',
},
} as unknown as InferenceServer);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
await expect(
manager.submit('dummyContainerId', 'dummyModelId', 'dummyConversationId', 'dummyUserInput'),
).rejects.toThrowError('Inference server is not healthy, currently status: unhealthy.');
});

test('submit should throw an error is the model id provided does not exist.', async () => {
vi.mocked(inferenceManagerMock.get).mockReturnValue({
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
},
],
} as unknown as InferenceServer);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
await expect(
manager.submit('dummyContainerId', 'invalidModelId', 'dummyConversationId', 'dummyUserInput'),
).rejects.toThrowError(
`modelId 'invalidModelId' is not available on the inference server, valid model ids are: dummyModelId.`,
);
});

test('submit should throw an error is the conversation id provided does not exist.', async () => {
vi.mocked(inferenceManagerMock.get).mockReturnValue({
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
} as unknown as InferenceServer);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
await expect(
manager.submit('dummyContainerId', 'dummyModelId', 'dummyConversationId', 'dummyUserInput'),
).rejects.toThrowError(`conversation with id dummyConversationId does not exist.`);
});

test('create conversation should create conversation.', async () => {
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
expect(manager.getConversations().length).toBe(0);
const conversationId = manager.createConversation();

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
expect(conversations[0].id).toBe(conversationId);
});

test('valid submit should create IPlaygroundMessage and notify the webview', async () => {
const createMock = vi.fn().mockResolvedValue([]);
vi.mocked(OpenAI).mockReturnValue({
chat: {
completions: {
create: createMock,
},
},
} as unknown as OpenAI);

vi.mocked(inferenceManagerMock.get).mockReturnValue({
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
connection: {
port: 8888,
},
} as unknown as InferenceServer);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock);
const conversationId = manager.createConversation();

const date = new Date(2000, 1, 1, 13);
vi.setSystemTime(date);

await manager.submit('dummyContainerId', 'dummyModelId', conversationId, 'dummyUserInput');

// Wait for assistant message to be completed
await vi.waitFor(() => {
expect(manager.getConversations()[0].messages[1].content).toBeDefined();
});

const conversations = manager.getConversations();

expect(conversations.length).toBe(1);
expect(conversations[0].messages.length).toBe(2);
expect(conversations[0].messages[0]).toStrictEqual({
content: 'dummyUserInput',
id: expect.anything(),
options: undefined,
role: 'user',
timestamp: date.getTime(),
});
expect(conversations[0].messages[1]).toStrictEqual({
choices: undefined,
completed: true,
content: '',
id: expect.anything(),
role: 'assistant',
timestamp: date.getTime(),
});

expect(webviewMock.postMessage).toHaveBeenLastCalledWith({
id: Messages.MSG_CONVERSATIONS_UPDATE,
body: conversations,
});
});
Loading

0 comments on commit a0bf692

Please sign in to comment.